In [5]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from transformers import BertTokenizer, BertModel
from torchvision import models, transforms
from PIL import Image
from tqdm import tqdm
import requests
import json

# Paths to dataset and image folder
#DATASET_PATH = "path_to_wikidiverse.json"  
DATASET_PATH = r"C:\Users\Min Dator\aics-project\wikidiverse.json"
IMAGES_FOLDER = "downloaded_images/"
os.makedirs(IMAGES_FOLDER, exist_ok=True)

# Initialize BERT tokenizer and model for textual encoding
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Initialize ResNet model for image encoding (using ResNet-50 here)
resnet_model = models.resnet50(weights='ResNet50_Weights.DEFAULT')
resnet_model = nn.Sequential(*list(resnet_model.children())[:-1])  # Remove final classification layer

# Image preprocessing pipeline
image_preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Cross-Attention Layer
class CrossAttentionLayer(nn.Module):
    def __init__(self, hidden_size, num_attention_heads):
        super(CrossAttentionLayer, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_attention_heads)
        self.fc = nn.Linear(hidden_size, hidden_size)

    def forward(self, text_features, image_features):
        text_features = text_features.unsqueeze(0)  # Add batch dimension
        image_features = image_features.unsqueeze(0)  # Add batch dimension

        attn_output_text, _ = self.attention(text_features, image_features, image_features)
        attn_output_image, _ = self.attention(image_features, text_features, text_features)

        combined_output = attn_output_text + attn_output_image
        combined_output = self.fc(combined_output)
        return combined_output

# Entity Disambiguation Head
class EntityDisambiguationHead(nn.Module):
    def __init__(self, hidden_size, num_candidates):
        super(EntityDisambiguationHead, self).__init__()
        self.fc1 = nn.Linear(hidden_size, 512)
        self.fc2 = nn.Linear(512, num_candidates)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, features):
        x = F.relu(self.fc1(features))
        x = self.fc2(x)
        return self.softmax(x)

# Model combining Text and Image features
class MultimodalEntityLinkingModel(nn.Module):
    def __init__(self, hidden_size, num_attention_heads, num_candidates):
        super(MultimodalEntityLinkingModel, self).__init__()
        self.text_encoder = bert_model
        self.image_encoder = resnet_model
        self.cross_attention_layer = CrossAttentionLayer(hidden_size, num_attention_heads)
        self.disambiguation_head = EntityDisambiguationHead(hidden_size, num_candidates)

    def forward(self, text_input, image_input):
        encoded_input = tokenizer(text_input, return_tensors='pt', padding=True, truncation=True)
        text_output = self.text_encoder(**encoded_input).last_hidden_state

        image_input = image_input.unsqueeze(0)  # Add batch dimension
        image_features = self.image_encoder(image_input)
        image_features = image_features.view(image_features.size(0), -1)

        combined_features = self.cross_attention_layer(text_output, image_features)
        entity_scores = self.disambiguation_head(combined_features.squeeze(0))
        return entity_scores

# Helper Functions for Preprocessing
def download_image(image_url):
    image_name = image_url.split("/")[-1]
    local_path = os.path.join(IMAGES_FOLDER, image_name)

    if not os.path.exists(local_path):
        response = requests.get(image_url, stream=True)
        if response.status_code == 200:
            with open(local_path, "wb") as f:
                f.write(response.content)
    return local_path

def preprocess_dataset(dataset_path):
    preprocessed_data = []

    with open(dataset_path, "r") as f:
        data = json.load(f)

    for item in tqdm(data, desc="Preprocessing Dataset"):
        text = item.get("caption")
        image_url = item.get("image_url")
        entities = item.get("entities")

        tokenized_text = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

        try:
            local_image_path = download_image(image_url)
            image = Image.open(local_image_path).convert("RGB")
            image_tensor = image_preprocess(image)
        except Exception as e:
            print(f"Error processing image {image_url}: {e}")
            continue

        preprocessed_data.append({
            "text": tokenized_text,
            "image": image_tensor,
            "entities": entities,
        })

    return preprocessed_data

# Example Usage
hidden_size = 768
num_attention_heads = 8
num_candidates = 10

model = MultimodalEntityLinkingModel(hidden_size, num_attention_heads, num_candidates)

# Preprocess the dataset
preprocessed_data = preprocess_dataset(DATASET_PATH)

# Example Training
sample = preprocessed_data[0]
text_input = sample["text"]
image_input = sample["image"]
correct_entity_index = 0  # Example correct index

# Calculate loss and accuracy
def calculate_loss_and_accuracy(model, text_input, image_input, correct_entity_index):
    entity_scores = model(text_input, image_input)
    labels = torch.tensor([correct_entity_index])
    criterion = nn.CrossEntropyLoss()
    loss = criterion(entity_scores, labels)
    predicted_entity = torch.argmax(entity_scores, dim=1)
    accuracy = (predicted_entity == labels).float().mean()
    return loss, accuracy

loss, accuracy = calculate_loss_and_accuracy(model, text_input, image_input, correct_entity_index)
print(f"Loss: {loss.item()}, Accuracy: {accuracy.item()}")

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\Min Dator\\aics-project\\wikidiverse.json'