In [1]:
import pandas as pd
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# Auto select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---------------------------------------------------
# Load metadata
# ---------------------------------------------------
with open("/kaggle/input/cub2002011/CUB_200_2011/train_test_split.txt") as f:
    split = dict(line.strip().split() for line in f)

with open("/kaggle/input/cub2002011/CUB_200_2011/images.txt") as f:
    paths = dict(line.strip().split() for line in f)

with open("/kaggle/input/cub2002011/CUB_200_2011/image_class_labels.txt") as f:
    labels = dict(line.strip().split() for line in f)

# ---------------------------------------------------
# Select only FIRST 5 classes
# ---------------------------------------------------
selected_classes = set(list({int(v) for v in labels.values()})[:200])

print("Using classes:", selected_classes)

train_paths, train_labels = [], []
test_paths, test_labels = [], []

base = "/kaggle/input/cub2002011/CUB_200_2011/images/"

for img_id, rel in paths.items():
    cls = int(labels[img_id])
    if cls not in selected_classes:
        continue

    full = base + rel
    if split[img_id] == "1":
        train_paths.append(full)
        train_labels.append(cls)
    else:
        test_paths.append(full)
        test_labels.append(cls)

print("Train images:", len(train_paths))
print("Test images :", len(test_paths))

# ---------------------------------------------------
# Convert to DataFrames (path + class)
# ---------------------------------------------------
train_df = pd.DataFrame({"path": train_paths, "class": train_labels})
test_df  = pd.DataFrame({"path": test_paths , "class": test_labels})

# ---------------------------------------------------
# Per-class sample counts
# ---------------------------------------------------

train_count = train_df["class"].value_counts().sort_index()
test_count  = test_df["class"].value_counts().sort_index()

print("\n===== TRAIN PER-CLASS COUNTS =====")
print(train_count)

print("\n===== TEST PER-CLASS COUNTS =====")
print(test_count)

print("\n===== SUMMARY =====")
print("Train: classes =", train_count.index.nunique(),
      "| min =", train_count.min(),
      "| max =", train_count.max(),
      "| avg =", train_count.mean())

print("Test : classes =", test_count.index.nunique(),
      "| min =", test_count.min(),
      "| max =", test_count.max(),
      "| avg =", test_count.mean())

# ---------------------------------------------------
# Dataset class with transforms
# ---------------------------------------------------
transform_train = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

class CUBDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["path"].values
        self.labels = df["class"].values
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        cls = self.labels[idx] - 1
        return img, cls

# ---------------------------------------------------
# Dataloaders
# ---------------------------------------------------
train_dataset = CUBDataset(train_df, transform_train)
test_dataset = CUBDataset(test_df, transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset , batch_size=32, shuffle=False, num_workers=2)

print("Train loader batches:", len(train_loader))
print("Test loader batches :", len(test_loader))

Using device: cuda
Using classes: {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200}
Train images: 5994
Test images : 5794

===== TRAIN PER-CLASS COUNTS =====

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleMetricAutoEncoder(nn.Module):
    def __init__(self, embed_dim=512, num_classes=200):
        super().__init__()
        goog = models.googlenet(weights="IMAGENET1K_V1")
        self.backbone = nn.Sequential(
            goog.conv1, goog.maxpool1,
            goog.conv2, goog.conv3, goog.maxpool2,
            goog.inception3a, goog.inception3b, goog.maxpool3,
            goog.inception4a, goog.inception4b, goog.inception4c,
            goog.inception4d, goog.inception4e,
            goog.maxpool4,
            goog.inception5a, goog.inception5b,
            goog.avgpool
        )
        self.fc_embed = nn.Linear(1024, embed_dim)
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024)
        )
        self.proxies = nn.Parameter(torch.randn(num_classes, embed_dim))

    def extract_f(self, x):
        f = self.backbone(x)
        return f.view(f.size(0), -1)

    def forward(self, x):
        f = self.extract_f(x)
        z = self.fc_embed(f)
        f_hat = self.decoder(z)
        return z, f, f_hat

class ProxyNCAPlusPlus(nn.Module):
    def __init__(self, scale=3):
        super().__init__()
        self.scale = scale

    def forward(self, z, labels, proxies):
        z = F.normalize(z, dim=1)
        p = F.normalize(proxies, dim=1)
        sim = self.scale * (z @ p.t())
        return F.cross_entropy(sim, labels)

def recon_loss(f, f_hat):
    return F.mse_loss(f_hat, f)

num_classes = len(set([lbl for _, lbl in train_loader.dataset]))
model = SimpleMetricAutoEncoder(embed_dim=128, num_classes=num_classes).to(device)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
proxynca = ProxyNCAPlusPlus(scale=8)

for epoch in range(3):
    model.train()
    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        z, f, f_hat = model(imgs)

        loss = proxynca(z, labels, model.module.proxies if isinstance(model, nn.DataParallel) else model.proxies)
        loss = loss + recon_loss(f, f_hat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print("Epoch", epoch, "Loss =", loss.item())

def extract_embeddings(model, loader):
    E, L = [], []
    model.eval()
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            z, _, _ = model(imgs)
            E.append(z.cpu())
            L.append(lbls)
    return torch.cat(E), torch.cat(L)

def recall_at_k(E, L, K=1):
    E = F.normalize(E, dim=1)
    S = E @ E.t()
    N = len(L)
    S[range(N), range(N)] = -1
    _, idx = S.topk(K, dim=1)
    c = 0
    for i in range(N):
        if L[i] in L[idx[i]]:
            c += 1
    return c / N

E, L = extract_embeddings(model, test_loader)
for k in [1,2,4]:
    print("R@", k, "=", recall_at_k(E, L, k))

Epoch 1 Loss = 4.440831184387207
Epoch 2 Loss = 4.384459018707275
R@ 1 = 0.5288229202623403
R@ 2 = 0.6432516396272006
R@ 4 = 0.7469796341042457


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DVML(nn.Module):
    def __init__(self, embed_dim=512, var_dim=128, num_classes=200):
        super().__init__()
        goog = models.googlenet(weights="IMAGENET1K_V1")
        self.backbone = nn.Sequential(
            goog.conv1, goog.maxpool1,
            goog.conv2, goog.conv3, goog.maxpool2,
            goog.inception3a, goog.inception3b, goog.maxpool3,
            goog.inception4a, goog.inception4b, goog.inception4c,
            goog.inception4d, goog.inception4e,
            goog.maxpool4,
            goog.inception5a, goog.inception5b,
            goog.avgpool
        )
        self.fc_mu = nn.Linear(1024, embed_dim)
        self.fc_logvar = nn.Linear(1024, var_dim)
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024)
        )
        self.proxies = nn.Parameter(torch.randn(num_classes, embed_dim))

    def encode(self, x):
        f = self.backbone(x)
        f = f.view(f.size(0), -1)
        mu = self.fc_mu(f)
        logvar = self.fc_logvar(f)
        return mu, logvar, f

    def reparam(self, mu, logvar, noise_scale):
        std = torch.exp(0.5 * logvar) * noise_scale
        eps = torch.randn_like(std)
        return mu + eps * std, eps * std

    def forward(self, x, noise_scale):
        mu, logvar, f = self.encode(x)
        z, noise = self.reparam(mu, logvar, noise_scale)
        f_hat = self.decoder(z)
        return z, mu, logvar, f, f_hat, noise


class ProxyNCAPlusPlus(nn.Module):
    def __init__(self, scale=3):
        super().__init__()
        self.scale = scale

    def forward(self, z, labels, proxies):
        z = F.normalize(z, dim=1)
        p = F.normalize(proxies, dim=1)
        sim = self.scale * (z @ p.t())
        return F.cross_entropy(sim, labels)


def kl_loss(mu, logvar):
    return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())


def recon_loss(f, f_hat):
    return F.mse_loss(f_hat, f)


num_classes = len(set([lbl for _, lbl in train_loader.dataset]))
model = DVML(embed_dim=128, var_dim=128, num_classes=num_classes).to(device)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
proxynca = ProxyNCAPlusPlus(scale=8)


for epoch in range(4):
    model.train()
    
    kl_w = min(0.1, epoch / 50)

    noise_scale = max(0.3, 1.0 - epoch/10)

    for imgs, labels in train_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        z, mu, logvar, f, f_hat, noise = model(imgs, noise_scale)

        proxies = model.module.proxies if isinstance(model, nn.DataParallel) else model.proxies
        loss_metric = proxynca(z, labels, proxies)

        z_syn = mu + noise
        loss_metric_syn = proxynca(z_syn, labels, proxies)

        loss = loss_metric + loss_metric_syn + recon_loss(f, f_hat) + kl_w * kl_loss(mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(epoch, loss.item())


def extract_embeddings(model, loader):
    E, L = [], []
    model.eval()
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs = imgs.to(device)
            z, _, _, _, _, _ = model(imgs, noise_scale=0.0)
            E.append(z.cpu())
            L.append(lbls.cpu())
    return torch.cat(E), torch.cat(L)


def recall_at_k(E, L, K=1):
    E = F.normalize(E, dim=1)
    S = E @ E.t()
    N = len(L)
    S[range(N), range(N)] = -1
    _, idx = S.topk(K, dim=1)
    c = 0
    for i in range(N):
        if L[i] in L[idx[i]]:
            c += 1
    return c / N


E, L = extract_embeddings(model, test_loader)
for k in [1,2,4]:
    print("R@", k, "=", recall_at_k(E, L, k))

0 10.556137084960938
1 9.585996627807617
2 8.980067253112793
3 8.689001083374023
R@ 1 = 0.5517777010700725
R@ 2 = 0.6589575422851225
R@ 4 = 0.762167759751467
