In [None]:
import pandas as pd
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# Auto select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---------------------------------------------------
# Load metadata
# ---------------------------------------------------
with open("/kaggle/input/cub2002011/CUB_200_2011/train_test_split.txt") as f:
    split = dict(line.strip().split() for line in f)

with open("/kaggle/input/cub2002011/CUB_200_2011/images.txt") as f:
    paths = dict(line.strip().split() for line in f)

with open("/kaggle/input/cub2002011/CUB_200_2011/image_class_labels.txt") as f:
    labels = dict(line.strip().split() for line in f)

# ---------------------------------------------------
# Select only FIRST 5 classes
# ---------------------------------------------------
selected_classes = set(list({int(v) for v in labels.values()})[:200])

print("Using classes:", selected_classes)

train_paths, train_labels = [], []
test_paths, test_labels = [], []

base = "/kaggle/input/cub2002011/CUB_200_2011/images/"

for img_id, rel in paths.items():
    cls = int(labels[img_id])
    if cls not in selected_classes:
        continue

    full = base + rel
    if split[img_id] == "1":
        train_paths.append(full)
        train_labels.append(cls)
    else:
        test_paths.append(full)
        test_labels.append(cls)

print("Train images:", len(train_paths))
print("Test images :", len(test_paths))

# ---------------------------------------------------
# Convert to DataFrames (path + class)
# ---------------------------------------------------
train_df = pd.DataFrame({"path": train_paths, "class": train_labels})
test_df  = pd.DataFrame({"path": test_paths , "class": test_labels})

# ---------------------------------------------------
# Per-class sample counts
# ---------------------------------------------------

train_count = train_df["class"].value_counts().sort_index()
test_count  = test_df["class"].value_counts().sort_index()

print("\n===== TRAIN PER-CLASS COUNTS =====")
print(train_count)

print("\n===== TEST PER-CLASS COUNTS =====")
print(test_count)

print("\n===== SUMMARY =====")
print("Train: classes =", train_count.index.nunique(),
      "| min =", train_count.min(),
      "| max =", train_count.max(),
      "| avg =", train_count.mean())

print("Test : classes =", test_count.index.nunique(),
      "| min =", test_count.min(),
      "| max =", test_count.max(),
      "| avg =", test_count.mean())

# ---------------------------------------------------
# Dataset class with transforms
# ---------------------------------------------------
transform_train = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

class CUBDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["path"].values
        self.labels = df["class"].values
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        cls = self.labels[idx] - 1
        return img, cls

# ---------------------------------------------------
# Dataloaders
# ---------------------------------------------------
train_dataset = CUBDataset(train_df, transform_train)
test_dataset = CUBDataset(test_df, transform_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader  = DataLoader(test_dataset , batch_size=32, shuffle=False, num_workers=2)

print("Train loader batches:", len(train_loader))
print("Test loader batches :", len(test_loader))

In [None]:
# --------------------------------------------------------------
# 1. Imports & Device
# --------------------------------------------------------------
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, f1_score, pairwise_distances
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("PyTorch:", torch.__version__)

# --------------------------------------------------------------
# 2. Seed
# --------------------------------------------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_seed(42)

# --------------------------------------------------------------
# 3. Load CUB metadata
# --------------------------------------------------------------
base_dir = "/kaggle/input/cub2002011/CUB_200_2011"

with open(os.path.join(base_dir, "train_test_split.txt")) as f:
    split_dict = {k: int(v) for k, v in (l.strip().split() for l in f)}
with open(os.path.join(base_dir, "images.txt")) as f:
    path_dict = dict(l.strip().split() for l in f)
with open(os.path.join(base_dir, "image_class_labels.txt")) as f:
    label_dict = dict(l.strip().split() for l in f)

# --------------------------------------------------------------
# 4. Zero-shot split: first 100 train, last 100 test
# --------------------------------------------------------------
all_classes = sorted({int(v) for v in label_dict.values()})
train_classes = set(all_classes[:100])
test_classes  = set(all_classes[100:200])

train_paths, train_labels = [], []
test_paths , test_labels  = [], []

img_base = os.path.join(base_dir, "images")
for img_id, rel in path_dict.items():
    cls = int(label_dict[img_id])
    if cls not in train_classes and cls not in test_classes: continue
    full = os.path.join(img_base, rel)
    is_train = split_dict[img_id] == 1

    if cls in train_classes and is_train:
        train_paths.append(full); train_labels.append(cls-1)
    elif cls in test_classes and not is_train:
        test_paths.append(full); test_labels.append(cls-1)

train_df = pd.DataFrame({"path": train_paths, "class": train_labels})
test_df  = pd.DataFrame({"path": test_paths , "class": test_labels})

print(f"Train (known): {len(train_df)} imgs, {len(train_classes)} classes")
print(f"Test  (unseen): {len(test_df)} imgs, {len(test_classes)} classes")

# --------------------------------------------------------------
# 5. Transforms
# --------------------------------------------------------------
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

class CUBDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["path"].values
        self.labels = df["class"].values
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        return img, self.labels[idx]

batch_size = 128
train_loader = DataLoader(CUBDataset(train_df, train_transform),
                          batch_size=batch_size, shuffle=True,
                          num_workers=4, pin_memory=True, drop_last=True)
test_loader  = DataLoader(CUBDataset(test_df, test_transform),
                          batch_size=batch_size, shuffle=False,
                          num_workers=4, pin_memory=True)

# --------------------------------------------------------------
# 6. Simple AutoEncoder Model
# --------------------------------------------------------------
class SimpleMetricAutoEncoder(nn.Module):
    def __init__(self, embed_dim=128):
        super().__init__()
        goog = models.googlenet(weights="IMAGENET1K_V1")
        self.backbone = nn.Sequential(
            goog.conv1, goog.maxpool1,
            goog.conv2, goog.conv3, goog.maxpool2,
            goog.inception3a, goog.inception3b, goog.maxpool3,
            goog.inception4a, goog.inception4b, goog.inception4c,
            goog.inception4d, goog.inception4e,
            goog.maxpool4,
            goog.inception5a, goog.inception5b,
            goog.avgpool
        )
        self.fc_embed = nn.Linear(1024, embed_dim)
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 1024)
        )

    def extract_f(self, x):
        f = self.backbone(x)
        return f.view(f.size(0), -1)          # (B, 1024)

    def forward(self, x):
        f = self.extract_f(x)
        z = self.fc_embed(f)
        f_hat = self.decoder(z)
        return z, f, f_hat

# --------------------------------------------------------------
# 7. Proxy-NCA++ Loss (Multiple Positive Proxies)
# --------------------------------------------------------------
class ProxyNCAPlusPlus(nn.Module):
    def __init__(self, embed_dim, num_classes, num_pos=3, temp=0.1):
        super().__init__()
        self.proxies = nn.Parameter(torch.randn(num_classes, num_pos, embed_dim))
        nn.init.xavier_uniform_(self.proxies)
        self.temp = temp
        self.num_pos = num_pos

    def forward(self, embedding, label):
        embedding = F.normalize(embedding, dim=1)          # (B, D)
        proxies   = F.normalize(self.proxies, dim=-1)     # (C, K, D)

        B, D = embedding.shape
        C, K = proxies.shape[:2]

        label_exp = label.unsqueeze(1).expand(B, K)       # (B, K)
        pos_proxies = proxies[label_exp, torch.arange(K).to(device)]

        sim_pos = torch.bmm(pos_proxies, embedding.unsqueeze(-1)).squeeze(-1) / self.temp
        sim_pos = sim_pos.max(dim=1)[0]                   # (B,)

        all_proxies = proxies.view(-1, D)                 # (C*K, D)
        sim_all = F.linear(embedding, all_proxies) / self.temp
        logsumexp = torch.logsumexp(sim_all, dim=1)

        loss = -(sim_pos - logsumexp).mean()
        return loss

# --------------------------------------------------------------
# 8. Loss Helper
# --------------------------------------------------------------
def recon_loss(f, f_hat):
    return F.mse_loss(f_hat, f)

# --------------------------------------------------------------
# 9. EVALUATION METRICS (MUST BE DEFINED BEFORE TRAINING)
# --------------------------------------------------------------
def recall_at_k(E, L, k=1):
    sim = E @ E.T
    np.fill_diagonal(sim, -np.inf)
    idx = np.argpartition(-sim, k, axis=1)[:, :k]
    correct = sum(L[i] in L[idx[i]] for i in range(len(L)))
    return correct / len(L)

def clustering_metrics(E, L, n_clusters):
    km = KMeans(n_clusters=n_clusters, n_init=10, random_state=0)
    pred = km.fit_predict(E)
    nmi = normalized_mutual_info_score(L, pred)
    f1  = f1_score(L, pred, average='macro')
    return nmi, f1

def pairwise_prec_recall(E, L):
    D = pairwise_distances(E)
    y_true = (L[:, None] == L[None, :]).astype(int)
    thr = np.median(D)
    y_pred = (D <= thr).astype(int)
    tp = (y_pred * y_true).sum()
    fp = (y_pred * (1 - y_true)).sum()
    fn = ((1 - y_pred) * y_true).sum()
    prec = tp / (tp + fp + 1e-12)
    rec  = tp / (tp + fn + 1e-12)
    return prec, rec

# --------------------------------------------------------------
# 10. Model & Optimizer
# --------------------------------------------------------------
embed_dim   = 128
num_classes = 100                     # only known classes

model     = SimpleMetricAutoEncoder(embed_dim=embed_dim).to(device)
proxy_nca = ProxyNCAPlusPlus(embed_dim, num_classes, num_pos=3, temp=0.1).to(device)

optimizer = torch.optim.Adam(
    list(model.parameters()) + list(proxy_nca.parameters()),
    lr=1e-4,
    weight_decay=1e-5
)

# --------------------------------------------------------------
# 11. Validation split (for early-stop)
# --------------------------------------------------------------
val_ratio = 0.1
val_df_list = []
train_df_list = []
for c in range(100):
    sub = train_df[train_df["class"] == c]
    n_val = max(1, int(len(sub) * val_ratio))
    idx = np.random.choice(sub.index, n_val, replace=False)
    val_df_list.append(sub.loc[idx])
    train_df_list.append(sub.drop(idx))

val_df       = pd.concat(val_df_list).reset_index(drop=True)
train_df_fin = pd.concat(train_df_list).reset_index(drop=True)

train_loader = DataLoader(CUBDataset(train_df_fin, train_transform),
                          batch_size=batch_size, shuffle=True,
                          num_workers=4, pin_memory=True, drop_last=True)
val_loader   = DataLoader(CUBDataset(val_df, test_transform),
                          batch_size=batch_size, shuffle=False,
                          num_workers=4, pin_memory=True)

# --------------------------------------------------------------
# 12. Validation helper (uses recall_at_k defined above)
# --------------------------------------------------------------
@torch.no_grad()
def validate():
    model.eval()
    E, L = [], []
    for x, y in val_loader:
        x = x.to(device)
        z, _, _ = model(x)
        E.append(z.cpu())
        L.append(y)
    E = torch.cat(E)
    L = torch.cat(L)
    E = F.normalize(E, dim=1).numpy()
    L = L.numpy()
    return recall_at_k(E, L, k=1)

# --------------------------------------------------------------
# 13. Training Loop
# --------------------------------------------------------------
EPOCHS   = 100
位_recon  = 1.0
位_nca    = 1.0

best_val_r1 = 0.0
patience    = 15
wait        = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    proxy_nca.train()
    epoch_loss = 0.0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        z, f, f_hat = model(imgs)
        L_rec = recon_loss(f, f_hat)
        L_nca = proxy_nca(z, labels)

        loss = 位_recon * L_rec + 位_nca * L_nca

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    # ------------------- validation -------------------
    val_r1 = validate()
    print(f"Epoch {epoch:02d} | Loss: {epoch_loss/len(train_loader):.4f} | Val R@1: {val_r1:.4f}")

    if val_r1 > best_val_r1:
        best_val_r1 = val_r1
        torch.save(model.state_dict(), "ae_proxy_best.pth")
        wait = 0
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping!")
            break

# Load best checkpoint
model.load_state_dict(torch.load("ae_proxy_best.pth"))

# --------------------------------------------------------------
# 14. Feature extraction helper
# --------------------------------------------------------------
@torch.no_grad()
def extract_feats(loader):
    model.eval()
    embs, lbls = [], []
    for x, y in loader:
        x = x.to(device)
        z, _, _ = model(x)
        embs.append(z.cpu())
        lbls.append(y)
    E = torch.cat(embs)
    L = torch.cat(lbls)
    E = F.normalize(E, dim=1).numpy()
    L = L.numpy()
    return E, L

# --------------------------------------------------------------
# 15. TRAIN SET evaluation
# --------------------------------------------------------------
print("\n" + "="*60)
print("EVALUATION ON TRAIN SET (100 known classes)")
print("="*60)
train_E, train_L = extract_feats(train_loader)
for k in [1,2,4,8]:
    print(f"Recall@{k}: {recall_at_k(train_E, train_L, k):.4f}")
nmi_tr, f1_tr = clustering_metrics(train_E, train_L, n_clusters=100)
print(f"NMI : {nmi_tr:.4f}")
print(f"F1  : {f1_tr:.4f}")
p_tr, r_tr = pairwise_prec_recall(train_E, train_L)
print(f"Prec: {p_tr:.4f}  Rec: {r_tr:.4f}")

# --------------------------------------------------------------
# 16. TEST SET evaluation
# --------------------------------------------------------------
print("\n" + "="*60)
print("EVALUATION ON TEST SET (100 unseen classes)")
print("="*60)
test_E, test_L = extract_feats(test_loader)
for k in [1,2,4,8]:
    print(f"Recall@{k}: {recall_at_k(test_E, test_L, k):.4f}")
nmi_te, f1_te = clustering_metrics(test_E, test_L, n_clusters=100)
print(f"NMI : {nmi_te:.4f}")
print(f"F1  : {f1_te:.4f}")
p_te, r_te = pairwise_prec_recall(test_E, test_L)
print(f"Prec: {p_te:.4f}  Rec: {r_te:.4f}")

# --------------------------------------------------------------
# 17. PCA visualisation (test)
# --------------------------------------------------------------
def plot_pca(E, L, title="PCA", n_samples=3000):
    if len(E) > n_samples:
        idx = np.random.choice(len(E), n_samples, replace=False)
        E, L = E[idx], L[idx]
    X = PCA(n_components=2).fit_transform(E)
    plt.figure(figsize=(10, 8))
    s = plt.scatter(X[:,0], X[:,1], c=L, cmap='tab20', s=15, alpha=0.7)
    plt.title(title)
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.colorbar(s)
    plt.grid(True)
    plt.show()

plot_pca(test_E, test_L, "PCA of AutoEncoder + Proxy-NCA++ (100 Unseen Classes)")



def plot_tsne(E, L, title="t-SNE", n_samples=3000, perplexity=30, random_state=42):
    """
    Visualise embeddings E with labels L using t-SNE.
    
    Parameters
    ----------
    E : np.ndarray, shape (N, D)
        Embedding matrix.
    L : np.ndarray, shape (N,)
        Integer class labels.
    title : str
        Plot title.
    n_samples : int
        If N > n_samples, randomly subsample.
    perplexity : float
        t-SNE perplexity (typical values 5-50).
    random_state : int
        For reproducibility.
    """
    if len(E) > n_samples:
        idx = np.random.choice(len(E), n_samples, replace=False)
        E, L = E[idx], L[idx]

    # t-SNE (Barnes-Hut, O(N log N))
    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        n_iter=1000,
        random_state=random_state,
        method='barnes_hut',   # fast approximate version
        n_jobs=-1
    )
    X_2d = tsne.fit_transform(E)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        X_2d[:, 0], X_2d[:, 1],
        c=L, cmap='tab20', s=15, alpha=0.7
    )
    plt.title(title)
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.colorbar(scatter, label='Class')
    plt.grid(True)
    plt.tight_layout()
    plt.show()


# ------------------------------------------------------------------
# Example usage (exactly like your PCA call)
# ------------------------------------------------------------------
plot_tsne(
    test_E, test_L,
    title="t-SNE of AutoEncoder + Proxy-NCA++ (100 Unseen Classes)",
    n_samples=3000,
    perplexity=40          # tweak 20-50 for 100 classes
)