Written and developed by Aaron Lozhkin

# Load the Huggingface Snacks Dataset

In [None]:
!pip install transformers datasets

In [None]:
from datasets import load_dataset
ds = load_dataset("Matthijs/snacks")
ds

# Load Snacks Dataset into PyTorch

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import datasets

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Helper Function to Display Images

def displayTriplet(anchor, positive, negative, title=None, inverseTransform=True):
    fig, axs = plt.subplots(1, 3, figsize=(10, 4))

    if not(title is None):
        fig.suptitle(title)

    if inverseTransform:
        invtransform = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                                      std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                                 transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                                      std = [ 1., 1., 1. ]),
                                                 transforms.ToPILImage()
                                          ])
        anchor, positive, negative = invtransform(anchor), invtransform(positive), invtransform(negative)

    # Display anchor image
    axs[0].imshow(anchor)
    axs[0].set_title('Anchor')
    axs[0].axis('off')

    # Display positive image
    axs[1].imshow(positive)
    axs[1].set_title('Positive')
    axs[1].axis('off')

    # Display negative image
    axs[2].imshow(negative)
    axs[2].set_title('Negative')
    axs[2].axis('off')

    plt.show()

In [None]:
from datasets.packaged_modules import imagefolder
# Create a custom Siamese Triplet Dataset class to return triplets from the snacks dataset

class SiameseTripletDataset(Dataset):
    def __init__(self, ds, width, height):
        self.ds = ds
        self.width = width
        self.height = height
        self.num_classes = len(set(self.ds['label']))

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize((self.width, self.height)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize to [0, 1]
        ])

        # Organize data by labels and apply transform
        self.label_to_images = {}
        for idx, data_point in enumerate(self.ds):
            label = data_point['label']
            if label not in self.label_to_images:
                self.label_to_images[label] = []
            self.label_to_images[label].append(idx)

    def __len__(self):
      return len(self.ds)

    def show_images(self, idx):
        anchor_image, positive_image, negative_image = self[idx]
        displayTriplet(anchor_image, positive_image, negative_image)

    def getImage(self, idx, transform=True):
        if transform:
            return self.transform(self.ds[idx]['image']), self.ds[idx]['label']
        return self.ds[idx]['image'], self.ds[idx]['label']

    def __getitem__(self, idx):
        anchor_image, anchor_label = self.ds[idx]['image'], self.ds[idx]['label']

        # Generate a positive pair with the same label
        positive_idx = random.choice(self.label_to_images[anchor_label])
        positive_image, positive_label = self.ds[positive_idx]['image'], self.ds[positive_idx]['label']

        # Generate a negative pair with a different label
        labels = list(self.label_to_images.keys())
        labels.remove(anchor_label)
        negative_label = random.choice(labels)
        assert (negative_label != anchor_label)
        negative_idx = random.choice(self.label_to_images[negative_label])
        negative_image, negative_label = self.ds[negative_idx]['image'], self.ds[negative_idx]['label']

        anchor_image = self.transform(anchor_image)
        positive_image = self.transform(positive_image)
        negative_image = self.transform(negative_image)

        return anchor_image, positive_image, negative_image

In [None]:
# We resize the images to (244, 244) similar to what is used by ResNet

width, height = 224, 224

train_dataset = SiameseTripletDataset(ds=ds['train'], width=width, height=height)
test_dataset = SiameseTripletDataset(ds=ds['test'], width=width, height=height)
validation_dataset = SiameseTripletDataset(ds=ds['validation'], width=width, height=height)

## Visualize 10 Anchor, Positive, and Negative image triplets

In [None]:
import random
for i in range(10):
  idx = random.randint(0, len(train_dataset))
  train_dataset.show_images(idx)

# Load ResNet Model and Build Siamese Network with Triplet Loss
Pre-trained ResNet model utilized for image embedding architecture

In [None]:
import torch.nn.functional as F
from transformers import AutoFeatureExtractor, AutoModel

class SiameseNetwork(nn.Module):
    def __init__(self, embedding_dim=128):
        super(SiameseNetwork, self).__init__()
        self.embedding_dim = embedding_dim

        # Use a pre-trained ResNet model from Hugging Face
        self.resnet_model = AutoModel.from_pretrained("microsoft/resnet-50")

        # Fully connected layers for embedding
        self.fc = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),

            nn.Linear(1024, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.4),

            nn.Linear(512, self.embedding_dim),
            nn.ReLU(inplace=True)
        )

    def forward_one(self, x):
        # Forward pass for one input
        x = self.resnet_model(x)
        x = x.pooler_output.squeeze()

        if len(x.shape) == 1:
          # If it's a single input add a batch dimension
          x = x.unsqueeze(0)

        x = self.fc(x)
        return x

    def forward(self, anchor, positive, negative):
        # Forward pass for anchor, positive, and negative samples
        output_anchor = self.forward_one(anchor)
        output_positive = self.forward_one(positive)
        output_negative = self.forward_one(negative)
        return output_anchor, output_positive, output_negative


class TripletLoss(nn.Module):

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean()

In [None]:
# Create dataloaders for training and validation
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=True)

# Model Visualization

In [None]:
# Load the saved model from google drive
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=FILEID' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1T2BXHH4M2ZUXYBpTnWWiaAv18JZbxXsS" -O siamese_triplet_model_cache.pth && rm -rf /tmp/cookies.txt

In [None]:
siamese_net = SiameseNetwork().to(device)
if torch.cuda.is_available():
  siamese_net.load_state_dict(torch.load("siamese_triplet_model_test.pth"))
else:
  print("WARNING: Model will run extremely slow on cpu. If on colab, go to Runtime->Change Runtime Type->Hardware Accelerator->GPU.")
  siamese_net.load_state_dict(torch.load("/content/siamese_triplet_model_cache.pth", map_location=torch.device('cpu')))
siamese_net.eval()

In [None]:
import matplotlib.pyplot as plt
import random

data_iter = iter(validation_loader)

# Get a random batch from the validation dataset
batch = next(data_iter)

anchor, positive, negative = batch
anchor, positive, negative = anchor.to(device), positive.to(device), negative.to(device)

# Compute embeddings using the Siamese Network
anchor_embedding, positive_embedding, negative_embedding = siamese_net.forward(anchor, positive, negative)

normalized_anchor_embedding = torch.nn.functional.normalize(anchor_embedding, p=2, dim=1)
normalized_positive_embedding = torch.nn.functional.normalize(positive_embedding, p=2, dim=1)
normalized_negative_embedding = torch.nn.functional.normalize(negative_embedding, p=2, dim=1)

# Compute the disimilarity score using Euclidean distance
# disimilarity_positive = F.pairwise_distance(normalized_anchor_embedding, normalized_positive_embedding, keepdim=True)
# disimilarity_negative = F.pairwise_distance(normalized_anchor_embedding, normalized_negative_embedding, keepdim=True)

# Compute the similarity score using cosine similarity
similarity_positive = F.cosine_similarity(normalized_anchor_embedding, normalized_positive_embedding)
similarity_negative = F.cosine_similarity(normalized_anchor_embedding, normalized_negative_embedding)

for i in range(10):
  displayTriplet(anchor[i], positive[i], negative[i], title= "Validation Triplet " + str(i+1) + "\n" +
               f"Positive similarity Score: {similarity_positive[i].item():.4f}" + "\t".expandtabs() +
               f"Negative similarity Score: {similarity_negative[i].item():.4f}")

# Reverse Image Search Engine

In [None]:
# Utilize the train dataset to create a reverse image search engine.
# Embed the entire train dataset

siamese_net.eval()

train_embeddings = []

for i in tqdm(range(len(train_dataset)), desc="Generating Train Embeddings"):
    image = train_dataset.getImage(i, transform=True)[0]
    image = image.unsqueeze(0).to(device)
    train_embeddings.append(siamese_net.forward_one(image).detach().to('cpu'))

del image

In [None]:
# Take a query image from the test dataset
# Adjust the index to query for different types of images

query_idx = 99

query_image_initial = test_dataset.getImage(query_idx, transform=False)[0]
print("Query Image")
display(query_image_initial)

In [None]:
# Display the top 20 most similar images using the Siamese Network


query_image = test_dataset.getImage(query_idx, transform=True)[0]

# Embed the query image
with torch.no_grad():
    query_embedding = siamese_net.forward_one(query_image.unsqueeze(0).to(device)).detach().cpu()

# Calculate distances (cosine similarity)
similarities = []
for idx, dataset_embedding in enumerate(train_embeddings):
    # Calculate cosine similarity
    similarity = F.cosine_similarity(query_embedding, dataset_embedding)
    similarities.append((idx, similarity.detach().cpu().item()))

# Sort images by similarity
similarities.sort(key=lambda x: x[1], reverse=True)

# Get the top N similar images
top_n = 20

# Create a grid of subplots
n_rows = 5
n_cols = 4
fig, axs = plt.subplots(n_rows, n_cols, figsize=(30, 50))

# Get the top N similar images and their similarity scores
top_similar_images = [train_dataset.getImage(idx, transform=False)[0] for idx, _ in similarities[:top_n]]
top_similarity_scores = [similarity for _, similarity in similarities[:top_n]]

# Display the top N similar images in the grid
for i in range(len(top_similar_images)):
    row = i // n_cols
    col = i % n_cols
    ax = axs[row, col]

    # Display image
    ax.imshow(top_similar_images[i])
    ax.set_title(f'Search Result {i+1}' + f', Similarity Score: {top_similarity_scores[i]:.4f}', fontsize=12)
    ax.axis('off')

# Remove empty subplots if top_n is less than n_rows * n_cols
for i in range(len(top_similar_images), n_rows * n_cols):
    row = i // n_cols
    col = i % n_cols
    axs[row, col].axis('off')

plt.tight_layout()
plt.show()


# Model Training Code

In [None]:
# Create Siamese network and optimizer, and move them to CUDA if available
siamese_net = SiameseNetwork().to(device)
optimizer = optim.Adam(siamese_net.parameters(), lr=0.0001)

# Define triplet loss function
criterion = TripletLoss(margin=1.0)
loss_list = []
validation_loss_list = []

# Training Loop
num_epochs = 25
for epoch in range(num_epochs):

    # Create a tqdm progress bar for the training loader
    train_loader_iter = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
    for batch in train_loader_iter:
        anchor, positive, negative = [th.to(device) for th in batch]
        optimizer.zero_grad()
        output_anchor, output_positive, output_negative = siamese_net(anchor, positive, negative)
        loss = criterion(output_anchor, output_positive, output_negative)
        loss.backward()
        optimizer.step()
        train_loader_iter.set_postfix({'Loss': loss.item()})

    # Validation loop
    with torch.no_grad():
        total_validation_loss = 0.0
        num_batches = 0
        for batch in validation_loader:
            anchor_val, positive_val, negative_val = [th.to(device) for th in batch]
            output_anchor_val, output_positive_val, output_negative_val = siamese_net(anchor_val, positive_val, negative_val)
            validation_loss = criterion(output_anchor_val, output_positive_val, output_negative_val)
            total_validation_loss += validation_loss.item()
            num_batches += 1
        average_validation_loss = total_validation_loss / num_batches

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}, Validation Loss: {average_validation_loss}')
    loss_list.append(loss.item())
    validation_loss_list.append(average_validation_loss)

In [None]:
# Let's visualize the loss over time. Since we were able to utilize ResNet, we didn't have to train for many epochs.
epochs = range(1, len(loss_list) + 1)

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(epochs, loss_list, label='Training Loss', marker='o', linestyle='-')
plt.plot(epochs, validation_loss_list, label='Validation Loss', marker='o', linestyle='-')

# Customize the plot
plt.title('Training and Validation Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Display the plot
plt.show()



## Observe Test Set Accuracy
Accuracy is measured by the amount of triplets the network correctly identified as positive and negative

In [None]:
from datasets.utils.version import total_ordering
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize variables for accuracy computation
correct = 0
total = 0

with torch.no_grad():
    for batch in test_loader:
        anchor, positive, negative = [th.to(device) for th in batch]

        # Compute embeddings for anchor, positive, and negative samples
        anchor_embedding = siamese_net.forward_one(anchor)
        positive_embedding = siamese_net.forward_one(positive)
        negative_embedding = siamese_net.forward_one(negative)

        normalized_anchor_embedding = torch.nn.functional.normalize(anchor_embedding, p=2, dim=1)
        normalized_positive_embedding = torch.nn.functional.normalize(positive_embedding, p=2, dim=1)
        normalized_negative_embedding = torch.nn.functional.normalize(negative_embedding, p=2, dim=1)

        # Compute the similarity scores (e.g., cosine similarities)
        similarity_positive = torch.cosine_similarity(normalized_anchor_embedding, normalized_positive_embedding, dim=1)
        similarity_negative = torch.cosine_similarity(normalized_anchor_embedding, normalized_negative_embedding, dim=1)

        is_correct_positive = similarity_positive > similarity_negative
        correct += is_correct_positive.sum().item()
        total += anchor.size(0)

# Calculate accuracy for positive and negative pairs separately
accuracy = correct / total * 100.0

print(f'Overall Test Accuracy: {accuracy:.2f}%')


## Save Model For Later Use

In [None]:
torch.save(siamese_net.state_dict(), 'siamese_triplet_model_test.pth')