In [None]:
# --------------------------------------------------------------
# 1. Imports & device
# --------------------------------------------------------------
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image

from tqdm import tqdm
from sklearn.cluster import KMeans
from sklearn.metrics import normalized_mutual_info_score, f1_score, pairwise_distances
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --------------------------------------------------------------
# 2. Seed everything
# --------------------------------------------------------------
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# --------------------------------------------------------------
# 3. Load CUB metadata
# --------------------------------------------------------------
base_dir = "/kaggle/input/cub2002011/CUB_200_2011"

with open(os.path.join(base_dir, "train_test_split.txt")) as f:
    split_dict = {k: int(v) for k, v in (l.strip().split() for l in f)}
with open(os.path.join(base_dir, "images.txt")) as f:
    path_dict = dict(l.strip().split() for l in f)
with open(os.path.join(base_dir, "image_class_labels.txt")) as f:
    label_dict = dict(l.strip().split() for l in f)

# --------------------------------------------------------------
# 4. Zero-shot split (first 100 = train, last 100 = test)
# --------------------------------------------------------------
all_classes = sorted({int(v) for v in label_dict.values()})
train_classes = set(all_classes[:100])   # known
test_classes  = set(all_classes[100:200])# unseen

train_paths, train_labels = [], []
test_paths , test_labels  = [], []

img_base = os.path.join(base_dir, "images")
for img_id, rel in path_dict.items():
    cls = int(label_dict[img_id])
    if cls not in train_classes and cls not in test_classes:
        continue
    full = os.path.join(img_base, rel)
    is_train = split_dict[img_id] == 1

    if cls in train_classes and is_train:
        train_paths.append(full); train_labels.append(cls-1)
    elif cls in test_classes and not is_train:
        test_paths.append(full); test_labels.append(cls-1)

train_df = pd.DataFrame({"path": train_paths, "class": train_labels})
test_df  = pd.DataFrame({"path": test_paths , "class": test_labels})

print(f"Train (known): {len(train_df)} imgs, {len(train_classes)} classes")
print(f"Test  (unseen): {len(test_df)} imgs, {len(test_classes)} classes")

# --------------------------------------------------------------
# 5. Validation split from *known* classes (for early-stop)
# --------------------------------------------------------------
val_ratio = 0.15
val_df_list = []
train_df_list = []

for c in train_classes:
    sub = train_df[train_df["class"] == c-1]
    n_val = max(1, int(len(sub) * val_ratio))
    val_idx = np.random.choice(sub.index, n_val, replace=False)
    val_df_list.append(sub.loc[val_idx])
    train_df_list.append(sub.drop(val_idx))

val_df  = pd.concat(val_df_list).reset_index(drop=True)
train_df = pd.concat(train_df_list).reset_index(drop=True)

print(f"Train (after val split): {len(train_df)}")
print(f"Val   (known): {len(val_df)}")

# --------------------------------------------------------------
# 6. Transforms (paper exact)
# --------------------------------------------------------------
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std =[0.229, 0.224, 0.225])
])

test_transform = val_transform  # same as val

class CUBDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["path"].values
        self.labels = df["class"].values
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        return img, self.labels[idx]

batch_size = 128
train_loader = DataLoader(CUBDataset(train_df, train_transform),
                          batch_size=batch_size, shuffle=True,
                          num_workers=4, pin_memory=True, drop_last=True)
val_loader   = DataLoader(CUBDataset(val_df , val_transform),
                          batch_size=batch_size, shuffle=False,
                          num_workers=4, pin_memory=True)
test_loader  = DataLoader(CUBDataset(test_df, test_transform),
                          batch_size=batch_size, shuffle=False,
                          num_workers=4, pin_memory=True)

# --------------------------------------------------------------
# 7. DVML model (exact architecture)
# --------------------------------------------------------------
class DVML(nn.Module):
    def __init__(self, embed_dim=512, T=20, dropout=0.3):
        super().__init__()
        goog = models.googlenet(weights="IMAGENET1K_V1")
        # keep everything up to avg-pool (1024-d)
        self.backbone = nn.Sequential(*list(goog.children())[:-1])
        self.pool = nn.AdaptiveAvgPool2d((1,1))

        self.fc_I     = nn.Sequential(nn.Linear(1024, embed_dim), nn.Dropout(dropout))
        self.fc_mu    = nn.Sequential(nn.Linear(1024, embed_dim), nn.Dropout(dropout))
        self.fc_logvar= nn.Sequential(nn.Linear(1024, embed_dim), nn.Dropout(dropout))

        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, 512),
            nn.Tanh(),
            nn.Linear(512, 1024)
        )
        self.T = T
        self.embed_dim = embed_dim

    def encode(self, x):
        f = self.pool(self.backbone(x)).view(x.size(0), -1)   # (B,1024)
        z_I = self.fc_I(f)
        mu  = self.fc_mu(f)
        logvar = self.fc_logvar(f)
        return z_I, mu, logvar, f

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, x, phase=2):
        z_I, mu, logvar, f = self.encode(x)
        z_V = self.reparameterize(mu, logvar)               # (B,D)

        # ---- T synthetic samples ----
        z_V = z_V.unsqueeze(1).repeat(1, self.T, 1)          # (B,T,D)
        z_I_rep = z_I.unsqueeze(1).repeat(1, self.T, 1)
        z_hat = z_I_rep + z_V                               # (B,T,D)

        z_hat_flat = z_hat.view(-1, self.embed_dim)         # (B*T,D)
        f_hat_flat = self.decoder(z_hat_flat)               # (B*T,1024)

        if phase == 1:                                      # stop decoder grads
            f_hat_flat = f_hat_flat.detach()

        return z_I, f, f_hat_flat, z_hat_flat, mu, logvar, z_hat

# --------------------------------------------------------------
# 8. Proxy-NCA with label-smoothing & L2 on proxies
# --------------------------------------------------------------
class ProxyNCA(nn.Module):
    def __init__(self, embed_dim, n_classes, temp=0.1, label_smoothing=0.1, proxy_l2=1e-4):
        super().__init__()
        self.proxies = nn.Parameter(torch.randn(n_classes, embed_dim))
        nn.init.xavier_uniform_(self.proxies)
        self.temp = temp
        self.ls   = label_smoothing
        self.l2   = proxy_l2

    def forward(self, emb, label):
        emb = F.normalize(emb, dim=1)
        prox = F.normalize(self.proxies, dim=1)

        sim = F.linear(emb, prox) / self.temp                # (B,C)

        # label-smoothing
        target = F.one_hot(label, num_classes=prox.size(0)).float()
        target = target * (1-self.ls) + (self.ls / prox.size(0))

        log_prob = F.log_softmax(sim, dim=1)
        loss = - (target * log_prob).sum(dim=1).mean()

        # L2 regularisation on proxies
        if self.l2 > 0:
            loss = loss + self.l2 * (prox ** 2).sum()
        return loss

# --------------------------------------------------------------
# 9. Loss helpers
# --------------------------------------------------------------
def kl_divergence(mu, logvar):
    return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

def recon_loss(f, f_hat_flat, T):
    f_rep = f.repeat(T, 1)
    return F.mse_loss(f_hat_flat, f_rep)

# --------------------------------------------------------------
# 10. Model / optimiser
# --------------------------------------------------------------
embed_dim = 512
T = 20
num_classes = len(train_classes)          # 100 known

model       = DVML(embed_dim=embed_dim, T=T, dropout=0.3).to(device)
proxy_I     = ProxyNCA(embed_dim, num_classes, temp=0.1,
                       label_smoothing=0.1, proxy_l2=1e-4).to(device)
proxy_hat   = ProxyNCA(embed_dim, num_classes, temp=0.1,
                       label_smoothing=0.1, proxy_l2=1e-4).to(device)

optimizer = torch.optim.Adam(
    list(model.parameters()) + list(proxy_I.parameters()) + list(proxy_hat.parameters()),
    lr=1e-4, weight_decay=1e-5
)

# --------------------------------------------------------------
# 11. KL annealing schedule
# --------------------------------------------------------------
total_epochs = 120
warmup_epochs = 30
kl_start = 0.0
kl_end   = 1.0

def get_kl_weight(epoch):
    if epoch <= warmup_epochs:
        return kl_start + (kl_end - kl_start) * (epoch / warmup_epochs)
    return kl_end

# --------------------------------------------------------------
# 13. Feature extraction helper
# --------------------------------------------------------------
@torch.no_grad()
def extract_feats(loader):
    model.eval()
    embs, lbls = [], []
    for imgs, ls in loader:
        imgs = imgs.to(device)
        z_I, *_ = model(imgs, phase=2)
        embs.append(z_I.cpu())
        lbls.append(ls)
    E = torch.cat(embs)
    L = torch.cat(lbls)
    E = F.normalize(E, dim=1).numpy()
    L = L.numpy()
    return E, L
    
# --------------------------------------------------------------
# 14. Evaluation metrics
# --------------------------------------------------------------
def recall_at_k(feats, labels, k):
    sim = feats @ feats.T
    np.fill_diagonal(sim, -np.inf)
    idx = np.argpartition(-sim, k, axis=1)[:, :k]
    correct = 0
    for i, neigh in enumerate(idx):
        if labels[i] in labels[neigh]:
            correct += 1
    return correct / len(labels)

def clustering_metrics(feats, labels, n_clusters):
    km = KMeans(n_clusters=n_clusters, n_init=10, random_state=0)
    pred = km.fit_predict(feats)
    nmi = normalized_mutual_info_score(labels, pred)
    f1  = f1_score(labels, pred, average='macro')
    return nmi, f1

def pairwise_prec_recall(feats, labels):
    D = pairwise_distances(feats, metric='euclidean')
    y_true = (labels[:, None] == labels[None, :]).astype(int)
    thr = np.median(D)
    y_pred = (D <= thr).astype(int)
    tp = (y_pred * y_true).sum()
    fp = (y_pred * (1-y_true)).sum()
    fn = ((1-y_pred) * y_true).sum()
    prec = tp / (tp + fp + 1e-12)
    rec  = tp / (tp + fn + 1e-12)
    return prec, rec
# --------------------------------------------------------------
# 12. Training loop (two phases + early-stop)
# --------------------------------------------------------------
phase1_epochs = 60
phase2_epochs = total_epochs - phase1_epochs

best_val_nmi = 0.0
patience = 20
wait = 0

# loss weights (paper)
λ1_p1, λ2_p1, λ3_p1, λ4_p1 = 1.0, 1.0, 0.1, 1.0
λ1_p2, λ2_p2, λ3_p2, λ4_p2 = 0.8, 1.0, 0.2, 0.8

for epoch in range(1, total_epochs + 1):
    phase = 1 if epoch <= phase1_epochs else 2
    λ_kl   = λ1_p1 if phase == 1 else λ1_p2
    λ_recon= λ2_p1 if phase == 1 else λ2_p2
    λ_hat  = λ3_p1 if phase == 1 else λ3_p2
    λ_I    = λ4_p1 if phase == 1 else λ4_p2

    model.train()
    proxy_I.train()
    proxy_hat.train()

    epoch_loss = 0.0
    for imgs, lbls in train_loader:
        imgs, lbls = imgs.to(device), lbls.to(device)

        z_I, f, f_hat_flat, z_hat_flat, mu, logvar, _ = model(imgs, phase=phase)

        L_kl   = kl_divergence(mu, logvar)
        L_rec  = recon_loss(f, f_hat_flat, T)
        L_I    = proxy_I(z_I, lbls)
        lbl_hat= lbls.unsqueeze(1).repeat(1,T).view(-1)
        L_hat  = proxy_hat(z_hat_flat, lbl_hat)

        kl_w = get_kl_weight(epoch)
        loss = (kl_w * λ_kl   * L_kl +
                λ_recon * L_rec +
                λ_I     * L_I +
                λ_hat   * L_hat)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    # ------------------- validation -------------------
    model.eval()
    with torch.no_grad():
        val_emb, val_lbl = [], []
        for imgs, lbls in val_loader:
            imgs = imgs.to(device)
            z_I, *_ = model(imgs, phase=2)
            val_emb.append(z_I.cpu())
            val_lbl.append(lbls)
        val_emb = torch.cat(val_emb).numpy()
        val_lbl = torch.cat(val_lbl).numpy()
        val_emb = val_emb / np.linalg.norm(val_emb, axis=1, keepdims=True)

        nmi_val, _ = clustering_metrics(val_emb, val_lbl, n_clusters=num_classes)
        print(f"Epoch {epoch:03d} | Phase {phase} | Loss {epoch_loss/len(train_loader):.4f} | Val NMI {nmi_val:.4f}")

        # early-stop
        if nmi_val > best_val_nmi:
            best_val_nmi = nmi_val
            torch.save(model.state_dict(), "dvml_best.pth")
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print("Early stopping!")
                break

# load best model
model.load_state_dict(torch.load("dvml_best.pth"))


# --------------------------------------------------------------
# 15. Evaluate on TRAIN (known) and TEST (unseen)
# --------------------------------------------------------------
print("\n=== TRAIN (known) SET ===")
train_E, train_L = extract_feats(train_loader)
for k in [1,2,4,8]:
    print(f"Recall@{k}: {recall_at_k(train_E, train_L, k):.4f}")
nmi_tr, f1_tr = clustering_metrics(train_E, train_L, n_clusters=num_classes)
print(f"NMI : {nmi_tr:.4f}")
print(f"F1  : {f1_tr:.4f}")
p_tr, r_tr = pairwise_prec_recall(train_E, train_L)
print(f"Prec: {p_tr:.4f}  Rec: {r_tr:.4f}")

print("\n=== TEST (unseen) SET ===")
test_E, test_L = extract_feats(test_loader)
for k in [1,2,4,8]:
    print(f"Recall@{k}: {recall_at_k(test_E, test_L, k):.4f}")
nmi_te, f1_te = clustering_metrics(test_E, test_L, n_clusters=len(test_classes))
print(f"NMI : {nmi_te:.4f}")
print(f"F1  : {f1_te:.4f}")
p_te, r_te = pairwise_prec_recall(test_E, test_L)
print(f"Prec: {p_te:.4f}  Rec: {r_te:.4f}")

# --------------------------------------------------------------
# 16. PCA visualisation (test set)
# --------------------------------------------------------------
def plot_pca(E, L, title, n_samples=3000):
    if len(E) > n_samples:
        idx = np.random.choice(len(E), n_samples, replace=False)
        E, L = E[idx], L[idx]
    X = PCA(n_components=2).fit_transform(E)
    plt.figure(figsize=(9,7))
    s = plt.scatter(X[:,0], X[:,1], c=L, cmap='tab20', s=12, alpha=0.7)
    plt.title(title)
    plt.xlabel("PC1"); plt.ylabel("PC2")
    plt.colorbar(s)
    plt.grid(True)
    plt.show()

plot_pca(test_E, test_L, "PCA of z_I – 100 unseen classes")


def plot_tsne(E, L, title="t-SNE", n_samples=3000, perplexity=30, random_state=42):
    """
    Visualise embeddings E with labels L using t-SNE.
    
    Parameters
    ----------
    E : np.ndarray, shape (N, D)
        Embedding matrix.
    L : np.ndarray, shape (N,)
        Integer class labels.
    title : str
        Plot title.
    n_samples : int
        If N > n_samples, randomly subsample.
    perplexity : float
        t-SNE perplexity (typical values 5-50).
    random_state : int
        For reproducibility.
    """
    if len(E) > n_samples:
        idx = np.random.choice(len(E), n_samples, replace=False)
        E, L = E[idx], L[idx]

    # t-SNE (Barnes-Hut, O(N log N))
    tsne = TSNE(
        n_components=2,
        perplexity=perplexity,
        n_iter=1000,
        random_state=random_state,
        method='barnes_hut',   # fast approximate version
        n_jobs=-1
    )
    X_2d = tsne.fit_transform(E)

    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(
        X_2d[:, 0], X_2d[:, 1],
        c=L, cmap='tab20', s=15, alpha=0.7
    )
    plt.title(title)
    plt.xlabel("t-SNE 1")
    plt.ylabel("t-SNE 2")
    plt.colorbar(scatter, label='Class')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

# ------------------------------------------------------------------
# Example usage (exactly like your PCA call)
# ------------------------------------------------------------------
plot_tsne(
    test_E, test_L,
    title="t-SNE of AutoEncoder + Proxy-NCA++ (100 Unseen Classes)",
    n_samples=3000,
    perplexity=40          # tweak 20-50 for 100 classes
)