# 1. An overview
Siamese networks consist of two identical sub-networks that share weights and learn to compute the similarity between two input samples. The goal is to learn embeddings such that similar inputs are close in the embedding space, while dissimilar inputs are far apart. For the WikiDiverse dataset, where we have image-caption pairs, we can build a Siamese network that processes text and image data (or just one modality like text or image) and learns to compute similarity between two entities from the knowledge base.
* Siamese Network Structure: Two identical sub-networks that compute embeddings for input pairs and learn their similarity
* Application: For WikiDiverse, compute similarity between image-caption pairs to link knowledge-base entities.

# 2. Required Libraries

In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score  
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# Dataset path
DATASET_PATH = r"C:\Users\Min Dator\aics-project\wikidiverse_w_cands\wikidiverse_w_cands"
IMAGE_DIR = os.path.join(DATASET_PATH, "images")

# Hyperparameters
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.0001
EMBED_DIM = 256
VOCAB_SIZE = 10000

# 3. Data Class
Prepare the dataset and dataloader for training and evaluation.

In [3]:
class WikiDiverseDataset(Dataset):
    def __init__(self, image_paths, captions, labels, tokenizer, transform=None):
        self.image_paths = image_paths
        self.captions = captions
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        
        text = self.tokenizer(self.captions[idx], truncation=True, padding="max_length", max_length=100, return_tensors="pt")
        text_input = text["input_ids"].squeeze(0)
        
        return image, text_input, torch.tensor(self.labels[idx], dtype=torch.float32)

# 4. Datasets and Dataloaders

In [None]:
# Load Data 
data_file = os.path.join(DATASET_PATH, "captions_and_labels.csv")  
df = pd.read_csv(data_file)

# Generate full paths to images
image_paths = [os.path.join(IMAGE_DIR, filename) for filename in df['image_filename']]

# Captions and labels
captions = df['caption'].tolist()
labels = df['label'].tolist()

In [None]:
# Tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Train-test split
train_paths, test_paths, train_captions, test_captions, train_labels, test_labels = train_test_split(
    image_paths, captions, labels, test_size=0.2, random_state=42
)

# Image transformation
transform = transforms.Compose([
    transforms.Resize((100, 100)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Datasets and Dataloaders
train_dataset = WikiDiverseDataset(train_paths, train_captions, train_labels, tokenizer, transform)
test_dataset = WikiDiverseDataset(test_paths, test_captions, test_labels, tokenizer, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

# 5. Cross Attention Module
The cross-attention mechanism allows one modality (e.g., text) to attend to another (e.g., image).

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        attn_output = self.dropout(attn_output)
        return self.layer_norm(query + attn_output)

# 6. Sub-Networks with Attention

In [None]:
class ImageSubNetworkWithAttention(nn.Module):
    def __init__(self, embed_dim=256, num_heads=4):
        super(ImageSubNetworkWithAttention, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])  # Remove FC layer
        self.fc = nn.Sequential(
            nn.Linear(2048, embed_dim),
            nn.ReLU()
        )
        self.cross_attention = CrossAttention(embed_dim, num_heads)

    def forward(self, x, text_features):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        x = self.cross_attention(x.unsqueeze(1), text_features, text_features).squeeze(1)
        return x

class TextSubNetworkWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_heads=4):
        super(TextSubNetworkWithAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU()
        )
        self.cross_attention = CrossAttention(hidden_dim, num_heads)

    def forward(self, x, image_features):
        x = self.embedding(x)
        _, (hidden, _) = self.lstm(x)
        hidden = hidden.squeeze(0)
        hidden = self.cross_attention(hidden.unsqueeze(1), image_features, image_features).squeeze(1)
        return self.fc(hidden)

# 7. Siamese Network with Cross-Attention
This combines image and text embeddings using cross-attention and computes similarity scores.

In [None]:
class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, vocab_size):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_net = ImageSubNetworkWithAttention()
        self.text_net = TextSubNetworkWithAttention(vocab_size)

    def forward(self, img1, img2, text1, text2):
        text_features1 = self.text_net(text1, None)
        text_features2 = self.text_net(text2, None)
        
        img_embedding1 = self.image_net(img1, text_features1)
        img_embedding2 = self.image_net(img2, text_features2)

        text_embedding1 = self.text_net(text1, img_embedding1)
        text_embedding2 = self.text_net(text2, img_embedding2)

        combined_embedding1 = torch.cat([img_embedding1, text_embedding1], dim=1)
        combined_embedding2 = torch.cat([img_embedding2, text_embedding2], dim=1)
        
        return combined_embedding1, combined_embedding2

# 8. Loss Function
The contrastive loss function encourages similar pairs to have closer embeddings and dissimilar pairs to have distant embeddings.

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = torch.nn.functional.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          label * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss

# 9. Training and Evaluation

In [None]:
# Training and Evaluation
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetworkWithCrossAttention(VOCAB_SIZE).to(device)
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []
val_losses = []

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for img1, text1, label in train_loader:
        img1, text1, label = img1.to(device), text1.to(device), label.to(device)
        optimizer.zero_grad()
        output1, output2 = model(img1, img1, text1, text1)
        loss = criterion(output1, output2, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_losses.append(train_loss / len(train_loader))

    # Validation
    model.eval()
    val_loss = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for img1, text1, label in test_loader:
            img1, text1, label = img1.to(device), text1.to(device), label.to(device)
            output1, output2 = model(img1, img1, text1, text1)
            val_loss += criterion(output1, output2, label).item()
            preds = torch.nn.functional.pairwise_distance(output1, output2)
            all_labels.extend(label.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    val_losses.append(val_loss / len(test_loader))

    # Calculate evaluation metrics
    roc_auc = roc_auc_score(all_labels, all_preds)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, ROC AUC: {roc_auc:.4f}")

# Plot Loss
plt.figure()
plt.plot(range(1, EPOCHS + 1), train_losses, label="Train Loss")
plt.plot(range(1, EPOCHS + 1), val_losses, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
#print("\nStarting Evaluation...")
#accuracy, precision, recall, f1 = evaluate_model(model, test_loader)

##### END #### 