# An overview
Siamese networks consist of two identical sub-networks that share weights and learn to compute the similarity between two input samples. The goal is to learn embeddings such that similar inputs are close in the embedding space, while dissimilar inputs are far apart. For the WikiDiverse dataset, where we have image-caption pairs, we can build a Siamese network that processes text and image data (or just one modality like text or image) and learns to compute similarity between two entities from the knowledge base.
* Siamese Network Structure: Two identical sub-networks that compute embeddings for input pairs and learn their similarity
* Application: For WikiDiverse, compute similarity between image-caption pairs to link knowledge-base entities.

In [None]:
# Required Libraries
import os
import json
import hashlib
import re
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from transformers import BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, default_collate

In [511]:
# 1. Dataset Path and Hyperparameters:
DATASET_PATH = r'C:\Users\Min Dator\aics-project\wikidiverse_w_cands'
IMAGE_DIR = os.path.join(DATASET_PATH, "wikinewsImgs")
BATCH_SIZE = 16
EPOCHS = 10
LEARNING_RATE = 0.0001
EMBED_DIM = 256
VOCAB_SIZE = 10000

In [522]:
# 2. Helper Functions:
def get_image_path(url, img_dir):
    filename = url.split('/')[-1]
    prefix = hashlib.md5(filename.encode()).hexdigest()
    suffix = re.sub(r'(\S+(?=\.(jpg|jpeg|png|svg)))', '', filename, flags=re.IGNORECASE)
    local_path = os.path.join(img_dir, f"{prefix}{suffix}".replace('.svg', '.png'))
    return local_path

In [535]:
# 3. WikiDiverse Dataset Class:
class WikiDiverseDataset(Dataset):
    def __init__(self, json_path, img_dir, transform=None, text_tokenizer=None):
        self.data = []
        self.img_dir = img_dir
        self.transform = transform
        self.text_tokenizer = text_tokenizer
        
        try:
            with open(json_path, 'r') as f:
                for entry in json.load(f):
                    try:
                        img1_path = os.path.join(img_dir, entry['image1'])
                        img2_path = os.path.join(img_dir, entry['image2'])
                        
                        if os.path.exists(img1_path) and os.path.exists(img2_path):
                            self.data.append(entry)
                    except Exception as e:
                        print(f"Error processing entry: {e}")
        except PermissionError:
            print(f"Permission denied for file: {json_path}. Please ensure you have read permissions.")
        
        print(f"Total valid entries: {len(self.data)}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        entry = self.data[idx]

        # Load images
        img1_path = os.path.join(self.img_dir, entry['image1'])
        img2_path = os.path.join(self.img_dir, entry['image2'])
        img1 = Image.open(img1_path).convert('RGB')
        img2 = Image.open(img2_path).convert('RGB')

        # Apply image transformations
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        # Tokenize text
        text1 = entry['text1']
        text2 = entry['text2']
        if self.text_tokenizer:
            text1 = self.text_tokenizer(text1, return_tensors="pt", padding=True, truncation=True, max_length=128)
            text2 = self.text_tokenizer(text2, return_tensors="pt", padding=True, truncation=True, max_length=128)

        label = torch.tensor(entry['label'], dtype=torch.float32)

        return img1, text1, img2, text2, label

def collate_fn(batch):
    batch = [sample for sample in batch if sample is not None]
    return default_collate(batch)

In [1]:
def load_data(json_path, img_dir):
    dataset = WikiDiverseDataset(json_path, img_dir)
    train_dataset, val_test_dataset = train_test_split(dataset, test_size=0.2, random_state=42)
    val_dataset, test_dataset = train_test_split(val_test_dataset, test_size=0.5, random_state=42)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
    
    return train_loader, val_loader, test_loader

# Load data
#train_loader, val_loader, test_loader = load_data(DATASET_PATH, IMAGE_DIR)

In [2]:
# 4. Data Augmentation:
from torchvision.transforms import Compose
from torchvision import models, transforms
from torchvision.transforms import Compose, ColorJitter
from transformers import BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score, f1_score
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, default_collate

transform = Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(degrees=30),
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [537]:
# 5. Cross-Attention Mechanism:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads=4, dropout=0.1):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.layer_norm = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        attn_output = self.dropout(attn_output)
        return self.layer_norm(attn_output + attn_output)


In [538]:
# 6. Image and Text Sub-Networks:
class ImageSubNetworkWithAttention(nn.Module):
    def __init__(self, embed_dim=256, num_heads=4):
        super(ImageSubNetworkWithAttention, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.features = nn.Sequential(*list(base_model.children())[:-1])  # Remove FC layer
        self.fc = nn.Sequential(
            nn.Linear(2048, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.cross_attention = CrossAttention(embed_dim, num_heads)

    def forward(self, x, text_features):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.fc(x)
        x = self.cross_attention(x.unsqueeze(1), text_features, text_features).squeeze(1)
        return x

class TextSubNetworkWithAttention(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_heads=4):
        super(TextSubNetworkWithAttention, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, dropout=0.5)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.cross_attention = CrossAttention(hidden_dim, num_heads)

    def forward(self, x, image_features):
        x = self.embedding(x)
        _, (hidden, _) = self.lstm(x)
        hidden = hidden.squeeze(0)
        hidden = self.cross_attention(hidden.unsqueeze(1), image_features, image_features).squeeze(1)
        return self.fc(hidden)

In [539]:
# 7. Siamese Network with Cross-Attention:
class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, vocab_size):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_net = ImageSubNetworkWithAttention()
        self.text_net = TextSubNetworkWithAttention(vocab_size)

    def forward(self, img1, img2, text1, text2):
        text_features1 = self.text_net(text1, None)
        text_features2 = self.text_net(text2, None)
        
        img_embedding1 = self.image_net(img1, text_features1)
        img_embedding2 = self.image_net(img2, text_features2)

        text_embedding1 = self.text_net(text1, img_embedding1)
        text_embedding2 = self.text_net(text2, img_embedding2)
        combined_embedding1 = torch.cat([img_embedding1, text_embedding1], dim=1)
        combined_embedding2 = torch.cat([img_embedding2, text_embedding2], dim=1)

        return combined_embedding1, combined_embedding2

In [544]:
#9. Training and Evaluation:
#Train the model and track performance using metrics like ROC AUC.
# Training and Evaluation
from sklearn.metrics import accuracy_score
import torch.nn.functional as F
import torch.nn.functional as F
from torchvision.transforms import Compose, ColorJitter
from sklearn.model_selection import train_test_split
# # Load data
train_loader, val_loader, test_loader = load_data(DATASET_PATH, IMAGE_DIR)

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetworkWithCrossAttention(VOCAB_SIZE).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CosineEmbeddingLoss()

def train(model, device, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for batch in loader:
        valid_batch = [sample for sample in batch if sample is not None]
        if len(valid_batch) == 5:  # Check if all elements are valid
            img1, text1, img2, text2, labels = valid_batch
            img1, text1, img2, text2, labels = img1.to(device), text1.to(device), img2.to(device), text2.to(device), labels.to(device)
            
            combined_embedding1, combined_embedding2 = model(img1, img2, text1, text2)
            labels = labels.unsqueeze(1).float()
            
            loss = criterion(combined_embedding1, combined_embedding2, labels)
            total_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        else:
            print(f"Warning: Invalid batch encountered. Skipping this iteration.")
    
    return total_loss / len(loader)  # Calculate average loss per batch

def evaluate(model, device, loader):
    model.eval()
    predictions = []
    true_labels = []
    with torch.no_grad():
        for batch in loader:
            img1, text1, img2, text2 = batch
            img1, text1, img2, text2 = img1.to(device), text1.to(device), img2.to(device), text2.to(device)
            
            combined_embedding1, combined_embedding2 = model(img1, img2, text1, text2)
            similarity = F.cosine_similarity(combined_embedding1, combined_embedding2)
            predictions.extend(similarity.cpu().numpy())
            true_labels.extend(batch[-1].cpu().numpy())
    
    return accuracy_score(true_labels, (predictions > 0.5).astype(int)) * 100  # Return accuracy as percentage

# Training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SiameseNetworkWithCrossAttention(VOCAB_SIZE).to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CosineEmbeddingLoss()

for epoch in range(EPOCHS):
    train_loss = train(model, device, train_loader, optimizer, criterion)
    print(f"Epoch {epoch+1}/{EPOCHS}, Train Loss: {train_loss:.4f}")
    
    # Evaluate on validation set
    val_accuracy = evaluate(model, device, val_loader)
    print(f"Epoch {epoch+1}/{EPOCHS}, Val Accuracy: {val_accuracy:.4f}")

# Final evaluation on test set
test_accuracy = evaluate(model, device, test_loader)
print(f"Test Accuracy: {test_accuracy:.4f}")

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\\\Users\\\\Min Dator\\\\AI project\\\\aics-project\\\\wikidiverse_w_cands\\wikinewsImgs\\abc77ae74d5b046c8b569191ccf39a3b.jpg'

In [None]:
#9. Training and Evaluation:
# #Train the model and track performance using metrics like ROC AUC.
# # Training and Evaluation

# import torch
# import torch.optim as optim
# from sklearn.metrics import roc_auc_score
# import matplotlib.pyplot as plt
# from torch.utils.data import DataLoader

# # Assuming other required variables (e.g., VOCAB_SIZE, LEARNING_RATE, etc.) are defined elsewhere

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = SiameseNetworkWithCrossAttention(VOCAB_SIZE).to(device)  # Ensure the model is initialized properly
# criterion = ContrastiveLoss()
# optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# train_losses = []
# val_losses = []

# for epoch in range(EPOCHS):
#     model.train()
#     train_loss = 0
#     for img1, text1, img2, text2, label in train_loader:
#         img1, text1, img2, text2, label = img1.to(device), text1.to(device), img2.to(device), text2.to(device), label.to(device)

#         # Forward pass through the Siamese network
#         img_embed1, img_embed2 = model.forward_once(img1), model.forward_once(img2)
#         text_embed1, text_embed2 = text_encoder(text1), text_encoder(text2)

#         # Combine embeddings
#         combined_embed1 = torch.cat([img_embed1, text_embed1], dim=1)
#         combined_embed2 = torch.cat([img_embed2, text_embed2], dim=1)

#         # Compute similarity and loss
#         loss = criterion(combined_embed1, combined_embed2, label)

#         # Backpropagation and optimizer step
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()    
        
#         train_loss += loss.item()

#     train_losses.append(train_loss / len(train_loader))

#     # Validation Phase
#     model.eval()
#     val_loss = 0
#     all_labels = []
#     all_preds = []
    
#     with torch.no_grad():
#         for img1, text1, label in test_loader:
#             img1, text1, label = img1.to(device), text1.to(device), label.to(device)
#             output1, output2 = model(img1, img1, text1, text1)
#             loss = criterion(output1, output2, label)
#             val_loss += loss.item()
            
#             # Calculate pairwise distance between outputs for prediction
#             preds = torch.nn.functional.pairwise_distance(output1, output2)  # Smaller distance => closer similarity
#             all_labels.extend(label.cpu().numpy())  # Collect actual labels
#             all_preds.extend(preds.cpu().numpy())  # Collect predicted distances

#     val_losses.append(val_loss / len(test_loader))

#     # Calculate ROC AUC score based on predicted distances and true labels
#     roc_auc = roc_auc_score(all_labels, all_preds)
#     print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, ROC AUC: {roc_auc:.4f}")

# # Plotting the training and validation loss curves
# plt.figure()
# plt.plot(range(1, EPOCHS + 1), train_losses, label="Train Loss")
# plt.plot(range(1, EPOCHS + 1), val_losses, label="Validation Loss")
# plt.xlabel("Epochs")
# plt.ylabel("Loss")
# plt.legend()
# plt.show()

In [None]:
def save_model(model, epoch, loss):
    torch.save({
        'model_state_dict': model.state_dict(),
        'epoch': epoch,
        'loss': loss,
    }, f'model_epoch_{epoch}.pth')

def load_model(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    return checkpoint['epoch'], checkpoint['loss']

# Save model after training
save_model(model, EPOCHS, train_loss)

# Load model for inference
checkpoint_path = f'model_epoch_{EPOCHS}.pth'
epoch, loss = load_model(model, checkpoint_path)

### END ###