# 1. Overview
This implementation is a multimodal entity linking system that leverages both text and image information to perform entity disambiguation. The model uses BERT for text encoding and CLIP for image encoding. A cross-attention mechanism fuses text and image features, and a disambiguation head predicts the correct entity from a list of candidates. The system includes training and validation loops to optimize the model.
# Key Components
-Text Encoder: BERT-based textual feature extraction
- Image Encoder: CLIP-based image feature extraction
- Cross-Attention: Fuse text and image representations
- Entity Disambiguation Head: Predict the correct entity
- Training and Validation: Dataset handling, loss computation, and optimizer setup
# Data Requiremnt (each sample in the data includes) 
- 'text': The textual description or context
- 'image_path': Path to the corresponding image file
- 'label': Index of the correct entity in the candidate list

# 2. Required Libraries 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel, CLIPProcessor, CLIPModel, get_scheduler
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

# 3. Textual Encoder (BERT-based) and Image Encoder (CLIP)

In [2]:
# Initialize BERT tokenizer and model for textual encoding
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

In [3]:
# Initialize CLIP processor and model for image encoding
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

# 4. Multimodal Attention Layer

In [4]:
class CrossAttentionLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads):
        super(CrossAttentionLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_attention_heads)
        self.fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, text_features, image_features):
        text_features = text_features.unsqueeze(0)
        image_features = image_features.unsqueeze(0)

        # Perform attention between text and image features
        attn_output_text, _ = self.attention(text_features, image_features, image_features)
        attn_output_image, _ = self.attention(image_features, text_features, text_features)
        
        combined_output = attn_output_text + attn_output_image
        combined_output = self.fc(combined_output)
        return combined_output

# 5. Entity Disambiguation

In [5]:
class EntityDisambiguationHead(nn.Module):
    def __init__(self, hidden_size, num_candidates):
        super(EntityDisambiguationHead, self).__init__()
        self.fc1 = nn.Linear(hidden_size, 512)
        self.fc2 = nn.Linear(512, num_candidates)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, features):
        x = F.relu(self.fc1(features))
        x = self.fc2(x)
        return self.softmax(x)

# 6. Multimodal Entity Linking Model

In [6]:
class MultimodalEntityLinkingModel(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, num_candidates):
        super(MultimodalEntityLinkingModel, self).__init__()
        self.text_encoder = bert_model
        self.image_encoder = clip_model
        self.cross_attention_layer = CrossAttentionLayer(hidden_size, num_attention_heads)
        self.disambiguation_head = EntityDisambiguationHead(hidden_size, num_candidates)

    def forward(self, text_input, image_input):
        # Textual feature extraction using BERT
        encoded_input = tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
        text_output = self.text_encoder(**encoded_input).last_hidden_state  # shape: (batch_size, seq_len, hidden_size)
        
        # Image feature extraction using CLIP
        inputs = clip_processor(text=[text_input], images=image_input, return_tensors="pt", padding=True)
        outputs = self.image_encoder(**inputs)
        image_features = outputs.image_embeds  # shape: (batch_size, embedding_dim)

        # Apply cross-attention to fuse text and image features
        combined_features = self.cross_attention_layer(text_output, image_features)

        # Disambiguate and predict the entity
        entity_scores = self.disambiguation_head(combined_features.squeeze(0))  # Removing batch dimension
        return entity_scores

# 7. Dataset class

In [None]:
class WikiDiverseDataset(Dataset):
    def __init__(self, data, tokenizer, clip_processor):
        self.data = data
        self.tokenizer = tokenizer
        self.clip_processor = clip_processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        text = sample['text']
        image_path = sample['image_path']
        label = sample['label']

        # Process text
        encoded_text = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True)

        # Process image
        image = Image.open(image_path).convert("RGB")
        processed_image = self.clip_processor(images=image, return_tensors="pt").pixel_values.squeeze(0)

        return encoded_text, processed_image, label

# 8. Training and Validation Function 

In [None]:
def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_accuracy = 0

    for batch in tqdm(dataloader, desc="Training"):
        encoded_text, processed_image, labels = batch
        encoded_text = {key: val.squeeze(0).to(device) for key, val in encoded_text.items()}
        processed_image = processed_image.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        entity_scores = model(encoded_text, processed_image)
        loss = criterion(entity_scores, labels)
        loss.backward()
        optimizer.step()

        predictions = torch.argmax(entity_scores, dim=1)
        accuracy = (predictions == labels).float().mean()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

    return total_loss / len(dataloader), total_accuracy / len(dataloader)

def validate_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_accuracy = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            encoded_text, processed_image, labels = batch
            encoded_text = {key: val.squeeze(0).to(device) for key, val in encoded_text.items()}
            processed_image = processed_image.to(device)
            labels = labels.to(device)

            entity_scores = model(encoded_text, processed_image)
            loss = criterion(entity_scores, labels)

            predictions = torch.argmax(entity_scores, dim=1)
            accuracy = (predictions == labels).float().mean()

            total_loss += loss.item()
            total_accuracy += accuracy.item()

    return total_loss / len(dataloader), total_accuracy / len(dataloader)


# 9. Training Loop

In [None]:
hidden_size = 768
num_attention_heads = 8
num_candidates = 10
num_epochs = 5
batch_size = 16
learning_rate = 2e-5

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_dataset = WikiDiverseDataset(train_data, tokenizer, clip_processor)
val_dataset = WikiDiverseDataset(val_data, tokenizer, clip_processor)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

model = MultimodalEntityLinkingModel(hidden_size, num_attention_heads, num_candidates).to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    train_loss, train_acc = train_model(model, train_dataloader, optimizer, criterion, device)
    val_loss, val_acc = validate_model(model, val_dataloader, criterion, device)

    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
    lr_scheduler.step()