# Overview 
An implementation of the WikiDiverse entity linking system using attention mechanisms, thereby leveraging a multimodal attention-based approach to fuse both text and image information for entity disambiguation.

# 1. Required Libraries 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel, CLIPProcessor, CLIPModel
import torch.nn.functional as F
import PIL
from PIL import Image
from tqdm import tqdm

# 2. Textual Encoder (BERT-based) 
* BERT will process the textual input, including the context and entity descriptions.*

In [None]:
# Initialize the tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Tokenize input text
input_text = "The Lions versus the Packers (2007)."
encoded_input = tokenizer(input_text, return_tensors='pt', padding=True, truncation=True)

# Get the output of BERT (last hidden state)
text_output = bert_model(**encoded_input).last_hidden_state  # shape: (batch_size, seq_len, hidden_size)

# 3. Image Encoder (CLIP or ResNet)
*We can extract image features using a pretrained model like CLIP or ResNet, which can take both text and images as input and project them into a shared space.*

In [None]:
# Initialize CLIP processor and model
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

# Preprocess and extract image features
image = PIL.Image.open("path_to_image.jpg")
inputs = clip_processor(text=["The Lions versus the Packers (2007)."], images=image, return_tensors="pt", padding=True)

# Extract image and text features
outputs = clip_model(**inputs)
image_features = outputs.image_embeds  # shape: (batch_size, embedding_dim)
text_features = outputs.text_embeds  # shape: (batch_size, embedding_dim)

# 4. Multimodal Attention Layer
* The core of multimodal fusion will be the cross-attention mechanism. Cross-attention will help the model attend to relevant parts of the text and image, guiding the prediction towards the correct entity. * 

In [None]:
class MultimodalAttention(nn.Module):
    def __init__(self, hidden_size, num_attention_heads):
        super(MultimodalAttention, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_attention_heads)
        self.fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, text_features, image_features):
        # Cross attention between text and image features
        # Attention of text features to image features (or vice versa)
        text_features = text_features.unsqueeze(0)  # Add batch dimension
        image_features = image_features.unsqueeze(0)  # Add batch dimension

        # Cross-attention: text to image (or reverse)
        attn_output, attn_weights = self.attention(text_features, image_features, image_features)

        # Pass through a feedforward layer
        attn_output = self.fc(attn_output)
        return attn_output, attn_weights

# Initialize attention layer
multimodal_attention = MultimodalAttention(hidden_size=768, num_attention_heads=8)

# Apply attention to the text and image features
attn_output, attn_weights = multimodal_attention(text_features, image_features)

# 5. Final Entity Disambiguation
*After obtaining the fused multimodal features from attention, we use a simple MLP (multilayer perceptron) or a classification head to output the predicted entity from the candidate list.*

In [None]:
class EntityDisambiguationHead(nn.Module):
    def __init__(self, hidden_size, num_candidates):
        super(EntityDisambiguationHead, self).__init__()
        self.fc1 = nn.Linear(hidden_size, 512)
        self.fc2 = nn.Linear(512, num_candidates)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, features):
        x = F.relu(self.fc1(features))
        x = self.fc2(x)
        return self.softmax(x)

# Initialize the entity disambiguation head
num_candidates = 10  # Example: 10 candidates
disambiguation_head = EntityDisambiguationHead(hidden_size=768, num_candidates=num_candidates)

# Predict the entity using the multimodal features
entity_scores = disambiguation_head(attn_output.squeeze(0))  # Squeeze batch dimension


# 6. Loss and Optimization
*For training, use cross-entropy loss to optimize the model for entity disambiguation.* 

In [None]:
# Cross-entropy loss for entity disambiguation
labels = torch.tensor([correct_entity_index])  # The index of the correct entity in the candidate list
criterion = nn.CrossEntropyLoss()
loss = criterion(entity_scores, labels)

# Backpropagate and update the model
loss.backward()
optimizer.step()

# 2

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel, CLIPProcessor, CLIPModel
import torch.nn.functional as F
import PIL
from PIL import Image
from tqdm import tqdm

# Initialize BERT tokenizer and model for textual encoding
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Initialize CLIP processor and model for image encoding
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")

# Cross-Attention Layer
class CrossAttentionLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads):
        super(CrossAttentionLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_attention_heads)
        self.fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, text_features, image_features):
        # Cross-attention: text to image and image to text
        text_features = text_features.unsqueeze(0)  # Add batch dimension
        image_features = image_features.unsqueeze(0)  # Add batch dimension

        # Perform attention between text and image features
        attn_output_text, _ = self.attention(text_features, image_features, image_features)
        attn_output_image, _ = self.attention(image_features, text_features, text_features)
        
        # Combine the outputs (you can experiment with different strategies like sum, concat, etc.)
        combined_output = attn_output_text + attn_output_image
        combined_output = self.fc(combined_output)  # Feedforward layer after attention
        return combined_output

# Entity Disambiguation Head
class EntityDisambiguationHead(nn.Module):
    def __init__(self, hidden_size, num_candidates):
        super(EntityDisambiguationHead, self).__init__()
        self.fc1 = nn.Linear(hidden_size, 512)
        self.fc2 = nn.Linear(512, num_candidates)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, features):
        x = F.relu(self.fc1(features))
        x = self.fc2(x)
        return self.softmax(x)

# Model combining Text and Image features
class MultimodalEntityLinkingModel(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, num_candidates):
        super(MultimodalEntityLinkingModel, self).__init__()
        self.text_encoder = bert_model
        self.image_encoder = clip_model
        self.cross_attention_layer = CrossAttentionLayer(hidden_size, num_attention_heads)
        self.disambiguation_head = EntityDisambiguationHead(hidden_size, num_candidates)

    def forward(self, text_input, image_input):
        # Textual feature extraction using BERT
        encoded_input = tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
        text_output = self.text_encoder(**encoded_input).last_hidden_state  # shape: (batch_size, seq_len, hidden_size)
        
        # Image feature extraction using CLIP
        inputs = clip_processor(text=[text_input], images=image_input, return_tensors="pt", padding=True)
        outputs = self.image_encoder(**inputs)
        image_features = outputs.image_embeds  # shape: (batch_size, embedding_dim)

        # Apply cross-attention to fuse text and image features
        combined_features = self.cross_attention_layer(text_output, image_features)

        # Disambiguate and predict the entity
        entity_scores = self.disambiguation_head(combined_features.squeeze(0))  # Removing batch dimension
        return entity_scores

# Helper function for calculating loss and accuracy
def calculate_loss_and_accuracy(model, text_input, image_input, correct_entity_index):
    # Forward pass through the model
    entity_scores = model(text_input, image_input)

    # Compute the cross-entropy loss for entity disambiguation
    labels = torch.tensor([correct_entity_index])  # The index of the correct entity in the candidate list
    criterion = nn.CrossEntropyLoss()
    loss = criterion(entity_scores, labels)
    
    # Get the predicted entity
    predicted_entity = torch.argmax(entity_scores, dim=1)
    accuracy = (predicted_entity == labels).float().mean()
    
    return loss, accuracy

# Example Usage
# Initialize model parameters
hidden_size = 768  # BERT hidden size
num_attention_heads = 8
num_candidates = 10  # Number of candidates for entity linking

# Initialize the model
model = MultimodalEntityLinkingModel(hidden_size, num_attention_heads, num_candidates)

# Example input (text and image)
text_input = "The Lions versus the Packers (2007)."
image_path = "path_to_image.jpg"
image_input = Image.open(image_path)

# Assume we have the index of the correct entity (for example purposes)
correct_entity_index = 0

# Calculate loss and accuracy
loss, accuracy = calculate_loss_and_accuracy(model, text_input, image_input, correct_entity_index)
print(f"Loss: {loss.item()}, Accuracy: {accuracy.item()}")

In [1]:
from transformers import BertTokenizer

# Initialize BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Example text (from the WikiDiverse dataset)
text_input = "The Lions versus the Packers (2007)."

# Tokenize the text input
encoded_input = tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)

# Output the tokenized input
print(encoded_input)


{'input_ids': tensor([[  101,  1996,  7212,  6431,  1996, 15285,  1006,  2289,  1007,  1012,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


In [3]:
import hashlib
import re
import os
import requests
from PIL import Image
from io import BytesIO

# Path to store images
DATASET_PATH = r"C:\Users\Min Dator\aics-project\wikidiverse_data\images"

# Ensure the directory exists
if not os.path.exists(DATASET_PATH):
    os.makedirs(DATASET_PATH)

# Function to download and process the image
def download_image(url):
    try:
        # Get the image content
        response = requests.get(url)
        
        if response.status_code == 200:
            m_img = url.split('/')[-1]
            
            # Create a unique file name using MD5 hash
            prefix = hashlib.md5(m_img.encode()).hexdigest()
            suffix = re.sub(r'(\S+(?=\.(jpg|JPG|png|PNG|svg|SVG)))|(\S+(?=\.(jpeg|JPEG)))', '', m_img)
            
            # Construct the file path for the image
            file_path = os.path.join(DATASET_PATH, prefix + suffix)
            file_path = file_path.replace('.svg', '.png').replace('.SVG', '.png')  # Replace .svg with .png

            # Open the image and save it
            image = Image.open(BytesIO(response.content))
            image.save(file_path)

            print(f"Image saved at {file_path}")
        else:
            print(f"Failed to retrieve image. Status code: {response.status_code}")

    except Exception as e:
        print(f"Error downloading image: {e}")

# Example usage with data (replace 'data' with the actual dataset)
data = [
    ["The Lions versus the Packers (2007).", "https://upload.wikimedia.org/wikipedia/commons/0/06/DetroitLionsRunningPlay-2007.jpg", "sports", [
        ["Lions", "Organization", 4, 9, "https://en.wikipedia.org/wiki/Detroit_Lions"],
        ["Packers", "Organization", 21, 28, "https://en.wikipedia.org/wiki/Green_Bay_Packers"]
    ]]
]

# Iterate over the data and download images
for item in data:
    image_url = item[1]  # Get the image URL (second element in the data)
    download_image(image_url) 

Failed to retrieve image. Status code: 403
