In [1]:
import tkinter as tk
from tkinter import filedialog
from tkinter import Text
from PIL import Image, ImageTk
import torch
import torch.nn as nn
import clip
import cv2
from tqdm import tqdm
import numpy as np
from math import ceil
import torch.nn.functional as F

# MMoE model
class MMoE(nn.Module):
    def __init__(self, input_dim, expert_dim, num_experts, num_tasks, output_dim, num_fc_layers=2):
        super(MMoE, self).__init__()
        
        self.num_tasks = num_tasks
        self.num_experts = num_experts
        
        # Experts: A shared set of experts (fully connected layers)
        self.experts = nn.ModuleList([self._build_fc_layers(input_dim, expert_dim, num_fc_layers) for _ in range(num_experts)])
        
        # Gates: One gating network per task
        self.gates = nn.ModuleList([self._build_fc_layers(input_dim, num_experts, num_fc_layers) for _ in range(num_tasks)])
        
        # Output layers: One output layer per task
        self.output_layers = nn.ModuleList([nn.Linear(expert_dim, output_dim) for _ in range(num_tasks)])

    def _build_fc_layers(self, input_dim, output_dim, num_fc_layers):
        layers = []
        for i in range(num_fc_layers):
            in_dim = input_dim if i == 0 else output_dim
            layers.append(nn.Linear(in_dim, output_dim))
            if i < num_fc_layers - 1:
                layers.append(nn.ReLU())
        return nn.Sequential(*layers)
    
    def forward(self, x):
        # Get expert outputs
        expert_outputs = [expert(x) for expert in self.experts]
        # Stack expert outputs to shape (batch_size, num_experts, expert_dim)
        expert_outputs = torch.stack(expert_outputs, dim=1)
        
        # Get the output for each task
        outputs = []
        for i in range(self.num_tasks):
            gate_output = self.gates[i](x)  # (batch_size, num_experts)
            gate_output = F.softmax(gate_output, dim=1)  # Normalize to get the weights
            
            # Weighted sum of expert outputs for the task
            task_output = torch.sum(gate_output.unsqueeze(2) * expert_outputs, dim=1)  # (batch_size, expert_dim)
            
            # Apply task-specific output layer
            task_output = self.output_layers[i](task_output)
            outputs.append(task_output)
        
        return outputs

# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # (64, 32, 32)
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # (128, 16, 16)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # (256, 8, 8)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),  # (512, 4, 4)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            
            nn.Flatten(),  # (512 * 4 * 4 = 8192)
            nn.Linear(512 * 4 * 4, 1),  # Fully connected layer to a single output
            nn.Sigmoid()  # Output: probability for binary classification
        )

    def forward(self, x):
        return self.model(x)

# Text and Image Similarity Model incorporating MMoE and Discriminator
class TextImageSimilarityModel(nn.Module):
    def __init__(self, text_image_input_dim, input_dim, embed_dim, num_fc_layers, expert_dim, num_experts, num_tasks, output_dim, num_fc_layers_mmoe):
        super(TextImageSimilarityModel, self).__init__()
        
        # Text branch
        self.text_layers = self._build_fc_layers(input_dim, embed_dim, num_fc_layers)
        
        # Image branch
        self.image_layers = self._build_fc_layers(input_dim, embed_dim, num_fc_layers)
        
        self.mmoe = MMoE(text_image_input_dim, expert_dim, num_experts, num_tasks, output_dim, num_fc_layers_mmoe)
        
        self.discriminator = Discriminator()
        
        # Final fully connected layer for output
        self.fc_layer = nn.Linear(1 + output_dim, 1)  # Adjust based on the concatenated dimensions
        
    def _build_fc_layers(self, input_dim, embed_dim, num_fc_layers):
        layers = []
        for i in range(num_fc_layers):
            in_dim = input_dim if i == 0 else embed_dim
            out_dim = embed_dim
            layers.append(nn.Linear(in_dim, out_dim))
            if i < num_fc_layers - 1:
                layers.append(nn.ReLU())
        return nn.Sequential(*layers)
    
    def forward(self, text_features, image_features, image):
        # Process text features
        text_embedding = self.text_layers(text_features)
        text_embedding = F.normalize(text_embedding, dim=-1)
        
        # Process image features
        image_embedding = self.image_layers(image_features)
        image_embedding = F.normalize(image_embedding, dim=-1)
        
        # Concatenate the text and image embeddings
        combined_features = torch.cat([text_embedding, image_embedding], dim=-1)
        
        # Pass the combined features to MMoE
        task_outputs = self.mmoe(combined_features)
        
        # For discriminator
        out_dis = self.discriminator(image)
        
        # Concatenate task_outputs[0] and discriminator output
        combined_out = torch.cat([task_outputs[0], out_dis], dim=-1)
        
        # Pass through the final FC layer
        out = self.fc_layer(combined_out)
        
        return out


class FileLoaderApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Text and Image Loader with CLIP")

        # Initialize CLIP model and preprocessing
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.clip_model, self.preprocess, self.device = self.create_clip_model_and_preprocessor()

        # Frame for buttons
        button_frame = tk.Frame(root)
        button_frame.pack(pady=10)

        # Buttons to load text and image
        self.text_button = tk.Button(button_frame, text="Load Text File", command=self.load_text)
        self.text_button.grid(row=0, column=0, padx=10)

        self.image_button = tk.Button(button_frame, text="Load Image File", command=self.load_image)
        self.image_button.grid(row=0, column=1, padx=10)

        # Text area for displaying text content
        self.text_area = Text(root, height=5, width=50)
        self.text_area.pack(pady=10)
        
        

        # Buttons for feature extraction
        # self.extract_text_button = tk.Button(root, text="Extract Text Features", command=self.extract_text_features)
        # self.extract_text_button.pack(pady=5)

        # self.extract_image_button = tk.Button(root, text="Extract Image Features", command=self.extract_image_features)
        # self.extract_image_button.pack(pady=5)
        
        self.load_model_button = tk.Button(root, text="Check News", command=self.load_text_image_similarity_model)
        self.load_model_button.pack(pady=5)


        # Labels for showing extracted features
        self.text_features_label = tk.Label(root, text="Text Features: (Not Extracted)", wraplength=400, justify="left")
        self.text_features_label.pack(pady=5)

        self.image_features_label = tk.Label(root, text="Image Features: (Not Extracted)", wraplength=400, justify="left")
        self.image_features_label.pack(pady=5)
        
        self.model_data_label = tk.Label(root, text="No model", wraplength=400, justify="left")
        self.model_data_label.pack(pady=5)

        # Canvas for displaying the image
        self.image_canvas = tk.Canvas(root, width=400, height=300, bg="gray")
        self.image_canvas.pack(pady=10)

        # Label for showing image path
        self.image_label = tk.Label(root, text="No image loaded", fg="blue")
        self.image_label.pack(pady=5)

        # Variables to hold loaded text and image
        self.loaded_text = None
        self.loaded_image = None

    def create_clip_model_and_preprocessor(self):
        """Initialize the CLIP model and preprocessing pipeline for images."""
        print("Loading CLIP model...")
        model, preprocess = clip.load("ViT-B/32", device=self.device)
        print("CLIP model loaded.")
        return model, preprocess, self.device

    def load_text(self):
        """Open a file dialog to select a text file and display its content."""
        file_path = filedialog.askopenfilename(filetypes=[("Text Files", "*.txt")])
        if file_path:
            with open(file_path, "r") as file:
                self.loaded_text = file.read()
            self.text_area.delete("1.0", tk.END)  # Clear existing content
            self.text_area.insert(tk.END, self.loaded_text)  # Insert file content

    def load_image(self):
        """Open a file dialog to select an image file and display it."""
        file_path = filedialog.askopenfilename(filetypes=[("Image Files", "*.jpg;*.jpeg;*.png;*.bmp")])
        if file_path:
            self.image_label.config(text=file_path)

            # Load and display the image
            img = Image.open(file_path)
            self.loaded_image = np.array(img)  # Save NumPy array for feature extraction
            img.thumbnail((400, 400))  # Resize to fit canvas
            img = ImageTk.PhotoImage(img)

            self.image_canvas.delete("all")  # Clear existing image
            self.image_canvas.create_image(200, 200, anchor=tk.CENTER, image=img)
            self.image_canvas.image = img  # Keep a reference to avoid garbage collection



    def extract_text_features(self):
        """Extract features from the loaded text using CLIP on CUDA."""
        if not self.loaded_text:
            self.text_features_label.config(text="No text loaded to extract features!")
            return

        # Tokenize and encode the text using the CLIP model
        print("Extracting text features on CUDA...")
        with torch.no_grad():
            text_tokens = clip.tokenize([self.loaded_text]).to(self.device)  # Move to CUDA
            text_features = self.clip_model.encode_text(text_tokens)

            # Normalize the features
            text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
                    # Display the text features
            self.text_features_label.config(text="Text Features is extracted")
            return text_features
        
    def load_text_image_similarity_model(self,device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Load and initialize the TextImageSimilarityModel.

        Args:
            text_image_input_dim (int): Dimensionality of the input for the MMoE module.
            input_dim (int): Input dimension for the feature extraction layers.
            embed_dim (int): Dimensionality of the embeddings in feature extraction layers.
            num_fc_layers (int): Number of fully connected layers in feature extraction.
            expert_dim (int): Dimensionality of the expert layers in MMoE.
            num_experts (int): Number of experts in MMoE.
            num_tasks (int): Number of tasks for the MMoE model.
            output_dim (int): Output dimension for each task.
            num_fc_layers_mmoe (int): Number of fully connected layers in the MMoE model.
            device (str): Device to load the model on ('cpu' or 'cuda').

        Returns:
            TextImageSimilarityModel: The loaded and initialized model.
        """
        # Instantiate the model
        model = TextImageSimilarityModel(text_image_input_dim=512, input_dim=512, embed_dim=256, num_fc_layers=4, 
                                 expert_dim=64, num_experts=8, num_tasks=2, output_dim=1, num_fc_layers_mmoe=2).to(device)
        model.load_state_dict(torch.load("model_final.pth"))
        model.eval()
        
        if self.text_area.get("1.0", "end-1c").strip():  # Check if the content is not empty
            self.loaded_text = self.text_area.get("1.0", "end-1c")

        
        """Extract features from the loaded text using CLIP on CUDA."""
        if not self.loaded_text:
            self.text_features_label.config(text="No text loaded to extract features!")
            return

        # Tokenize and encode the text using the CLIP model
        test_text_features = []
        print("Extracting text features on CUDA...")
        with torch.no_grad():
            text_tokens = clip.tokenize([self.loaded_text]).to(self.device)  # Move to CUDA
            text_features = self.clip_model.encode_text(text_tokens)

            # Normalize the features
            text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
                    # Display the text features
            test_text_features = text_features.to(device, dtype=torch.float32) 
            self.text_features_label.config(text="Text Features is extracted")
        
        
        
        
        test_image_features = []
        
        """Extract features from the loaded image using CLIP on CUDA."""
        if self.loaded_image is None:
            self.image_features_label.config(text="No image loaded to extract features!")
            return

        # Process the image and extract features
        print("Extracting image features on CUDA...")
        image_features = self.extract_image_features_from_np_single(
            self.clip_model, self.loaded_image, self.preprocess, self.device
        )
        test_image_features=image_features.to(device, dtype=torch.float32) 
        

        # Display the image features
        self.image_features_label.config(text="Image Features is extracted")
        
        
        image = self.loaded_image
        
        #Resize images to 64x64 using OpenCV
        test_imagesX = np.array(cv2.resize(image, (64, 64)))
               
        test_imagesX = test_imagesX.astype(np.float32) / 255.0  # Normalize images (scale pixel values to [0, 1])
        # Convert train_imagesX to PyTorch tensor and move to device

        # Check dimensions
        print(f"Test text features shape: {test_text_features.shape}")
        print(f"Test image features shape: {test_image_features.shape}")
        print(f"Test image features shape: {test_imagesX.shape}")
        
       
        test_imagesX_tensor = torch.from_numpy(test_imagesX)
        test_imagesX_tensor = test_imagesX_tensor.unsqueeze(0).permute(0, 3, 1, 2).to(device) 
        print(f"Test image tensor shape: {test_imagesX_tensor.shape}")
        
         # Forward pass
        with torch.no_grad():
            output = model(test_text_features, test_image_features, test_imagesX_tensor)
            #print(f"Output: {output.item()}")
            probabilities = output.squeeze().cpu().numpy()  # Convert to numpy
            predictions = (probabilities >= 0.5).astype(int)  # Binary classification
            # Compare with label
            #print(f"Predicted label: {'Real' if predictions >=0.5 else 'Fake'}")

        # Compare with label
        #print(f"Actual label: {test_label[idx].item()}")
        
        # idx = 1800  # Select the index of the data point you want to test
        # # Load single data point
        # test_text_features = torch.load('clip_test_text_features.pt').to(device, dtype=torch.float32)
        # test_image_features = torch.load('clip_test_image_features.pt').to(device, dtype=torch.float32)
        # test_label = torch.load('clip_test_label.pt').to(device, dtype=torch.float32)  # Test label

        # test_imagesX = np.load('traslate_test_images_64.npy')
        # test_imagesX = test_imagesX.astype(np.float32) / 255.0  # Normalize images (scale pixel values to [0, 1])
        # # Convert train_imagesX to PyTorch tensor and move to device
        # test_imagesX_tensor = torch.from_numpy(test_imagesX).permute(0, 3, 1, 2).to(device)  # (batch_size, C, H, W)

        # idx = 1800  # Select the index of the data point you want to test
        # # Check dimensions
        # print(f"Test text features shape: {test_text_features[idx].shape}")
        # print(f"Test image features shape: {test_image_features[idx].shape}")
        # print(f"Test image tensor shape: {test_imagesX_tensor[idx].shape}")
        

        # # Forward pass
        # with torch.no_grad():
        #     output = model(test_text_features[idx].unsqueeze(0), test_image_features[idx].unsqueeze(0), test_imagesX_tensor[idx].unsqueeze(0))
        #     print(f"Output: {output.item()}")
        #     probabilities = output.squeeze().cpu().numpy()  # Convert to numpy
        #     predictions = (probabilities >= 0.5).astype(int)  # Binary classification
        #     # Compare with label
        #     print(f"Predicted label: {'Real' if predictions >=0.5 else 'Fake'}")

        # # Compare with label
        # print(f"Actual label: {test_label[idx].item()}")


        
        # Display the image features
        self.model_data_label.config(text=f"The model is loaded. The model output: {output.item()}. \nPredicted label: {'Real' if predictions >=0.5 else 'Fake'}")
        
        return model


    def extract_image_features(self):
        """Extract features from the loaded image using CLIP on CUDA."""
        if self.loaded_image is None:
            self.image_features_label.config(text="No image loaded to extract features!")
            return

        # Process the image and extract features
        print("Extracting image features on CUDA...")
        image_features = self.extract_image_features_from_np_single(
            self.clip_model, self.loaded_image, self.preprocess, self.device
        )

        # Display the image features
        self.image_features_label.config(text="Image Features is extracted")
        return image_features



    def extract_image_features_from_np_single(self, model, image, preprocess, device='cpu'):
        """
        Extract image features using CLIP for a single NumPy array image.
        Args:
            model: CLIP model.
            image: Single NumPy array (H, W, C).
            preprocess: CLIP image preprocessing pipeline.
            device: Device to run the extraction ('cpu' or 'cuda').
        Returns:
            Normalized image features as a PyTorch tensor.
        """
        # Convert NumPy array to PIL image and preprocess
        img_pil = Image.fromarray(np.uint8(image)).convert('RGB')  # Convert to PIL image
        processed_image = preprocess(img_pil).unsqueeze(0).to(device)  # Preprocess and move to device

        # Pass through the model to extract features
        with torch.no_grad():
            image_features = model.encode_image(processed_image)
            # Normalize the features
            image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)

        return image_features

# Initialize the GUI application
if __name__ == "__main__":
    root = tk.Tk()
    app = FileLoaderApp(root)
    root.mainloop()


Loading CLIP model...
CLIP model loaded.


  model.load_state_dict(torch.load("model_final.pth"))


Extracting text features on CUDA...


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)


Extracting image features on CUDA...
Test text features shape: torch.Size([1, 512])
Test image features shape: torch.Size([1, 512])
Test image features shape: (64, 64, 3)
Test image tensor shape: torch.Size([1, 3, 64, 64])
Extracting text features on CUDA...
Extracting image features on CUDA...
Test text features shape: torch.Size([1, 512])
Test image features shape: torch.Size([1, 512])
Test image features shape: (64, 64, 3)
Test image tensor shape: torch.Size([1, 3, 64, 64])
