# 1. An overview
Siamese networks consist of two identical sub-networks that share weights and learn to compute the similarity between two input samples. The goal is to learn embeddings such that similar inputs are close in the embedding space, while dissimilar inputs are far apart. For the WikiDiverse dataset, where we have image-caption pairs, we can build a Siamese network that processes text and image data (or just one modality like text or image) and learns to compute similarity between two entities from the knowledge base.
* Siamese Network Structure: Two identical sub-networks that compute embeddings for input pairs and learn their similarity
* Application: For WikiDiverse, compute similarity between image-caption pairs to link knowledge-base entities.

# 2. Libraries  

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from PIL import Image
import numpy as np

# 3. Data Preprocessing 
Assuming we have preprocessed text and image data, we need to encode both image and text inputs for the Siamese network. We will first tokenize and pad the text, and then use a pretrained ResNet50 model (for example) for feature extraction from the images.

# Text Processing

#Text Processing: Tokenize and pad text to ensure uniform input dimensions.
max_sequence_length = 100
tokenizer = Tokenizer(num_words=10000)
tokenizer.fit_on_texts(text_data)
text_sequences = tokenizer.texts_to_sequences(text_data)
text_input = pad_sequences(text_sequences, maxlen=max_sequence_length)

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_text(text, max_length=100):
    encoding = tokenizer(
        text,
        padding='max_length',
        truncation=True,
        max_length=max_length,
        return_tensors='pt'
    )
    return encoding['input_ids'].squeeze(0), encoding['attention_mask'].squeeze(0)

# Image processing 

def preprocess_image(img_path):
    img = image.load_img(img_path, target_size=(224, 224)) 
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    return preprocess_input(img_array)

image_input = np.array([preprocess_image(img_path) for img_path in image_data])

In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Function to preprocess images
def preprocess_image(img_path):
    img = Image.open(img_path).convert('RGB')
    img = transform(img)
    return img

# 4. Dataset Preparation

In [4]:
class WikiDiverseDataset(Dataset):
    def __init__(self, data, img_dir, max_length=100):
        self.data = data  # List of image-caption pairs
        self.img_dir = img_dir
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        img_path = item['image_path']  # Path to the image
        caption = item['caption']  # Text caption
        label = item['label']  # 1 for similar, 0 for dissimilar pairs

        # Preprocess image and text
        img = preprocess_image(img_path)
        input_ids, attention_mask = tokenize_text(caption, self.max_length)

        return img, input_ids, attention_mask, label

# 5. Siamese Sub-Networks

# Image Sub-Network

In [5]:
class ImageEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(ImageEmbedding, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # Remove the final classification layer
        self.fc = nn.Linear(base_model.fc.in_features, embedding_dim)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

# Text Sub-Network

In [6]:
class TextEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(TextEmbedding, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, embedding_dim)  # BERT outputs 768-dimensional embeddings

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_token = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token representation
        x = self.fc(cls_token)
        return x

# 6. Siamese Network

In [7]:
class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim=256):
        super(SiameseNetwork, self).__init__()
        self.image_embedding = ImageEmbedding(embedding_dim)
        self.text_embedding = TextEmbedding(embedding_dim)

    def forward(self, img1, img2, text1, mask1, text2, mask2):
        img_emb1 = self.image_embedding(img1)
        img_emb2 = self.image_embedding(img2)
        text_emb1 = self.text_embedding(text1, mask1)
        text_emb2 = self.text_embedding(text2, mask2)

        combined_emb1 = torch.cat((img_emb1, text_emb1), dim=1)
        combined_emb2 = torch.cat((img_emb2, text_emb2), dim=1)

        return combined_emb1, combined_emb2

# 7. Contrastive Loss Function

In [8]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, emb1, emb2, label):
        euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)
        loss = (label * torch.square(euclidean_distance)) + \
               ((1 - label) * torch.square(torch.clamp(self.margin - euclidean_distance, min=0.0)))
        return loss.mean()

# 8. Training Loop

In [9]:
# Initialize dataset and dataloader
dataset = WikiDiverseDataset(data, img_dir='path_to_images')
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize model, loss, and optimizer
model = SiameseNetwork(embedding_dim=256)
model = model.cuda()  # Move to GPU if available
criterion = ContrastiveLoss(margin=1.0)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for img1, input_ids1, mask1, label in dataloader:
        # Load second input from dataset (img2, text2)
        img2, input_ids2, mask2, _ = dataloader.dataset[np.random.randint(len(dataset))]

        # Move data to GPU
        img1, img2 = img1.cuda(), img2.cuda()
        input_ids1, mask1, input_ids2, mask2 = input_ids1.cuda(), mask1.cuda(), input_ids2.cuda(), mask2.cuda()
        label = label.cuda()

        # Forward pass
        emb1, emb2 = model(img1, img2, input_ids1, mask1, input_ids2, mask2)

        # Compute loss
        loss = criterion(emb1, emb2, label)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss/len(dataloader)}")

SyntaxError: invalid syntax (2397490576.py, line 1)

# 9. Evaluation 

In [None]:
model.eval()
total_loss = 0
with torch.no_grad():
    for img1, input_ids1, mask1, label in test_dataloader:
        img2, input_ids2, mask2, _ = test_dataloader.dataset[np.random.randint(len(test_dataset))]

        img1, img2 = img1.cuda(), img2.cuda()
        input_ids1, mask1, input_ids2, mask2 = input_ids1.cuda(), mask1.cuda(), input_ids2.cuda(), mask2.cuda()
        label = label.cuda()

        emb1, emb2 = model(img1, img2, input_ids1, mask1, input_ids2, mask2)
        loss = criterion(emb1, emb2, label)
        total_loss += loss.item()

print(f"Test Loss: {total_loss/len(test_dataloader)}")

# 1. Required Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from PIL import Image

# 2. Data Preprocessing

# Image Preprocessing
We use a standard ResNet50 model for extracting features from images. Images are resized and normalized before feeding into the model.

In [None]:
image_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def preprocess_image(img_path):
    image = Image.open(img_path).convert("RGB")
    return image_transforms(image)

# Text Preprocessing
We use a pretrained BERT tokenizer to tokenize and encode text inputs.

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def preprocess_text(text, max_length=100):
    encoded = tokenizer(text, padding='max_length', truncation=True, max_length=max_length, return_tensors="pt")
    return encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0)

# 3. Sub-Networks

# Image Embedding Network
This network uses ResNet50 to extract image features and reduces them to a fixed-size embedding.

In [None]:
class ImageEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(ImageEmbedding, self).__init__()
        base_model = models.resnet50(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])  # Remove the final classification layer
        self.fc = nn.Linear(2048, embedding_dim)  # Reduce to embedding dimension

    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)
        embedding = self.fc(features)
        return embedding

# Text Embedding Network
This network uses BERT to extract text embeddings, followed by a linear layer to reduce dimensions.

In [None]:
class TextEmbedding(nn.Module):
    def __init__(self, embedding_dim=256):
        super(TextEmbedding, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.fc = nn.Linear(768, embedding_dim)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.pooler_output  # CLS token representation
        embedding = self.fc(cls_embedding)
        return embedding

# 4. Cross-Attention Mechanism
The cross-attention mechanism allows one modality (e.g., text) to attend to another (e.g., image).

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(CrossAttention, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True)
        self.fc = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value):
        attn_output, _ = self.multihead_attn(query, key, value)
        output = self.fc(attn_output)
        return output

# 5. Siamese Network with Cross-Attention
This combines image and text embeddings using cross-attention and computes similarity scores.

In [None]:
class SiameseNetworkWithCrossAttention(nn.Module):
    def __init__(self, embedding_dim=256, num_heads=4):
        super(SiameseNetworkWithCrossAttention, self).__init__()
        self.image_embedding = ImageEmbedding(embedding_dim)
        self.text_embedding = TextEmbedding(embedding_dim)
        self.cross_attention = CrossAttention(embed_dim=embedding_dim, num_heads=num_heads)
        self.fc = nn.Linear(embedding_dim * 2, embedding_dim)

    def forward(self, img1, img2, text1, mask1, text2, mask2):
        # Extract embeddings
        img_emb1 = self.image_embedding(img1)
        img_emb2 = self.image_embedding(img2)
        text_emb1 = self.text_embedding(text1, mask1)
        text_emb2 = self.text_embedding(text2, mask2)

        # Apply cross-attention
        img_text_emb1 = self.cross_attention(text_emb1.unsqueeze(1), img_emb1.unsqueeze(1), img_emb1.unsqueeze(1)).squeeze(1)
        img_text_emb2 = self.cross_attention(text_emb2.unsqueeze(1), img_emb2.unsqueeze(1), img_emb2.unsqueeze(1)).squeeze(1)

        # Combine embeddings
        combined_emb1 = torch.cat((img_emb1, img_text_emb1), dim=1)
        combined_emb2 = torch.cat((img_emb2, img_text_emb2), dim=1)

        # Reduce to a single embedding
        combined_emb1 = self.fc(combined_emb1)
        combined_emb2 = self.fc(combined_emb2)

        return combined_emb1, combined_emb2

# 6. Loss Function
The contrastive loss function encourages similar pairs to have closer embeddings and dissimilar pairs to have distant embeddings.

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, emb1, emb2, label):
        euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)
        loss = (label * torch.square(euclidean_distance)) + \
               ((1 - label) * torch.square(torch.clamp(self.margin - euclidean_distance, min=0.0)))
        return loss.mean()

# 7. Training and Evaluation


In [None]:
# Training Loop
model = SiameseNetworkWithCrossAttention()
criterion = ContrastiveLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    model.train()
    for img1, img2, text1, mask1, text2, mask2, labels in train_loader:
        optimizer.zero_grad()
        emb1, emb2 = model(img1, img2, text1, mask1, text2, mask2)
        loss = criterion(emb1, emb2, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

In [None]:
# Evaluation
model.eval()
with torch.no_grad():
    for img1, img2, text1, mask1, text2, mask2, labels in test_loader:
        emb1, emb2 = model(img1, img2, text1, mask1, text2, mask2)
        euclidean_distance = torch.nn.functional.pairwise_distance(emb1, emb2)
        # Evaluate accuracy, precision, recall, etc.

# 8. DataLoader Example
Prepare the dataset and dataloader for training and evaluation.

In [None]:
class WikiDiverseDataset(Dataset):
    def __init__(self, image_paths, texts, labels, tokenizer, transform):
        self.image_paths = image_paths
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        text = self.texts[idx]
        label = self.labels[idx]

        img = preprocess_image(img_path)
        input_ids, attention_mask = preprocess_text(text)

        return img, input_ids, attention_mask, label

# Usage
train_dataset = WikiDiverseDataset(train_image_paths, train_texts, train_labels, tokenizer, image_transforms)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)