In [None]:
!curl -LO https://data.caltech.edu/records/65de6-vp158/files/CUB_200_2011.tgz
!curl -LO https://data.caltech.edu/records/w9d68-gec53/files/segmentations.tgz

In [None]:
!tar -xzf CUB_200_2011.tgz
!tar -xzf segmentations.tgz

In [None]:
!pip install umap-learn

In [None]:
# check device (whether NVIDIA or AMD)
import torch
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.cuda.current_device())
print(torch.cuda.get_device_name())
print(torch.cuda.get_device_properties())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from collections import defaultdict

CURR_DIR = os.getcwd()
BASE_DIR = os.path.join(CURR_DIR, 'CUB_200_2011')
ATTRIBUTES_FILE = 'attributes.txt'
IMAGE_LABELS_FILE = os.path.join(BASE_DIR, 'attributes', 'image_attribute_labels.txt')
IMAGES_FILE = os.path.join(BASE_DIR, 'images.txt')
# set seed (for reproducibility)
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# --- load attributes-to-name ---
attribute_id_to_name = {}
print("Loading attribute definitions...")
with open(ATTRIBUTES_FILE, 'r') as f:
    for line in f:
        parts = line.strip().split(maxsplit=1)
        if len(parts) == 2:
            attr_id, attr_name = parts
            attribute_id_to_name[attr_id] = attr_name
print(f"-> Loaded {len(attribute_id_to_name)} attribute definitions.")

# --- load image-id-to attributes mappping ---
image_to_attributes = defaultdict(list)
print("Loading image attributes...")
with open(IMAGE_LABELS_FILE, 'r') as f:
    for line in f:
        parts = line.strip().split()
        image_id, attr_id, is_present = parts[0], parts[1], parts[2]

        if is_present == '1':
            attr_name = attribute_id_to_name.get(attr_id, "Unknown")
            image_to_attributes[image_id].append((attr_id, attr_name))
print(f"-> Loaded attributes for {len(image_to_attributes)} images.")

# --- load image ID to file path mapping ---
image_id_to_path = {}
print("Loading image paths...")
with open(IMAGES_FILE, 'r') as f:
    for line in f:
        parts = line.strip().split()
        if len(parts) == 2:
            image_id, image_path = parts[0], parts[1]
            image_id_to_path[image_id] = image_path
print(f"-> Loaded {len(image_id_to_path)} image paths.")

# select random images
all_image_ids = list(image_id_to_path.keys())
random_image_ids = random.sample(all_image_ids, 2)
# print(f"\nSelected 9 random image IDs: {random_image_ids}")

print("Generating plot...")
fig, axes = plt.subplots(1, 2, figsize=(15, 18))
fig.subplots_adjust(hspace=0.5, wspace=0.1)
axes = axes.flatten()

for i, image_id in enumerate(random_image_ids):
    ax = axes[i]

    # get image data
    image_relative_path = image_id_to_path[image_id]
    image_full_path = os.path.join(BASE_DIR, 'images', image_relative_path)
    image_name = os.path.basename(image_relative_path)

    # get class ID
    class_folder = os.path.dirname(image_relative_path)
    class_id = int(class_folder.split('.')[0])

    # get attribute data
    attributes = image_to_attributes.get(image_id, [])
    attr_list = [name for id, name in attributes[:4]]
    attr_string = "\n".join([f"- {a}" for a in attr_list])
    if len(attributes) > 4:
        attr_string += "\n- ..."

    img = mpimg.imread(image_full_path)
    ax.imshow(img)
    ax.axis('off')

    title = f"Class #: {class_id} | Image #: {image_id}\n{image_name}"
    ax.set_title(title, fontsize=9, wrap=True)

    ax.text(0, -0.05, attr_string,
            transform=ax.transAxes,
            fontsize=8,
            verticalalignment='top')

# save plot to file and display
output_filename = 'cub_attributes_grid.png'
plt.savefig(output_filename)
print(f"\n created and saved plot to '{output_filename}'")
plt.show()

In [None]:
#UMAP and T-SNE
from sklearn.manifold import TSNE
from torchvision import transforms
from PIL import Image
import matplotlib
import umap

def umap_tsne(model, embedding, labels, plot_title, output_file_name):
  """
  This fucntion is used for UMAP, TSNE before and after the training
  """

  print(f"Got {embedding.shape[0]} baseline embeddings.")

  # Run t-SNE
  tsne = TSNE(n_components=2, perplexity=30, max_iter=1000, random_state=42)
  tsne_embeddings = tsne.fit_transform(embedding)
  print("t-SNE finished.")

  # Run UMAP
  umap_reducer = umap.UMAP(n_components=2, random_state=42)
  umap_embeddings = umap_reducer.fit_transform(embedding)
  print("UMAP finished.")

  # Plot both
  fig, axes = matplotlib.pyplot.subplots(1, 2, figsize=(20, 10))

  # t-SNE Plot
  axes[0].set_title('t-SNE of Embeddings')
  for i in range(10): # We have 10 classes
      indices = (labels == i)
      axes[0].scatter(
          tsne_embeddings[indices, 0],
          tsne_embeddings[indices, 1],
          alpha=0.6,
          label=f'Class {i+1}'
      )
  axes[0].set_xlabel('t-SNE Component 1')
  axes[0].set_ylabel('t-SNE Component 2')
  axes[0].legend()

  # UMAP Plot
  axes[1].set_title('UMAP of Embeddings')
  for i in range(10): # We have 10 classes
      indices = (labels == i)
      axes[1].scatter(
          umap_embeddings[indices, 0],
          umap_embeddings[indices, 1],
          alpha=0.6,
          label=f'Class {i+1}'
      )
  axes[1].set_xlabel('UMAP Component 1')
  axes[1].set_ylabel('UMAP Component 2')
  axes[1].legend()

  matplotlib.pyplot.suptitle(plot_title, fontsize=16)
  plt.savefig(output_file_name)
  
  matplotlib.pyplot.show()
  matplotlib.pyplot.close()

  print(f"\n created and saved plot to '{output_file_name}'")

In [None]:
import torch.nn
import torch.nn.functional
import torchvision.models

class EmbeddingNet(torch.nn.Module):
    """
    A network that uses a pretrained ResNet18 backbone to extract embeddings from images. Used for visualizations with UMAP / t-SNE
    """
    def __init__(self, embedding_dim=128, backbone_name='resnet18'):
        super(EmbeddingNet, self).__init__()

        # Load the pretrained ResNet18
        if backbone_name == 'resnet18':
            self.backbone = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        elif backbone_name == 'resnet34':
            self.backbone = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT)
        elif backbone_name == 'resnet50':
            self.backbone = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}. Please choose 'resnet18', 'resnet34', or 'resnet50'.")

        # get the number of features from the layer *before* the classifier
        num_features = self.backbone.fc.in_features

        # Replace the final classifier layer ('fc') with our new embedding layer
        self.backbone.fc = torch.nn.Linear(num_features, embedding_dim)

    def forward(self, x):
        """
        Forward pass of the network.
        """
        # pass image through the modified ResNet
        embeddings = self.backbone(x)

        # normalize the embeddings (L2 normalization) to make the embedding vectors have a length of 1, standard for triplet loss.
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings

In [None]:
def get_all_embeddings(model, loader, device):
    model.eval() # Set model to eval mode
    all_embeddings = []
    all_labels = []

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Getting Embeddings"):
            images = images.to(device)
            embeddings = model(images)

            all_embeddings.append(embeddings.cpu())
            all_labels.append(labels.cpu())

    # Concatenate all batches
    all_embeddings = torch.cat(all_embeddings, dim=0)
    all_labels = torch.cat(all_labels, dim=0)

    return all_embeddings, all_labels

In [None]:
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm # library for progress bars
from torchvision import transforms
import pandas

class CUBEmbeddingDataset(Dataset):
    """
    A dataset that returns just one image and its label.
    Used for generating embeddings for the entire test set (precision@k)
    returns (image, label) pairs
    """
    def __init__(self, root_dir, split='test', transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(self.root_dir, 'images')
        self.transform = transform
        self.split = split

        # Load metadata (same as before)
        images_df = pandas.read_csv(os.path.join(self.root_dir, 'images.txt'), sep=' ', names=['img_id', 'filepath'])
        labels_df = pandas.read_csv(os.path.join(self.root_dir, 'image_class_labels.txt'), sep=' ', names=['img_id', 'class_id'])
        split_df = pandas.read_csv(os.path.join(self.root_dir, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_train'])

        data_df = images_df.merge(labels_df, on='img_id').merge(split_df, on='img_id')
        data_df['class_id'] = data_df['class_id'] - 1 # 0-indexed

        # Filter for first 10 classes. this should be removed for the real datase
        data_df = data_df[data_df['class_id'] < 10].reset_index(drop=True)

        # Filter by split
        target_split = 1 if self.split == 'train' else 0
        self.data_df = data_df[data_df['is_train'] == target_split].reset_index(drop=True)

        self.data_list = list(zip(self.data_df['filepath'], self.data_df['class_id']))

    def __len__(self):
        return len(self.data_list)

    def _load_image(self, filepath):
        full_path = os.path.join(self.image_dir, filepath)
        img = Image.open(full_path).convert('RGB')
        return img

    def __getitem__(self, index):
        # Get the image and its label
        img_path, label = self.data_list[index]
        img = self._load_image(img_path)

        if self.transform:
            img = self.transform(img)

        return img, label

In [None]:
class TripletCUBDataset(torch.utils.data.Dataset):
    """
    Custom DataLoader to use CUB-200 for triplet loss.
    """
    def __init__(self, root_dir, split='train', transform=None):
        self.root_dir = root_dir
        self.image_dir = os.path.join(self.root_dir, 'images')
        self.transform = transform
        self.split = split
        # Load metadata
        self._load_metadata()

    def _load_metadata(self):
        # Read images.txt: <image_id> <filepath>
        images_df = pandas.read_csv(
            os.path.join(self.root_dir, 'images.txt'),
            sep=' ',
            names=['img_id', 'filepath']
        )

        # Read image_class_labels.txt: <image_id> <class_id>
        labels_df = pandas.read_csv(
            os.path.join(self.root_dir, 'image_class_labels.txt'),
            sep=' ',
            names=['img_id', 'class_id']
        )

        # Read train_test_split.txt: <image_id> <is_train>
        split_df = pandas.read_csv(
            os.path.join(self.root_dir, 'train_test_split.txt'),
            sep=' ',
            names=['img_id', 'is_train']
        )

        # Merge dataframes
        data_df = images_df.merge(labels_df, on='img_id').merge(split_df, on='img_id')

        data_df['class_id'] = data_df['class_id'] - 1

        # Filter by split (1 for train, 0 for test)
        target_split = 1 if self.split == 'train' else 0
        self.data_df = data_df[data_df['is_train'] == target_split].reset_index(drop=True)

        # Create a list of all data points (filepath, class_id)
        self.data_list = list(zip(self.data_df['filepath'], self.data_df['class_id']))

        # Create a dictionary mapping class_id -> [list of indices in self.data_list]
        self.class_to_indices = {}
        for idx, (_, class_id) in enumerate(self.data_list):
            if class_id not in self.class_to_indices:
                self.class_to_indices[class_id] = []
            self.class_to_indices[class_id].append(idx)

        # Store a list of all unique class IDs
        self.classes = list(self.class_to_indices.keys())

    def __len__(self):
        return len(self.data_list)

    # def _load_image(self, filepath):
    #     """Helper to load an image from its relative path."""
    #     full_path = os.path.join(self.image_dir, filepath)
    #     img = Image.open(full_path).convert('RGB')
    #     return img

    def _load_image(self, filepath):
      """Helper to load an image from its relative path."""
      full_path = os.path.join(self.image_dir, filepath)

      try:
          img = Image.open(full_path).convert('RGB')
          return img
      except OSError as e:
          print(f"!!!!!!!!!!!!!! ERROR LOADING IMAGE !!!!!!!!!!!!!!")
          print(f"Failed to load image: {full_path}")
          print(f"Error: {e}")
          raise e

    def __getitem__(self, index):
        """
        Generate a triplet!
        """

        # get the ANCHOR
        anchor_path, anchor_class = self.data_list[index]
        anchor_img = self._load_image(anchor_path)

        # get a POSITIVE image (same class, different image)
        positive_indices = self.class_to_indices[anchor_class]

        # Ensure we don't pick the same image as the anchor
        positive_index = index
        while positive_index == index and len(positive_indices) > 1:
            positive_index = random.choice(positive_indices)

        positive_path, _ = self.data_list[positive_index]
        positive_img = self._load_image(positive_path)

        # get NEGATIVE image (different class)
        negative_class = anchor_class
        while negative_class == anchor_class:
            negative_class = random.choice(self.classes)

        negative_indices = self.class_to_indices[negative_class]
        negative_index = random.choice(negative_indices)

        negative_path, _ = self.data_list[negative_index]
        negative_img = self._load_image(negative_path)

        # apply transforms
        if self.transform:
            anchor_img = self.transform(anchor_img)
            positive_img = self.transform(positive_img)
            negative_img = self.transform(negative_img)

        return anchor_img, positive_img, negative_img

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm # library for progress bars
from torchvision import transforms
import pandas

EPOCHS=2
BATCH_SIZE=64
LEARN_RATE=0.001
NUM_WORKERS=0

# train transforms (additional augmentations)
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5), # % of images
    transforms.RandomRotation(10), # rotates by up to x degrees
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# test transforms (no additional augmentatations)
test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

## for tiplet-loss
# create datasets
train_dataset = TripletCUBDataset(
    root_dir=BASE_DIR,
    split='train',
    transform=train_transforms
)

test_dataset = TripletCUBDataset(
    root_dir=BASE_DIR,
    split='test',
    transform=test_transforms
)

# create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

## for embedding evaluation (precision@k)
eval_dataset = CUBEmbeddingDataset(
    root_dir=BASE_DIR,
    split='test',
    transform=test_transforms
)

eval_loader = DataLoader(
    eval_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
print(f"Eval dataset size (for embeddings): {len(eval_dataset)}")

# test loader
print("\nTesting the train_loader...")
try:
    # Get one batch
    anchor_batch, positive_batch, negative_batch = next(iter(train_loader))
    print(f"Anchor batch shape: {anchor_batch.shape}")
    print(f"Positive batch shape: {positive_batch.shape}")
    print(f"Negative batch shape: {negative_batch.shape}")
except Exception as e:
    print(f"Error loading data: {e}")

In [None]:
import time
from tqdm import tqdm
import torch.optim
from sklearn.metrics.pairwise import cosine_similarity

# Helpers: batch cosine stats
def batch_cosine_stats(a, p, n):
    """
    a, p, n are L2-normalized torch embeddings of shape (B, D).
    Returns numpy arrays sim_ap, sim_an (shape B,)
    and batch_top1 (float in [0,1]).
    """
    A = a.detach().cpu().numpy()
    P = p.detach().cpu().numpy()
    N = n.detach().cpu().numpy()

    sim_ap = cosine_similarity(A, P).diagonal()
    sim_an = cosine_similarity(A, N).diagonal()
    top1 = (sim_ap > sim_an).mean()
    return sim_ap, sim_an, float(top1)

# Training
def train(model, train_loader, optimizer, loss_fn, device):
    model.train()
    running_loss = 0.0

    # epoch accumulators
    sum_ap = 0.0
    sum_an = 0.0
    sum_top1 = 0.0
    n_examples = 0

    for (anchor_img, positive_img, negative_img) in tqdm(train_loader, desc="Training"):
        # Move data to device
        anchor_img  = anchor_img.to(device, non_blocking=True)
        positive_img= positive_img.to(device, non_blocking=True)
        negative_img= negative_img.to(device, non_blocking=True)

        optimizer.zero_grad()

        # Forward to get embeddings
        anchor_emb  = model(anchor_img)     # (B, D)
        positive_emb= model(positive_img)   # (B, D)
        negative_emb= model(negative_img)   # (B, D)

        # L2-normalize so cosine â‰ˆ dot product
        a = torch.nn.functional.normalize(anchor_emb,  p=2, dim=1)
        p = torch.nn.functional.normalize(positive_emb,p=2, dim=1)
        n = torch.nn.functional.normalize(negative_emb,p=2, dim=1)

        # Batch cosine stats (before mining)
        sim_ap, sim_an, batch_top1 = batch_cosine_stats(a, p, n)
        B = a.size(0)
        sum_ap   += float(sim_ap.sum())
        sum_an   += float(sim_an.sum())
        sum_top1 += batch_top1 * B
        n_examples += B

        loss = loss_fn(a, p, n)
        loss.backward()
        optimizer.step()

        running_loss += float(loss.item())

    epoch_loss = running_loss / len(train_loader)
    epoch_mean_ap = sum_ap / n_examples
    epoch_mean_an = sum_an / n_examples
    epoch_top1    = sum_top1 / n_examples

    return {
        "loss": epoch_loss,
        "mean_ap": epoch_mean_ap,
        "mean_an": epoch_mean_an,
        "top1": epoch_top1,
        "count": n_examples,
    }

In [None]:
# setup for each model; ResNet18 and ResNet34
embedding_dim = 512
margin = 1.0
loss_fn = torch.nn.TripletMarginLoss(margin=margin)

In [None]:
# initialize ResNet18
model_18 = EmbeddingNet(embedding_dim, backbone_name='resnet18')
model_18.to(device)
optimizer = torch.optim.Adam(model_18.parameters(), lr=LEARN_RATE)

In [None]:
# ResNet18 training loop
print("Starting ResNet18 training...")
print(f"Selected: \n epochs: {EPOCHS} \n batch size: {BATCH_SIZE} \n learn rate: {LEARN_RATE}")
train_hist = []
for epoch in range(EPOCHS):
    start_time = time.time()

    tr = train(model_18, train_loader, optimizer, loss_fn, device)

    train_hist.append(tr)

    elapsed = time.time() - start_time
    print(
        f"Epoch {epoch+1}/{EPOCHS} | {elapsed:.2f}s | "
        f"Train Loss: {tr['loss']:.4f} | "
        f"Train cos(AP): {tr['mean_ap']:.3f} | "
        f"Train cos(AN): {tr['mean_an']:.3f} | "
        f"Train Top1: {tr['top1']:.3f} "
    )
    print("\n")

print("ResNet18 training finished.")

# save model's weight and parameters
MODEL_PATH_18 = 'resnet18_model.pth'
torch.save(model_18.state_dict(), MODEL_PATH_18)

print(f"Model state_dict saved to {MODEL_PATH_18}")

In [None]:
# initialize ResNet34
model_34 = EmbeddingNet(embedding_dim, backbone_name='resnet34')
model_34.to(device)
optimizer = torch.optim.Adam(model_34.parameters(), lr=LEARN_RATE)

In [None]:
# ResNet34 training loop
print("Starting ResNet34 training...")
print(f"Selected: \n epochs: {EPOCHS} \n batch size: {BATCH_SIZE} \n learn rate: {LEARN_RATE}")
train_hist = []
for epoch in range(EPOCHS):
    start_time = time.time()

    tr = train(model_34, train_loader, optimizer, loss_fn, device)

    train_hist.append(tr)

    elapsed = time.time() - start_time
    print(
        f"Epoch {epoch+1}/{EPOCHS} | {elapsed:.2f}s | "
        f"Train Loss: {tr['loss']:.4f} | "
        f"Train cos(AP): {tr['mean_ap']:.3f} | "
        f"Train cos(AN): {tr['mean_an']:.3f} | "
        f"Train Top1: {tr['top1']:.3f} "
    )
    print("\n")

print("ResNet34 training finished.")

# save model's weight and parameters
MODEL_PATH_34 = 'resnet34_model.pth'
torch.save(model_34.state_dict(), MODEL_PATH_34)

print(f"Model state_dict saved to {MODEL_PATH_34}")

In [None]:
# manage GPU memory
import torch
import torch.nn as nn

def memory_stats():
    print(torch.cuda.memory_allocated()/1024**2)
    print(torch.cuda.memory_reserved()/1024**2)
    
# def allocate():
#     x = torch.randn(1024*1024, device='cuda')
#     memory_stats()

# import gc
memory_stats() 
# gc.collect()
# del model_18
# del model_34
# del optimizer
torch.cuda.empty_cache()
memory_stats()
print("Freed GPU memory.")

In [None]:
# loss function plot
loss_hist = [d['loss'] for d in train_hist]

# x-axis based on actual length (number of epochs)
epochs_range = range(1, len(loss_hist) + 1)

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.plot(epochs_range, loss_hist, marker='o', label='Training Loss')
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()

# save plot to file and display
training_loss_graph = 'training_loss_graph.png'
plt.savefig(training_loss_graph)

print(f"\n created and saved plot to '{training_loss_graph}'")

plt.show()

In [None]:
# test model
def test(model, test_loader, loss_fn, device):
    """
    Runs the model on the test set and returns a dictionary of metrics.
    """
    model.eval()
    running_loss = 0.0

    # epoch accumulators
    sum_ap = 0.0
    sum_an = 0.0
    sum_top1 = 0.0
    n_examples = 0

    # disable gradient calculations
    with torch.no_grad():
        for (anchor_img, positive_img, negative_img) in tqdm(test_loader, desc="Testing"):
            # move data to device
            anchor_img  = anchor_img.to(device, non_blocking=True)
            positive_img= positive_img.to(device, non_blocking=True)
            negative_img= negative_img.to(device, non_blocking=True)

            # get embeddings
            anchor_emb  = model(anchor_img)
            positive_emb= model(positive_img)
            negative_emb= model(negative_img)

            # L2-normalize already in EmbeddingNet (forward method)

            # Batch cosine stats
            sim_ap, sim_an, batch_top1 = batch_cosine_stats(anchor_emb, positive_emb, negative_emb)
            B = anchor_emb.size(0)
            sum_ap   += float(sim_ap.sum())
            sum_an   += float(sim_an.sum())
            sum_top1 += batch_top1 * B
            n_examples += B

            loss = loss_fn(anchor_emb, positive_emb, negative_emb)
            running_loss += float(loss.item())

    # final epoch metrics
    epoch_loss = running_loss / len(test_loader)
    epoch_mean_ap = sum_ap / n_examples
    epoch_mean_an = sum_an / n_examples
    epoch_top1    = sum_top1 / n_examples

    return {
        "loss": epoch_loss,
        "mean_ap": epoch_mean_ap,
        "mean_an": epoch_mean_an,
        "top1": epoch_top1,
        "count": n_examples,
    }

In [None]:
# test on ResNet18
# using same margin and loss_fn from training run
print("\nTesting ResNet18 Model...")

# initialize and load model to device
loaded_model_18 = EmbeddingNet(embedding_dim, backbone_name='resnet18')
loaded_model_18.to(device)
loaded_model_18.load_state_dict(torch.load(MODEL_PATH_18, map_location=device))

test_stats_18 = test(model_18, test_loader, loss_fn, device)

print("--- ResNet18 Test Results ---")
print(f"  Loss: {test_stats_18['loss']:.4f}")
print(f"  Top-1 Acc: {test_stats_18['top1'] * 100:.2f}%")
print(f"  Avg. Pos Cosine: {test_stats_18['mean_ap']:.4f}")
print(f"  Avg. Neg Cosine: {test_stats_18['mean_an']:.4f}")

In [None]:
# test on ResNet34
from PIL import ImageFile
# allow truncated images to load (test dataset has truncated images)
ImageFile.LOAD_TRUNCATED_IMAGES = True

print("\nTesting ResNet34 Model...")

# initialize and load model to device
loaded_model_34 = EmbeddingNet(embedding_dim, backbone_name='resnet34')
loaded_model_34.to(device)
loaded_model_34.load_state_dict(torch.load(MODEL_PATH_34, map_location=device))

test_stats_34 = test(model_34, test_loader, loss_fn, device)

print("--- ResNet34 Test Results ---")
print(f"  Loss: {test_stats_34['loss']:.4f}")
print(f"  Top-1 Acc: {test_stats_34['top1'] * 100:.2f}%")
print(f"  Avg. Pos Cosine: {test_stats_34['mean_ap']:.4f}")
print(f"  Avg. Neg Cosine: {test_stats_34['mean_an']:.4f}")

In [None]:
# precision@k for ResNet18
# get embeddings (using existing code before first UMAP / t-SNE plots)
embeddings_18, labels_18 = get_all_embeddings(model_18, eval_loader, device)
print(f"got ResNet18 {embeddings_18.shape[0]} embeddings with {labels_18.shape[0]} labels.")

In [None]:
# precision@k for ResNet34
# get embeddings (using existing code before first UMAP / t-SNE plots)
embeddings_34, labels_34 = get_all_embeddings(model_34, eval_loader, device)
print(f"got ResNet34 {embeddings_34.shape[0]} embeddings with {labels_34.shape[0]} labels.")

In [None]:
def calculate_precision_at_k(embeddings, labels, k, device):
    """
    Calculates Precision@k for a given set of embeddings and labels.
    Assumes embeddings are L2-normalized (which your model does).
    """
    num_embeddings = len(embeddings)

    embeddings_gpu = embeddings.to(device)
    labels_gpu = labels.to(device)

    # cosine similarity; embeddings are L2-normalized, dot product = cosine similarity
    similarity_matrix = torch.mm(embeddings_gpu, embeddings_gpu.T)

    precisions = []

    for i in tqdm(range(num_embeddings), desc=f"Calculating P@{k}"):
        query_label = labels_gpu[i]

        similarities = similarity_matrix[i]

        # sort and get top k+1 indices (ignoring self-match)
        # largest=True gets highest similarity
        _, top_k_indices = torch.topk(similarities, k + 1, largest=True)

        # remove self-match (highest similarity)
        top_k_indices = top_k_indices[1:]

        # labels of top k retrieved images
        retrieved_labels = labels_gpu[top_k_indices]

        # count matches with query label
        num_correct = torch.sum(retrieved_labels == query_label).item()

        # precision for this query
        precision = num_correct / k
        precisions.append(precision)

    # average precision across all queries
    return np.mean(precisions)

In [None]:
# ensure using device = cuda (NVIDIA or AMD)
print("\n--- ResNet18 Retrieval Results ---")
p_at_1 = calculate_precision_at_k(embeddings_18, labels_18, k=1, device=device)
p_at_5 = calculate_precision_at_k(embeddings_18, labels_18, k=5, device=device)
p_at_10 = calculate_precision_at_k(embeddings_18, labels_18, k=10, device=device)

print(f"  Precision@1:  {p_at_1 * 100:.2f}%")
print(f"  Precision@5:  {p_at_5 * 100:.2f}%")
print(f"  Precision@10: {p_at_10 * 100:.2f}%")

In [None]:
# ensure using device = cuda (NVIDIA or AMD)
print("\n--- ResNet34 Retrieval Results ---")
p_at_1 = calculate_precision_at_k(embeddings_34, labels_34, k=1, device=device)
p_at_5 = calculate_precision_at_k(embeddings_34, labels_34, k=5, device=device)
p_at_10 = calculate_precision_at_k(embeddings_34, labels_34, k=10, device=device)

print(f"  Precision@1:  {p_at_1 * 100:.2f}%")
print(f"  Precision@5:  {p_at_5 * 100:.2f}%")
print(f"  Precision@10: {p_at_10 * 100:.2f}%")

In [None]:
#plot after training resnet18
# Use the `get_all_embeddings` function and `eval_loader` to get embeddings from the model
trained_resnet18_embeddings, trained_resnet18_labels = get_all_embeddings(model_18, eval_loader, device)

print("plots before trainig")
#fucntion call for the plots
umap_tsne(model_18, trained_resnet18_embeddings, trained_resnet18_labels, "Visualization of images after training (ResNet18)", 'trained_model_plot_resnet18.png')

In [None]:
#plot after training resnet34
# Use the `get_all_embeddings` function and `eval_loader` to get embeddings from the model
trained_resnet34_embeddings, trained_resnet34_labels = get_all_embeddings(model_34, eval_loader, device)

print("plots before trainig")
#fucntion call for the plots
umap_tsne(model_34, trained_resnet34_embeddings, trained_resnet34_labels, "Visualization of images after training (ResNet34)", 'trained_model_plot_resnet34.png')