# 1. Required Libraries

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from PIL import Image

# 2. Data Preprocessing

# Image Preprocessing
We use a standard ResNet50 model for extracting features from images. Images are resized and normalized before feeding into the model.

In [3]:
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def preprocess_image(img_path):
    image = Image.open(img_path).convert("RGB")
    return image_transforms(image)

# Text Preprocessing
We use a pretrained BERT tokenizer to tokenize and encode text inputs.

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_text(text, max_length=100):
    encoded = tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
    return encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0)

# 3. Sub-Networks

# Image Embedding Network
This network uses ResNet50 to extract image features and reduces them to a fixed-size embedding.

In [None]:
class ImageEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(ImageEmbedding, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # Remove the final classification layer
        self.fc = nn.Linear(2048, embedding_dim)  # Reduce to embedding dimension

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        embedding = self.fc(features)
        return embedding

# Text Embedding Network
This network uses BERT to extract text embeddings, followed by a linear layer to reduce dimensions.

In [None]:
class TextEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(TextEmbedding, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, embedding_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.pooler_output  # CLS token representation
        embedding = self.fc(cls_embedding)
        return embedding

# 4. Cross-Attention Mechanism
The cross-attention mechanism allows one modality (e.g., text) to attend to another (e.g., image).

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        output = self.fc(attn_output)
        return output

# 5. Siamese Network with Cross-Attention
This combines image and text embeddings using cross-attention and computes similarity scores.

In [None]:
class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, embedding_dim=256, num_heads=4):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_embedding = ImageEmbedding(embedding_dim)
        self.text_embedding = TextEmbedding(embedding_dim)
        self.cross_attention = CrossAttention(embed_dim=embedding_dim, num_heads=num_heads)
        self.fc = nn.Linear(embedding_dim * 2, embedding_dim)

    def forward(self, img1, img2, text1, mask1, text2, mask2):
        # Extract embeddings
        img_emb1 = self.image_embedding(img1)
        img_emb2 = self.image_embedding(img2)
        text_emb1 = self.text_embedding(text1, mask1)
        text_emb2 = self.text_embedding(text2, mask2)

        # Apply cross-attention
        img_text_emb1 = self.cross_attention(text_emb1.unsqueeze(1), img_emb1.unsqueeze(1), img_emb1.unsqueeze(1)).squeeze(1)
        img_text_emb2 = self.cross_attention(text_emb2.unsqueeze(1), img_emb2.unsqueeze(1), img_emb2.unsqueeze(1)).squeeze(1)

        # Combine embeddings
        combined_emb1 = torch.cat((img_emb1, img_text_emb1), dim=1)
        combined_emb2 = torch.cat((img_emb2, img_text_emb2), dim=1)

        # Reduce to a single embedding
        combined_emb1 = self.fc(combined_emb1)
        combined_emb2 = self.fc(combined_emb2)

        return combined_emb1, combined_emb2

# 6. Loss Function
The contrastive loss function encourages similar pairs to have closer embeddings and dissimilar pairs to have distant embeddings.

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, emb1, emb2, label):
        euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)
        loss = (label * torch.square(euclidean_distance)) + \
               ((1 - label) * torch.square(torch.clamp(self.margin - euclidean_distance, min=0.0)))
        return loss.mean()

# 7. Training and Evaluation

# Training Loop

In [None]:
model = SiameseNetworkWithCrossAttention()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    model.train()
    for img1, img2, text1, mask1, text2, mask2, labels in train_loader:
        optimizer.zero_grad()
        emb1, emb2 = model(img1, img2, text1, mask1, text2, mask2)
        loss = criterion(emb1, emb2, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

# Evaluation 

In [None]:
model.eval()
with torch.no_grad():
    for img1, img2, text1, mask1, text2, mask2, labels in test_loader:
        emb1, emb2 = model(img1, img2, text1, mask1, text2, mask2)
        euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)
        # Evaluate accuracy, precision, recall, etc.

# 8. DataLoader Example
Prepare the dataset and dataloader for training and evaluation.

In [None]:
class WikiDiverseDataset(Dataset):
    def __init__(self, image_paths, texts, labels, tokenizer, transform):
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        text = self.texts[idx]
        label = self.labels[idx]

        img = preprocess_image(img_path)
        input_ids, attention_mask = preprocess_text(text)

        return img, input_ids, attention_mask, label

# Usage
train_dataset = WikiDiverseDataset(train_image_paths, train_texts, train_labels, tokenizer, image_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

  #### END #####

# 1. Overview
This implementation is a multimodal entity linking system that leverages both text and image information to perform entity disambiguation. The model uses BERT for text encoding and CLIP for image encoding. A cross-attention mechanism fuses text and image features, and a disambiguation head predicts the correct entity from a list of candidates. The system includes training and validation loops to optimize the model.
# Key Components
-Text Encoder: BERT-based textual feature extraction
- Image Encoder: CLIP-based image feature extraction
- Cross-Attention: Fuse text and image representations
- Entity Disambiguation Head: Predict the correct entity
- Training and Validation: Dataset handling, loss computation, and optimizer setup
# Data Requiremnt (each sample in the data includes) 
- 'text': The textual description or context
- 'image_path': Path to the corresponding image file
- 'label': Index of the correct entity in the candidate list

# 2. Required Libraries 

In [None]:
#import torch
#import torch.nn as nn
#import torch.optim as optim
#from transformers import BertTokenizer, BertModel, CLIPProcessor, CLIPModel, get_scheduler
#import torch.nn.functional as F
#from PIL import Image
#from tqdm import tqdm
#from torch.utils.data import Dataset, DataLoader

# 3. Textual Encoder (BERT-based) and Image Encoder (CLIP)

In [None]:
# Initialize BERT tokenizer and model for textual encoding
#tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
#bert_model = BertModel.from_pretrained('bert-base-uncased')

In [3]:
# Initialize CLIP processor and model for image encoding
#clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
#clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

# 4. Multimodal Attention Layer

In [4]:
#class CrossAttentionLayer(nn.Module):
#    def __init__(self, hidden_size, num_attention_heads):
#        super(CrossAttentionLayer, self).__init__()
#        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_attention_heads)
#        self.fc = nn.Linear(hidden_size, hidden_size)
#
#    def forward(self, text_features, image_features):
#        text_features = text_features.unsqueeze(0)
#        image_features = image_features.unsqueeze(0)
#
#        # Perform attention between text and image features
#        attn_output_text, _ = self.attention(text_features, image_features, image_features)
#        attn_output_image, _ = self.attention(image_features, text_features, text_features)
#        
#        combined_output = attn_output_text + attn_output_image
#        combined_output = self.fc(combined_output)
#        return combined_output

# 5. Entity Disambiguation

In [5]:
#class EntityDisambiguationHead(nn.Module):
 #   def __init__(self, hidden_size, num_candidates):
  #      super(EntityDisambiguationHead, self).__init__()
   #     self.fc1 = nn.Linear(hidden_size, 512)
    #    self.fc2 = nn.Linear(512, num_candidates)
     #   self.softmax = nn.Softmax(dim=-1)

    #def forward(self, features):
     #   x = F.relu(self.fc1(features))
      #  x = self.fc2(x)
       # return self.softmax(x)

# 6. Multimodal Entity Linking Model

In [6]:
#class MultimodalEntityLinkingModel(nn.Module):
#    def __init__(self, hidden_size, num_attention_heads, num_candidates):
#        super(MultimodalEntityLinkingModel, self).__init__()
#        self.text_encoder = bert_model
#        self.image_encoder = clip_model
#        self.cross_attention_layer = CrossAttentionLayer(hidden_size, num_attention_heads)
#        self.disambiguation_head = EntityDisambiguationHead(hidden_size, num_candidates)
#
#    def forward(self, text_input, image_input):
#        # Textual feature extraction using BERT
#        encoded_input = tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
#        text_output = self.text_encoder(**encoded_input).last_hidden_state  # shape: (batch_size, seq_len, hidden_size)
#        
#        # Image feature extraction using CLIP
#        inputs = clip_processor(text=[text_input], images=image_input, return_tensors="pt", padding=True)
#        outputs = self.image_encoder(**inputs)
#        image_features = outputs.image_embeds  # shape: (batch_size, embedding_dim)
#
#        # Apply cross-attention to fuse text and image features
#        combined_features = self.cross_attention_layer(text_output, image_features)

#        # Disambiguate and predict the entity
#        entity_scores = self.disambiguation_head(combined_features.squeeze(0))  # Removing batch dimension
#        return entity_scores

# 7. Dataset class

In [None]:
#class WikiDiverseDataset(Dataset):
#    def __init__(self, data, tokenizer, clip_processor):
#        self.data = data
#        self.tokenizer = tokenizer
#        self.clip_processor = clip_processor
#
#    def __len__(self):
#        return len(self.data)
#
#    def __getitem__(self, idx):
#        sample = self.data[idx]
#        text = sample['text']
#        image_path = sample['image_path']
#        label = sample['label']
#
#        # Process text
#        encoded_text = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)
#
        # Process image
 #       image = Image.open(image_path).convert("RGB")
 #       processed_image = self.clip_processor(images=image, return_tensors="pt").pixel_values.squeeze(0)

 #       return encoded_text, processed_image, label

# 8. Training and Validation Function 

def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_accuracy = 0

    for batch in tqdm(dataloader, desc="Training"):
        encoded_text, processed_image, labels = batch
        encoded_text = {key: val.squeeze(0).to(device) for key, val in encoded_text.items()}
        processed_image = processed_image.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        entity_scores = model(encoded_text, processed_image)
        loss = criterion(entity_scores, labels)
        loss.backward()
        optimizer.step()

        predictions = torch.argmax(entity_scores, dim=1)
        accuracy = (predictions == labels).float().mean()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    return total_loss / len(dataloader), total_accuracy / len(dataloader)

def validate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            encoded_text, processed_image, labels = batch
            encoded_text = {key: val.squeeze(0).to(device) for key, val in encoded_text.items()}
            processed_image = processed_image.to(device)
            labels = labels.to(device)

            entity_scores = model(encoded_text, processed_image)
            loss = criterion(entity_scores, labels)

            predictions = torch.argmax(entity_scores, dim=1)
            accuracy = (predictions == labels).float().mean()

            total_loss += loss.item()
            total_accuracy += accuracy.item()

    return total_loss / len(dataloader), total_accuracy / len(dataloader)

# 9. Training Loop

hidden_size = 768
num_attention_heads = 8
num_candidates = 10
num_epochs = 5
batch_size = 16
learning_rate = 2e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = WikiDiverseDataset(train_data, tokenizer, clip_processor)
val_dataset = WikiDiverseDataset(val_data, tokenizer, clip_processor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

model = MultimodalEntityLinkingModel(hidden_size, num_attention_heads, num_candidates).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    train_loss, train_acc = train_model(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_acc = validate_model(model, val_dataloader, criterion, device)

    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
    lr_scheduler.step()