# An overview
Siamese networks consist of two identical sub-networks that share weights and learn to compute the similarity between two input samples. The goal is to learn embeddings such that similar inputs are close in the embedding space, while dissimilar inputs are far apart. For the WikiDiverse dataset, where we have image-caption pairs, we can build a Siamese network that processes text and image data (or just one modality like text or image) and learns to compute similarity between two entities from the knowledge base.
* Siamese Network Structure: Two identical sub-networks that compute embeddings for input pairs and learn their similarity
* Application: For WikiDiverse, compute similarity between image-caption pairs to link knowledge-base entities.

# 1. Required Libraries

In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score  # Adjustment
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
# Dataset path
DATASET_PATH = r"C:\Users\Min Dator\aics-project\wikidiverse_w_cands\wikidiverse_w_cands"
IMAGE_DIR = os.path.join(DATASET_PATH, "images")

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.0001
EMBED_DIM = 256
VOCAB_SIZE = 10000

# 2. Data Class
Prepare the dataset and dataloader for training and evaluation.

In [None]:
class WikiDiverseDataset(Dataset):
    def __init__(self, image_paths, captions, labels, tokenizer, transform=None):
        self.image_paths = image_paths
        self.captions = captions
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        
        text = self.tokenizer(self.captions[idx], truncation=True, padding="max_length", max_length=100, return_tensors="pt")
        text_input = text["input_ids"].squeeze(0)
        
        return image, text_input, torch.tensor(self.labels[idx], dtype=torch.float32)

# 3. Datasets and Dataloaders

In [None]:
# Load Data 
data_file = os.path.join(DATASET_PATH, "captions_and_labels.csv")  # Your CSV file with captions and labels
df = pd.read_csv(data_file)

# Generate full paths to images
image_paths = [os.path.join(IMAGE_DIR, filename) for filename in df['image_filename']]

# Captions and labels
captions = df['caption'].tolist()
labels = df['label'].tolist()

In [None]:
# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Train-test split
train_paths, test_paths, train_captions, test_captions, train_labels, test_labels = train_test_split(
    image_paths, captions, labels, test_size=0.2, random_state=42
)

# Image transformation
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Datasets and Dataloaders
train_dataset = WikiDiverseDataset(train_paths, train_captions, train_labels, tokenizer, transform)
test_dataset = WikiDiverseDataset(test_paths, test_captions, test_labels, tokenizer, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# 4. Cross Attention Module
The cross-attention mechanism allows one modality (e.g., text) to attend to another (e.g., image).

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        attn_output = self.dropout(attn_output)
        return self.layer_norm(query + attn_output)

# 5. Sub-Networks with Attention

In [None]:
class ImageSubNetworkWithAttention(nn.Module):
    def __init__(self, embed_dim=256, num_heads=4):
        super(ImageSubNetworkWithAttention, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])  # Remove FC layer
        self.fc = nn.Sequential(
            nn.Linear(2048, embed_dim),
            nn.ReLU()
        )
        self.cross_attention = CrossAttention(embed_dim, num_heads)

    def forward(self, x, text_features):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        x = self.cross_attention(x.unsqueeze(1), text_features, text_features).squeeze(1)
        return x

class TextSubNetworkWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_heads=4):
        super(TextSubNetworkWithAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU()
        )
        self.cross_attention = CrossAttention(hidden_dim, num_heads)

    def forward(self, x, image_features):
        x = self.embedding(x)
        _, (hidden, _) = self.lstm(x)
        hidden = hidden.squeeze(0)
        hidden = self.cross_attention(hidden.unsqueeze(1), image_features, image_features).squeeze(1)
        return self.fc(hidden)

# 6. Siamese Network with Cross-Attention
This combines image and text embeddings using cross-attention and computes similarity scores.

In [None]:
class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, vocab_size):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_net = ImageSubNetworkWithAttention()
        self.text_net = TextSubNetworkWithAttention(vocab_size)

    def forward(self, img1, img2, text1, text2):
        text_features1 = self.text_net(text1, None)
        text_features2 = self.text_net(text2, None)
        
        img_embedding1 = self.image_net(img1, text_features1)
        img_embedding2 = self.image_net(img2, text_features2)

        text_embedding1 = self.text_net(text1, img_embedding1)
        text_embedding2 = self.text_net(text2, img_embedding2)

        combined_embedding1 = torch.cat([img_embedding1, text_embedding1], dim=1)
        combined_embedding2 = torch.cat([img_embedding2, text_embedding2], dim=1)
        
        return combined_embedding1, combined_embedding2

# 7. Loss Function
The contrastive loss function encourages similar pairs to have closer embeddings and dissimilar pairs to have distant embeddings.

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

# 8. Training and Evaluation

# Training 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetworkWithCrossAttention(VOCAB_SIZE).to(device)
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for img1, text1, label in train_loader:
        img1, text1, label = img1.to(device), text1.to(device), label.to(device)
        optimizer.zero_grad()
        output1, output2 = model(img1, img1, text1, text1)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss / len(train_loader))

# Validation

In [None]:
    model.eval()
    val_loss = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for img1, text1, label in test_loader:
            img1, text1, label = img1.to(device), text1.to(device), label.to(device)
            output1, output2 = model(img1, img1, text1, text1)
            val_loss += criterion(output1, output2, label).item()
            preds = torch.nn.functional.pairwise_distance(output1, output2)
            all_labels.extend(label.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    val_losses.append(val_loss / len(test_loader))

    # Calculate evaluation metrics
    roc_auc = roc_auc_score(all_labels, all_preds)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, ROC AUC: {roc_auc:.4f}")

# Plot Loss
plt.figure()
plt.plot(range(1, EPOCHS + 1), train_losses, label="Train Loss")
plt.plot(range(1, EPOCHS + 1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

# Evaluate

In [None]:
print("\nStarting Evaluation...")
accuracy, precision, recall, f1 = evaluate_model(model, test_loader)

##### END #### 

# 2. Data Preprocessing

# Image Preprocessing
We use a standard ResNet50 model for extracting features from images. Images are resized and normalized before feeding into the model.

image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def preprocess_image(img_path):
    image = Image.open(img_path).convert("RGB")
    return image_transforms(image)

# Text Preprocessing
We use a pretrained BERT tokenizer to tokenize and encode text inputs.

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_text(text, max_length=100):
    encoded = tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
    return encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0)

# 3. Sub-Networks

# Image Embedding Network
This network uses ResNet50 to extract image features and reduces them to a fixed-size embedding.

In [None]:
class ImageEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(ImageEmbedding, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # Remove the final classification layer
        self.fc = nn.Linear(2048, embedding_dim)  # Reduce to embedding dimension

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        embedding = self.fc(features)
        return embedding

# Text Embedding Network
This network uses BERT to extract text embeddings, followed by a linear layer to reduce dimensions.

In [None]:
class TextEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(TextEmbedding, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, embedding_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.pooler_output  # CLS token representation
        embedding = self.fc(cls_embedding)
        return embedding

# 4. Cross-Attention Mechanism
The cross-attention mechanism allows one modality (e.g., text) to attend to another (e.g., image).

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        output = self.fc(attn_output)
        return output

# 5. Siamese Network with Cross-Attention
This combines image and text embeddings using cross-attention and computes similarity scores.

class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, embedding_dim=256, num_heads=4):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_embedding = ImageEmbedding(embedding_dim)
        self.text_embedding = TextEmbedding(embedding_dim)
        self.cross_attention = CrossAttention(embed_dim=embedding_dim, num_heads=num_heads)
        self.fc = nn.Linear(embedding_dim * 2, embedding_dim)  # Final layer after concatenation

    def forward(self, img1, img2, text1, mask1, text2, mask2):
        # Extract embeddings
        img_emb1 = self.image_embedding(img1)
        img_emb2 = self.image_embedding(img2)
        text_emb1 = self.text_embedding(text1, mask1)
        text_emb2 = self.text_embedding(text2, mask2)

        # Apply cross-attention (optional but useful for alignment)
        img_text_emb1 = self.cross_attention(text_emb1.unsqueeze(1), img_emb1.unsqueeze(1), img_emb1.unsqueeze(1)).squeeze(1)
        img_text_emb2 = self.cross_attention(text_emb2.unsqueeze(1), img_emb2.unsqueeze(1), img_emb2.unsqueeze(1)).squeeze(1)

        ## Combine embeddings
        combined_emb1 = torch.cat((img_emb1, text_emb1), dim=1)  # text component included
        combined_emb2 = torch.cat((img_emb2, text_emb2), dim=1)

        # Pass the combined embeddings through a final fully connected layer
        combined_emb1 = self.fc(combined_emb1)
        combined_emb2 = self.fc(combined_emb2)

        return combined_emb1, combined_emb2

class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, embedding_dim=256, num_heads=4):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_embedding = ImageEmbedding(embedding_dim)
        self.text_embedding = TextEmbedding(embedding_dim)
        self.cross_attention = CrossAttention(embed_dim=embedding_dim, num_heads=num_heads)
        self.fc = nn.Linear(embedding_dim * 2, embedding_dim)

    def forward(self, img1, img2, text1, mask1, text2, mask2):
        # Extract embeddings
        img_emb1 = self.image_embedding(img1)
        img_emb2 = self.image_embedding(img2)
        text_emb1 = self.text_embedding(text1, mask1)
        text_emb2 = self.text_embedding(text2, mask2)

        # Apply cross-attention
        img_text_emb1 = self.cross_attention(text_emb1.unsqueeze(1), img_emb1.unsqueeze(1), img_emb1.unsqueeze(1)).squeeze(1)
        img_text_emb2 = self.cross_attention(text_emb2.unsqueeze(1), img_emb2.unsqueeze(1), img_emb2.unsqueeze(1)).squeeze(1)

        # Combine embeddings
        combined_emb1 = torch.cat((img_emb1, img_text_emb1), dim=1) #
        combined_emb2 = torch.cat((img_emb2, img_text_emb2), dim=1)

        # Reduce to a single embedding
        combined_emb1 = self.fc(combined_emb1)
        combined_emb2 = self.fc(combined_emb2)

        return combined_emb1, combined_emb2

# 6. Loss Function
The contrastive loss function encourages similar pairs to have closer embeddings and dissimilar pairs to have distant embeddings.

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, emb1, emb2, label):
        euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)
        loss = (label * torch.square(euclidean_distance)) + \
               ((1 - label) * torch.square(torch.clamp(self.margin - euclidean_distance, min=0.0)))
        return loss.mean()

# 7. Training and Evaluation

# Training Loop

model = SiameseNetworkWithCrossAttention()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    model.train()
    for img1, img2, text1, mask1, text2, mask2, labels in train_loader:
        optimizer.zero_grad()
        emb1, emb2 = model(img1, img2, text1, mask1, text2, mask2)
        loss = criterion(emb1, emb2, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Evaluation 

import torch
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Evaluate the model
def evaluate_model(model, test_loader, threshold=0.5):
    model.eval()  
    all_predictions = []
    all_labels = []

    with torch.no_grad():  # Disable gradient computation for evaluation
        for img1, img2, text1, mask1, text2, mask2, labels in test_loader:
            # Move data to the same device as the model
            img1, img2 = img1.to(device), img2.to(device)
            text1, mask1 = text1.to(device), mask1.to(device)
            text2, mask2 = text2.to(device), mask2.to(device)
            labels = labels.to(device)

            # Forward pass to compute embeddings
            emb1, emb2 = model(img1, img2, text1, mask1, text2, mask2)

            # Compute Euclidean distance between embeddings
            euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)

            # Convert distances to binary predictions using the threshold
            predictions = (euclidean_distance < threshold).long()

            # Append predictions and labels to lists
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Compute evaluation metrics
    accuracy = accuracy_score(all_labels, all_predictions)
    precision = precision_score(all_labels, all_predictions)
    recall = recall_score(all_labels, all_predictions)
    f1 = f1_score(all_labels, all_predictions)

    # Print metrics
    print(f"Evaluation Metrics:\n")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")

    return accuracy, precision, recall, f1

# Assuming `test_loader` is your DataLoader for the test dataset
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#model = model.to(device)  # Move model to appropriate device

# Evaluate the model
#evaluate_model(model, test_loader)

# 8. DataLoader 
Prepare the dataset and dataloader for training and evaluation.

class WikiDiverseDataset(Dataset):
    def __init__(self, image_paths, texts, labels, tokenizer, transform):
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        text = self.texts[idx]
        label = self.labels[idx]

        img = preprocess_image(img_path)
        input_ids, attention_mask = preprocess_text(text)

        return img, input_ids, attention_mask, label

# Usage
train_dataset = WikiDiverseDataset(train_image_paths, train_texts, train_labels, tokenizer, image_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

  #### END #####

import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from PIL import Image
from torchvision import transforms

# Dataset paths
DATASET_PATH = r"C:\Users\Min Dator\aics-project\wikidiverse_w_cands\wikidiverse_w_cands"
TRAIN_PATH = os.path.join(DATASET_PATH, "train_w_10cands")
VALID_PATH = os.path.join(DATASET_PATH, "valid_w_10cands")
TEST_PATH = os.path.join(DATASET_PATH, "test_w_10cands")

# Verify that the directories exist
for path in [TRAIN_PATH, VALID_PATH, TEST_PATH]:
    if not os.path.exists(path):
        raise FileNotFoundError(f"Dataset directory not found: {path}")
print("All dataset directories found.")

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.0001
EMBED_DIM = 256
VOCAB_SIZE = 10000  # For tokenizer

# Dataset class
class WikiDiverseDataset(Dataset):
    def __init__(self, data_dir, tokenizer, transform=None):
        self.data_dir = data_dir
        self.tokenizer = tokenizer
        self.transform = transform
        self.image_paths = []
        self.captions = []
        self.labels = []

        # Load dataset from the directory
        self._load_data()

    def _load_data(self):
        # Assuming data_dir contains files with images and captions
        for file_name in os.listdir(self.data_dir):
            if file_name.endswith(".jpg"):  # Process images
                image_path = os.path.join(self.data_dir, file_name)
                caption_path = image_path.replace(".jpg", ".txt")
                label_path = image_path.replace(".jpg", ".label")

                # Check for corresponding caption and label files
                if os.path.exists(caption_path) and os.path.exists(label_path):
                    self.image_paths.append(image_path)
                    with open(caption_path, "r") as f:
                        self.captions.append(f.read().strip())
                    with open(label_path, "r") as f:
                        self.labels.append(int(f.read().strip()))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        # Load and preprocess image
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Tokenize text
        text = self.tokenizer(self.captions[idx], truncation=True, padding="max_length", max_length=100, return_tensors="pt")
        text_input = text["input_ids"].squeeze(0)

        # Return data
        return image, text_input, torch.tensor(self.labels[idx], dtype=torch.float32)

# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Image transformation
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Datasets and Dataloaders
train_dataset = WikiDiverseDataset(TRAIN_PATH, tokenizer, transform)
valid_dataset = WikiDiverseDataset(VALID_PATH, tokenizer, transform)
test_dataset = WikiDiverseDataset(TEST_PATH, tokenizer, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Test samples: {len(test_dataset)}")

import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from PIL import Image
from torchvision import transforms

# File paths
DESC_FILE = r"C:\Users\Min Dator\aics-project\wikipedia_entity2desc_filtered.tsv"
IMGS_FILE = r"C:\Users\Min Dator\aics-project\wikipedia_entity2imgs.tsv"
DATASET_PATH = r"C:\Users\Min Dator\aics-project\wikidiverse_w_cands\wikidiverse_w_cands"

# Check if files exist
for file in [DESC_FILE, IMGS_FILE]:
    if not os.path.exists(file):
        raise FileNotFoundError(f"File not found: {file}")
print("All files found.")

# Load TSV files
desc_df = pd.read_csv(DESC_FILE, sep="\t", names=["entity_id", "description"], header=None)
imgs_df = pd.read_csv(IMGS_FILE, sep="\t", names=["entity_id", "image_path"], header=None)

# Merge the data
data_df = pd.merge(imgs_df, desc_df, on="entity_id")
data_df["image_path"] = data_df["image_path"].apply(lambda x: os.path.join(DATASET_PATH, x))

# Filter out entries where the image files don't exist
data_df = data_df[data_df["image_path"].apply(os.path.exists)]
print(f"Dataset loaded with {len(data_df)} samples.")

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.0001
EMBED_DIM = 256
VOCAB_SIZE = 10000  # For tokenizer

# Dataset class
class WikiDiverseDataset(Dataset):
    def __init__(self, dataframe, tokenizer, transform=None):
        self.data = dataframe
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = row["image_path"]
        description = row["description"]

        # Load and preprocess image
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Tokenize text
        text = self.tokenizer(description, truncation=True, padding="max_length", max_length=100, return_tensors="pt")
        text_input = text["input_ids"].squeeze(0)

        # Dummy label for now (you can modify this to load actual labels)
        label = torch.tensor(1.0, dtype=torch.float32)

        return image, text_input, label

# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Image transformation
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Split the data into train, validation, and test sets
train_df, valid_df, test_df = (
    data_df.iloc[:int(0.7 * len(data_df))],
    data_df.iloc[int(0.7 * len(data_df)):int(0.85 * len(data_df))],
    data_df.iloc[int(0.85 * len(data_df)):],
)

# Datasets and Dataloaders
train_dataset = WikiDiverseDataset(train_df, tokenizer, transform)
valid_dataset = WikiDiverseDataset(valid_df, tokenizer, transform)
test_dataset = WikiDiverseDataset(test_df, tokenizer, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Test samples: {len(test_dataset)}")

### END ###

# 1. Overview
This implementation is a multimodal entity linking system that leverages both text and image information to perform entity disambiguation. The model uses BERT for text encoding and CLIP for image encoding. A cross-attention mechanism fuses text and image features, and a disambiguation head predicts the correct entity from a list of candidates. The system includes training and validation loops to optimize the model.
# Key Components
-Text Encoder: BERT-based textual feature extraction
- Image Encoder: CLIP-based image feature extraction
- Cross-Attention: Fuse text and image representations
- Entity Disambiguation Head: Predict the correct entity
- Training and Validation: Dataset handling, loss computation, and optimizer setup
# Data Requiremnt (each sample in the data includes) 
- 'text': The textual description or context
- 'image_path': Path to the corresponding image file
- 'label': Index of the correct entity in the candidate list

# 2. Required Libraries 

In [None]:
#import torch
#import torch.nn as nn
#import torch.optim as optim
#from transformers import BertTokenizer, BertModel, CLIPProcessor, CLIPModel, get_scheduler
#import torch.nn.functional as F
#from PIL import Image
#from tqdm import tqdm
#from torch.utils.data import Dataset, DataLoader

# 3. Textual Encoder (BERT-based) and Image Encoder (CLIP)

In [None]:
# Initialize BERT tokenizer and model for textual encoding
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#bert_model = BertModel.from_pretrained('bert-base-uncased')

In [3]:
# Initialize CLIP processor and model for image encoding
#clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
#clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

# 4. Multimodal Attention Layer

In [4]:
#class CrossAttentionLayer(nn.Module):
#    def __init__(self, hidden_size, num_attention_heads):
#        super(CrossAttentionLayer, self).__init__()
#        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_attention_heads)
#        self.fc = nn.Linear(hidden_size, hidden_size)
#
#    def forward(self, text_features, image_features):
#        text_features = text_features.unsqueeze(0)
#        image_features = image_features.unsqueeze(0)
#
#        # Perform attention between text and image features
#        attn_output_text, _ = self.attention(text_features, image_features, image_features)
#        attn_output_image, _ = self.attention(image_features, text_features, text_features)
#        
#        combined_output = attn_output_text + attn_output_image
#        combined_output = self.fc(combined_output)
#        return combined_output

# 5. Entity Disambiguation

In [5]:
#class EntityDisambiguationHead(nn.Module):
 #   def __init__(self, hidden_size, num_candidates):
  #      super(EntityDisambiguationHead, self).__init__()
   #     self.fc1 = nn.Linear(hidden_size, 512)
    #    self.fc2 = nn.Linear(512, num_candidates)
     #   self.softmax = nn.Softmax(dim=-1)

    #def forward(self, features):
     #   x = F.relu(self.fc1(features))
      #  x = self.fc2(x)
       # return self.softmax(x)

# 6. Multimodal Entity Linking Model

In [6]:
#class MultimodalEntityLinkingModel(nn.Module):
#    def __init__(self, hidden_size, num_attention_heads, num_candidates):
#        super(MultimodalEntityLinkingModel, self).__init__()
#        self.text_encoder = bert_model
#        self.image_encoder = clip_model
#        self.cross_attention_layer = CrossAttentionLayer(hidden_size, num_attention_heads)
#        self.disambiguation_head = EntityDisambiguationHead(hidden_size, num_candidates)
#
#    def forward(self, text_input, image_input):
#        # Textual feature extraction using BERT
#        encoded_input = tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
#        text_output = self.text_encoder(**encoded_input).last_hidden_state  # shape: (batch_size, seq_len, hidden_size)
#        
#        # Image feature extraction using CLIP
#        inputs = clip_processor(text=[text_input], images=image_input, return_tensors="pt", padding=True)
#        outputs = self.image_encoder(**inputs)
#        image_features = outputs.image_embeds  # shape: (batch_size, embedding_dim)
#
#        # Apply cross-attention to fuse text and image features
#        combined_features = self.cross_attention_layer(text_output, image_features)

#        # Disambiguate and predict the entity
#        entity_scores = self.disambiguation_head(combined_features.squeeze(0))  # Removing batch dimension
#        return entity_scores

# 7. Dataset class

In [None]:
#class WikiDiverseDataset(Dataset):
#    def __init__(self, data, tokenizer, clip_processor):
#        self.data = data
#        self.tokenizer = tokenizer
#        self.clip_processor = clip_processor
#
#    def __len__(self):
#        return len(self.data)
#
#    def __getitem__(self, idx):
#        sample = self.data[idx]
#        text = sample['text']
#        image_path = sample['image_path']
#        label = sample['label']
#
#        # Process text
#        encoded_text = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
#
        # Process image
 #       image = Image.open(image_path).convert("RGB")
 #       processed_image = self.clip_processor(images=image, return_tensors="pt").pixel_values.squeeze(0)

 #       return encoded_text, processed_image, label

# 8. Training and Validation Function 

def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_accuracy = 0

    for batch in tqdm(dataloader, desc="Training"):
        encoded_text, processed_image, labels = batch
        encoded_text = {key: val.squeeze(0).to(device) for key, val in encoded_text.items()}
        processed_image = processed_image.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        entity_scores = model(encoded_text, processed_image)
        loss = criterion(entity_scores, labels)
        loss.backward()
        optimizer.step()

        predictions = torch.argmax(entity_scores, dim=1)
        accuracy = (predictions == labels).float().mean()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    return total_loss / len(dataloader), total_accuracy / len(dataloader)

def validate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            encoded_text, processed_image, labels = batch
            encoded_text = {key: val.squeeze(0).to(device) for key, val in encoded_text.items()}
            processed_image = processed_image.to(device)
            labels = labels.to(device)

            entity_scores = model(encoded_text, processed_image)
            loss = criterion(entity_scores, labels)

            predictions = torch.argmax(entity_scores, dim=1)
            accuracy = (predictions == labels).float().mean()

            total_loss += loss.item()
            total_accuracy += accuracy.item()

    return total_loss / len(dataloader), total_accuracy / len(dataloader)

# 9. Training Loop

hidden_size = 768
num_attention_heads = 8
num_candidates = 10
num_epochs = 5
batch_size = 16
learning_rate = 2e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = WikiDiverseDataset(train_data, tokenizer, clip_processor)
val_dataset = WikiDiverseDataset(val_data, tokenizer, clip_processor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

model = MultimodalEntityLinkingModel(hidden_size, num_attention_heads, num_candidates).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    train_loss, train_acc = train_model(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_acc = validate_model(model, val_dataloader, criterion, device)

    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
    lr_scheduler.step()