Installing packages

In [None]:
!pip install torch torchvision sentence-transformers matplotlib seaborn --quiet

[2K   [91m━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.1/363.4 MB[0m [31m33.4 MB/s[0m eta [36m0:00:07[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m26.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m27.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m12.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━

Import statements

In [None]:
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets import OxfordIIITPet
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

from sentence_transformers import SentenceTransformer

import matplotlib.pyplot as plt
import seaborn as sns

# Set reproducibility
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


Defining basic image transforms and load dataset

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),   # ResNet-18 expects 224×224
    transforms.ToTensor(),
    # normalize with ImageNet means/stds
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])

# Download / load OxfordIIITPet
dataset = OxfordIIITPet(
    root='.',
    download=True,
    transform=transform,
    target_types='category'   # returns (image, int_label)
)

print("Total examples:", len(dataset))
print("Example shape:", dataset[0][0].shape, "| Label:", dataset[0][1], "| Class-name:", dataset.classes[dataset[0][1]])

# Print all class names
print("\nAll {} classes (pet breeds):".format(len(dataset.classes)))
print(dataset.classes)


Total examples: 3680
Example shape: torch.Size([3, 224, 224]) | Label: 0 | Class-name: Abyssinian

All 37 classes (pet breeds):
['Abyssinian', 'American Bulldog', 'American Pit Bull Terrier', 'Basset Hound', 'Beagle', 'Bengal', 'Birman', 'Bombay', 'Boxer', 'British Shorthair', 'Chihuahua', 'Egyptian Mau', 'English Cocker Spaniel', 'English Setter', 'German Shorthaired', 'Great Pyrenees', 'Havanese', 'Japanese Chin', 'Keeshond', 'Leonberger', 'Maine Coon', 'Miniature Pinscher', 'Newfoundland', 'Persian', 'Pomeranian', 'Pug', 'Ragdoll', 'Russian Blue', 'Saint Bernard', 'Samoyed', 'Scottish Terrier', 'Shiba Inu', 'Siamese', 'Sphynx', 'Staffordshire Bull Terrier', 'Wheaten Terrier', 'Yorkshire Terrier']


Split class names into seen and unseen

In [None]:
all_classes = dataset.classes
num_classes = len(all_classes)  # Should be 37

# Shuffle deterministically
shuffled = all_classes.copy()
random.shuffle(shuffled)

split_ratio = 0.6
split_index = int(num_classes * split_ratio)

seen_classes = shuffled[:split_index]
unseen_classes = shuffled[split_index:]

print(f"Seen classes ({len(seen_classes)}): {seen_classes}")
print(f"Unseen classes ({len(unseen_classes)}): {unseen_classes}")


Seen classes (22): ['Bengal', 'Maine Coon', 'English Cocker Spaniel', 'British Shorthair', 'Newfoundland', 'Ragdoll', 'Russian Blue', 'Beagle', 'Pomeranian', 'Samoyed', 'Sphynx', 'Shiba Inu', 'Siamese', 'Chihuahua', 'Egyptian Mau', 'Leonberger', 'Saint Bernard', 'Havanese', 'Yorkshire Terrier', 'Birman', 'Pug', 'Abyssinian']
Unseen classes (15): ['Wheaten Terrier', 'English Setter', 'Keeshond', 'American Pit Bull Terrier', 'Staffordshire Bull Terrier', 'Scottish Terrier', 'Miniature Pinscher', 'Basset Hound', 'Persian', 'Boxer', 'German Shorthaired', 'Great Pyrenees', 'Japanese Chin', 'American Bulldog', 'Bombay']


Building a filtered dataset - training (seen) and test (unseen)

In [None]:
class FilteredOxfordPets(Dataset):
    def __init__(self, base_dataset: OxfordIIITPet, allowed_classes: list):
        super().__init__()
        self.base_dataset = base_dataset
        self.allowed_classes = allowed_classes

        # Map each allowed class‐name to its original integer index
        self.allowed_class_indices = [
            base_dataset.classes.index(c) for c in allowed_classes
        ]

        # Build a list of indices in base_dataset whose label is in allowed_class_indices
        self.filtered_indices = [
            i
            for i, (_, label) in enumerate(base_dataset)
            if label in self.allowed_class_indices
        ]

        # Now, for any image whose original label is `L`,
        # we will re‐index it so that L_new = position of L in allowed_class_indices.
        # That way, `0 <= L_new < len(allowed_classes)`.
        # Example: if allowed_class_indices = [3, 7, 12], then original label=7 → new label=1.

    def __len__(self):
        return len(self.filtered_indices)

    def __getitem__(self, idx):
        actual_idx = self.filtered_indices[idx]
        img, orig_label = self.base_dataset[actual_idx]
        new_label = self.allowed_class_indices.index(orig_label)
        return img, new_label

# Instantiate train / test splits
train_dataset = FilteredOxfordPets(dataset, seen_classes)
test_dataset = FilteredOxfordPets(dataset, unseen_classes)

print(f"Train dataset size (seen classes): {len(train_dataset)}")
print(f"Test dataset size (unseen classes): {len(test_dataset)}\n")

# Print mapping sanity check
print("Training (seen) class idx → breed name:")
for new_idx, breed in enumerate(seen_classes):
    print(f"  {new_idx} → {breed}")
print("\nTest (unseen) class idx → breed name:")
for new_idx, breed in enumerate(unseen_classes):
    print(f"  {new_idx} → {breed}")


Train dataset size (seen classes): 2184
Test dataset size (unseen classes): 1496

Training (seen) class idx → breed name:
  0 → Bengal
  1 → Maine Coon
  2 → English Cocker Spaniel
  3 → British Shorthair
  4 → Newfoundland
  5 → Ragdoll
  6 → Russian Blue
  7 → Beagle
  8 → Pomeranian
  9 → Samoyed
  10 → Sphynx
  11 → Shiba Inu
  12 → Siamese
  13 → Chihuahua
  14 → Egyptian Mau
  15 → Leonberger
  16 → Saint Bernard
  17 → Havanese
  18 → Yorkshire Terrier
  19 → Birman
  20 → Pug
  21 → Abyssinian

Test (unseen) class idx → breed name:
  0 → Wheaten Terrier
  1 → English Setter
  2 → Keeshond
  3 → American Pit Bull Terrier
  4 → Staffordshire Bull Terrier
  5 → Scottish Terrier
  6 → Miniature Pinscher
  7 → Basset Hound
  8 → Persian
  9 → Boxer
  10 → German Shorthaired
  11 → Great Pyrenees
  12 → Japanese Chin
  13 → American Bulldog
  14 → Bombay


Building class-label embeddings

In [None]:
embedder = SentenceTransformer("all-MiniLM-L6-v2")  # 384-D output

# Option: you could prepend a prompt like "a photo of a {breed}" for richer semantics:
prompted_seen = [f"a photo of a {breed}" for breed in seen_classes]
prompted_unseen = [f"a photo of a {breed}" for breed in unseen_classes]

with torch.no_grad():
    # Embed all seen & unseen prompts
    seen_emb_list = embedder.encode(prompted_seen, convert_to_tensor=True, normalize_embeddings=True)
    unseen_emb_list = embedder.encode(prompted_unseen, convert_to_tensor=True, normalize_embeddings=True)

# seen_emb_list: (num_seen, 384), normalized
# unseen_emb_list: (num_unseen, 384), normalized

seen_embeddings = seen_emb_list.cpu()    # Move to CPU so we can save
unseen_embeddings = unseen_emb_list.cpu()

print("Seen_embeddings shape:", seen_embeddings.shape)
print("Unseen_embeddings shape:", unseen_embeddings.shape)

# Save to disk
torch.save({
    "seen": seen_embeddings,
    "unseen": unseen_embeddings,
}, "class_name_embeddings.pt")

print("› Saved class_name_embeddings.pt")


Seen_embeddings shape: torch.Size([22, 384])
Unseen_embeddings shape: torch.Size([15, 384])
› Saved class_name_embeddings.pt


Building a ResNet18 (predefined CNN) that outputs a 512-D feature vector

In [None]:
class ResNet18FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
        # Remove last FC; keep everything up to the global-avgpool
        self.features = nn.Sequential(*list(resnet18.children())[:-1])
        # The final output after .view(...) will be 512-D

    def forward(self, x):
        x = self.features(x)           # → (B, 512, 1, 1)
        x = x.view(x.size(0), -1)      # → (B, 512)
        return x

feature_extractor = ResNet18FeatureExtractor().to(device)
feature_extractor.eval()  # We won't fine-tune ResNet right now; freeze it
for param in feature_extractor.parameters():
    param.requires_grad = False

Building a mapper: 512 -> 384

In [None]:
class ImageToEmbeddingMapper(nn.Module):
    def __init__(self, input_dim=512, output_dim=384):
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.fc(x)
        return x

mapper = ImageToEmbeddingMapper(input_dim=512, output_dim=384).to(device)

DataLoader for train_dataset

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

# Load the previously‐saved class embeddings
data = torch.load("class_name_embeddings.pt")
seen_embeddings = data["seen"].to(device)    # (num_seen, 384)
unseen_embeddings = data["unseen"].to(device)  # (num_unseen, 384)

# Collect all training features (512-D) and the corresponding labels
all_feats = []
all_lbls = []

feature_extractor.eval()
with torch.no_grad():
    for imgs, lbls in train_loader:
        imgs = imgs.to(device)
        lbls = lbls.to(device)           # each ∈ [0, num_seen−1]

        feats = feature_extractor(imgs)  # → (B, 512)
        all_feats.append(feats.cpu())
        all_lbls.append(lbls.cpu())

train_features = torch.cat(all_feats, dim=0)  # (N_seen_examples, 512)
train_labels = torch.cat(all_lbls, dim=0)     # (N_seen_examples,)

print("Train_features:", train_features.shape)
print("Train_labels:", train_labels.shape)

# Build target embeddings for each example:
#   target_embeddings[i] = seen_embeddings[ train_labels[i] ]
target_embeddings = seen_embeddings[train_labels].to(device)  # (N_seen_examples, 384)

print("Target_embeddings:", target_embeddings.shape)

Train_features: torch.Size([2184, 512])
Train_labels: torch.Size([2184])
Target_embeddings: torch.Size([2184, 384])


Setting up training

In [None]:
criterion = nn.MSELoss()
optimizer = optim.Adam(mapper.parameters(), lr=1e-4, weight_decay=1e-5)

EPOCHS = 100
batch_size = 64   # We’ll split the big train_features into minibatches

# Create a TensorDataset so we can do minibatch SGD on (features, target_embeds)
from torch.utils.data import TensorDataset
train_tensor_dataset = TensorDataset(train_features, train_labels)
# Note: we could also build (train_features, target_embeddings), but re-indexing on the fly:
#    inside the loop: tgt = seen_embeddings[ label_batch ]

train_tensor_loader = DataLoader(train_tensor_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(EPOCHS):
    mapper.train()
    running_loss = 0.0

    for feat_batch, lbl_batch in train_tensor_loader:
        feat_batch = feat_batch.to(device)       # (B, 512)
        lbl_batch = lbl_batch.to(device)         # (B,)

        # Forward
        pred = mapper(feat_batch)                # (B, 384)
        # Normalize pred → (B, 384) to lie on unit sphere (optional but usually helps)
        pred_norm = F.normalize(pred, dim=1)

        # Look up the target embeddings (already normalized)
        tgt = seen_embeddings[lbl_batch]         # (B, 384)

        # Compute MSE
        loss = criterion(pred_norm, tgt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * feat_batch.size(0)

    epoch_loss = running_loss / len(train_tensor_dataset)
    print(f"Epoch [{epoch+1}/{EPOCHS}]  Loss: {epoch_loss:.6f}")


Epoch [1/100]  Loss: 0.003288
Epoch [2/100]  Loss: 0.001994
Epoch [3/100]  Loss: 0.001780
Epoch [4/100]  Loss: 0.001663
Epoch [5/100]  Loss: 0.001573
Epoch [6/100]  Loss: 0.001493
Epoch [7/100]  Loss: 0.001419
Epoch [8/100]  Loss: 0.001349
Epoch [9/100]  Loss: 0.001283
Epoch [10/100]  Loss: 0.001223
Epoch [11/100]  Loss: 0.001170
Epoch [12/100]  Loss: 0.001123
Epoch [13/100]  Loss: 0.001078
Epoch [14/100]  Loss: 0.001039
Epoch [15/100]  Loss: 0.001005
Epoch [16/100]  Loss: 0.000974
Epoch [17/100]  Loss: 0.000943
Epoch [18/100]  Loss: 0.000918
Epoch [19/100]  Loss: 0.000893
Epoch [20/100]  Loss: 0.000871
Epoch [21/100]  Loss: 0.000850
Epoch [22/100]  Loss: 0.000829
Epoch [23/100]  Loss: 0.000810
Epoch [24/100]  Loss: 0.000795
Epoch [25/100]  Loss: 0.000779
Epoch [26/100]  Loss: 0.000764
Epoch [27/100]  Loss: 0.000750
Epoch [28/100]  Loss: 0.000735
Epoch [29/100]  Loss: 0.000723
Epoch [30/100]  Loss: 0.000712
Epoch [31/100]  Loss: 0.000699
Epoch [32/100]  Loss: 0.000688
Epoch [33/100]  L

DataLoader for test_dataset

In [None]:
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

mapper.eval()
feature_extractor.eval()

total = 0
correct = 0

# To turn predicted index → breed name:
idx2unseen = {i: breed for i, breed in enumerate(unseen_classes)}

# For optional sanity printing:
MAX_PRINT = 20
printed = 0

with torch.no_grad():
    for imgs, true_lbls in test_loader:
        imgs = imgs.to(device)
        true_lbls = true_lbls.to(device)   # each ∈ [0, num_unseen−1]

        feats = feature_extractor(imgs)    # → (B, 512)
        preds_raw = mapper(feats)          # → (B, 384)
        preds_norm = F.normalize(preds_raw, dim=1)  # normalize to unit sphere

        # Normalize unseen_embeddings if not already:
        unseen_norm = F.normalize(unseen_embeddings.to(device), dim=1)  # (num_unseen, 384)

        # Cosine similarities: (B, 384) × (384, num_unseen) → (B, num_unseen)
        sims = torch.matmul(preds_norm, unseen_norm.T)

        # Pick top index along unseen side
        pred_indices = sims.argmax(dim=1)  # (B,)

        # Compute accuracy
        correct += (pred_indices == true_lbls).sum().item()
        total += true_lbls.size(0)

        # (Optional) print a few examples
        if printed < MAX_PRINT:
            for t_lbl, p_lbl in zip(true_lbls, pred_indices):
                true_breed = idx2unseen[t_lbl.item()]
                pred_breed = idx2unseen[p_lbl.item()]
                print(f"✔ True: {true_breed:20s} | Predicted: {pred_breed}")
                printed += 1
                if printed >= MAX_PRINT:
                    break

accuracy = 100.0 * correct / total
print("\n🔍 Zero-Shot Accuracy on Unseen Classes: {:.2f}%".format(accuracy))

✔ True: American Bulldog     | Predicted: Miniature Pinscher
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Staffordshire Bull Terrier
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: American Bulldog     | Predicted: Great Pyrenees
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: American Bulldog     | Predicted: Scottish Terrier
✔ True: American Bulldog     | Predicted: Basset Hound
✔ True: America