# Import Libraries and Setup Constant Configuration

In [None]:
import os
import sys
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
import importlib
from matplotlib import pyplot as plt
from pathlib import Path
import random
import timm
import argparse
from sklearn.metrics import precision_recall_fscore_support
import numpy as np
import json
from tqdm.auto import tqdm

# =========================
# Constant Configuration
# =========================
HERBARIUM_DOMAIN = 0  # from data_utils: 0 = herbarium, 1 = photo
PHOTO_DOMAIN = 1
EMBED_DIM = 512  # or whatever you chose earlier

EPOCHS = 30
BATCH_SIZE = 32
LR_HEAD = 1e-6          # Head needs to learn fast
LR_BACKBONE_MAX = 1e-6  # Topmost backbone layers
WEIGHT_DECAY = 1e-4
LAYER_DECAY = 0.8       # Each layer gets 80% of the LR of the layer above it

LR_BACKBONE = 1e-6   # smaller LR for pretrained DINO
LR_HEAD = 1e-6       # larger LR for randomly init projection head
WEIGHT_DECAY = 1e-4
# =========================

PROJECT_ROOT = Path.cwd().parents[1]
COMMON_DIR = PROJECT_ROOT / "common"

CHECKPOINT_DIR = PROJECT_ROOT / "experiments" / "triplet_dann" / "checkpoints"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

if str(COMMON_DIR) not in sys.path:
    sys.path.append(str(COMMON_DIR))

from config import DATA_ROOT
import data_utils
importlib.reload(data_utils)
from data_utils import build_train_dataset, build_test_dataset, get_with_without_label_sets

train_ds = build_train_dataset()
test_ds = build_test_dataset()
with_set, without_set = get_with_without_label_sets()

len(train_ds), len(test_ds), len(with_set), len(without_set), DATA_ROOT



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

# Data Preprocessing

In [None]:

sample = train_ds[10]
img = sample["image"]
label = sample["label"]
domain = sample["domain"]

plt.imshow(img.permute(1, 2, 0).numpy() * 0.229 + 0.485)  # quick un-normalise-ish
plt.title(f"label={label}, domain={domain}")
plt.axis("off")

In [None]:
# Cell 4: Dataset loading and with/without-pairs metadata
# If your common/ is a package, this should work directly.
# Otherwise, you may need to adjust sys.path above this cell.
from config import (
    PROJECT_ROOT,
    NUM_CLASSES,
)
from data_utils import (
    build_train_dataset,
    build_test_dataset,
    get_with_without_label_sets,
)

print("PROJECT_ROOT:", PROJECT_ROOT)

# Build datasets
train_dataset = build_train_dataset()
test_dataset = build_test_dataset()

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# with/without-pair LABEL SETS (already mapped to [0..NUM_CLASSES-1])
with_set, without_set = get_with_without_label_sets()
print(f"Classes with pairs    : {len(with_set)}")
print(f"Classes without pairs : {len(without_set)}")

# Quick sanity check: all labels in data should be within [0..NUM_CLASSES-1]
all_train_labels = {s["label"] for s in train_dataset.samples}
all_test_labels = {s["label"] for s in test_dataset.samples}

print("Distinct train labels:", len(all_train_labels))
print("Distinct test labels :", len(all_test_labels))

missing_from_sets = all_train_labels.union(all_test_labels) - (with_set | without_set)
if missing_from_sets:
    print("[Warning] Some labels not in with/without sets:", sorted(missing_from_sets))
else:
    print("All labels covered by with/without splits.")


In [None]:
# Cell 5: TripletDataset for cross-domain metric learning
class TripletDataset(Dataset):
    """
    Samples cross-domain triplets:
      - anchor: herbarium OR photo
      - positive: same class, other domain
      - negative: different class (any domain)
    Using only with-pair species (labels in with_set).
    """

    def __init__(self, base_dataset, with_labels):
        """
        base_dataset: HerbFieldDataset (train)
        with_labels : iterable of label indices that have herbarium-photo pairs
        """
        self.base_dataset = base_dataset
        self.with_labels = sorted(set(with_labels))

        # Pre-build index: label -> {0: [idxs of herbarium], 1: [idxs of photo]}
        self.label_domain_index = {}
        for idx, s in enumerate(self.base_dataset.samples):
            lbl = s["label"]
            dom = s["domain"]
            if lbl not in self.with_labels:
                continue
            if lbl not in self.label_domain_index:
                self.label_domain_index[lbl] = {
                    HERBARIUM_DOMAIN: [],
                    PHOTO_DOMAIN: [],
                }
            self.label_domain_index[lbl][dom].append(idx)

        # Filter out labels that don't actually have both domains
        # (in case the list file is weird)
        cleaned_labels = []
        for lbl in self.with_labels:
            doms = self.label_domain_index.get(lbl, None)
            if doms is None:
                continue
            if len(doms[HERBARIUM_DOMAIN]) > 0 and len(doms[PHOTO_DOMAIN]) > 0:
                cleaned_labels.append(lbl)

        self.with_labels = cleaned_labels
        print(f"[TripletDataset] Usable with-pair classes: {len(self.with_labels)}")

    def __len__(self):
        # Arbitrary: we just want many triplets per epoch; using base length is fine.
        return len(self.base_dataset)

    def _sample_cross_pair(self):
        # Choose a class that has both domains
        lbl = random.choice(self.with_labels)
        # Randomly pick which domain is anchor vs positive
        dom_anchor = random.choice([HERBARIUM_DOMAIN, PHOTO_DOMAIN])
        dom_pos = PHOTO_DOMAIN if dom_anchor == HERBARIUM_DOMAIN else HERBARIUM_DOMAIN

        anchor_idx = random.choice(self.label_domain_index[lbl][dom_anchor])
        pos_idx = random.choice(self.label_domain_index[lbl][dom_pos])

        # Negative: any *other* class from with_labels, any domain that has examples
        neg_lbl = random.choice([c for c in self.with_labels if c != lbl])
        neg_dom_choices = []
        for dom in (HERBARIUM_DOMAIN, PHOTO_DOMAIN):
            if self.label_domain_index[neg_lbl][dom]:
                neg_dom_choices.append(dom)
        neg_dom = random.choice(neg_dom_choices)
        neg_idx = random.choice(self.label_domain_index[neg_lbl][neg_dom])

        return anchor_idx, pos_idx, neg_idx, lbl, neg_lbl

    def __getitem__(self, idx):
        # idx is ignored; we generate a fresh triplet every time
        anchor_idx, pos_idx, neg_idx, lbl, neg_lbl = self._sample_cross_pair()

        a = self.base_dataset[anchor_idx]
        p = self.base_dataset[pos_idx]
        n = self.base_dataset[neg_idx]

        return {
            "anchor": a["image"],
            "positive": p["image"],
            "negative": n["image"],
            "anchor_label": a["label"],
            "positive_label": p["label"],
            "negative_label": n["label"],
            "anchor_domain": a["domain"],
            "positive_domain": p["domain"],
            "negative_domain": n["domain"],
        }

# Instantiate TripletDataset + DataLoader

triplet_dataset = TripletDataset(train_dataset, with_set)
triplet_loader = DataLoader(
    triplet_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True,
    drop_last=True,
)

batch = next(iter(triplet_loader))
print("Triplet batch keys:", batch.keys())
print("Triplet batch anchor tensor shape:", batch["anchor"].shape)


# Architecture Setup

In [None]:
# Cell 6 – Model definition (DINOv2 backbone)
class TripletEncoder(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        backbone_type: str = "dinov2_vitb14",
        pretrained: bool = True,
        freeze_backbone: bool = False,
        proj_hidden_dim: int = 1024,
        proj_layers: int = 2,
        dropout_p: float = 0.0,
    ):
        super().__init__()

        self.backbone_type = backbone_type.lower()

        # Load pretrained backbone weights from external checkpoint
        torch.serialization.add_safe_globals([argparse.Namespace])
        ckpt = torch.load(
            PROJECT_ROOT / "model_best.pth.tar",
            map_location="cpu",
            weights_only=False
        )

        state = ckpt["state_dict"]

        backbone = timm.create_model(
            "vit_base_patch14_reg4_dinov2.lvd142m",
            pretrained=False,
            num_classes=0  # remove original classifier
        )

        missing, unexpected = backbone.load_state_dict(state, strict=False)
        print("Missing keys:", missing)
        print("Unexpected keys:", unexpected)

        # Use checkpoint-initialised backbone instead of hub-loaded one
        # self.backbone = torch.hub.load("facebookresearch/dinov2", self.backbone_type)
        self.backbone = backbone

        # ViT backbone exposes embedding dimension like this:
        feat_dim = getattr(self.backbone, "num_features", None)
        if feat_dim is None:
            feat_dim = getattr(self.backbone, "embed_dim", None)
        if feat_dim is None:
            raise RuntimeError(
                f"Backbone {self.backbone_type} has no 'embed_dim'/'num_features' attribute; "
                "inspect the model to get the feature dimension."
            )

        # ---- optional freezing of entire backbone ----
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # ---- projection head (feat_dim -> embed_dim) ----
        proj_layers_list = []
        in_dim = feat_dim
        for i in range(proj_layers - 1):
            proj_layers_list.append(nn.Linear(in_dim, proj_hidden_dim))
            proj_layers_list.append(nn.BatchNorm1d(proj_hidden_dim))
            proj_layers_list.append(nn.ReLU(inplace=True))
            if dropout_p > 0:
                proj_layers_list.append(nn.Dropout(dropout_p))
            in_dim = proj_hidden_dim
        proj_layers_list.append(nn.Linear(in_dim, embed_dim))

        self.proj_head = nn.Sequential(*proj_layers_list)

    # ---------- NEW: fine-tuning control ----------
    def set_backbone_trainable(self, mode: str = "all", last_k: int = 2):
        """
        Control which backbone layers are trainable.

        mode:
          - "all"  : fine-tune entire backbone
          - "none" : freeze entire backbone (only projection head trains)
          - "last_k": unfreeze only last `last_k` transformer blocks (DINOv2 only)

        For ResNet, "last_k" is not implemented (only all/none).
        """
        mode = mode.lower()
        if mode not in ("all", "none", "last_k"):
            raise ValueError(f"Invalid mode: {mode}")

        # First freeze everything
        for p in self.backbone.parameters():
            p.requires_grad = False

        if mode == "none":
            return  # all frozen, only proj_head will train

        if mode == "all":
            for p in self.backbone.parameters():
                p.requires_grad = True
            return

        # mode == "last_k"
        if self.backbone_type.startswith("dinov2"):
            blocks = getattr(self.backbone, "blocks", None)
            if blocks is None:
                raise RuntimeError(
                    "DINOv2 backbone has no .blocks attribute; inspect the model to adapt this."
                )
            last_k = min(last_k, len(blocks))
            for blk in blocks[-last_k:]:
                for p in blk.parameters():
                    p.requires_grad = True
        else:
            raise NotImplementedError("mode='last_k' only implemented for DINOv2 in this notebook.")

    def forward_backbone(self, x: torch.Tensor) -> torch.Tensor:
        """
        Returns backbone features (before projection).
        DINOv2 ViT: global features (B, embed_dim)
        """
        return self.backbone(x)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.forward_backbone(x)
        z = self.proj_head(feats)
        z = F.normalize(z, p=2, dim=-1)
        return z

# Instantiate with DINOv2 backbone
model = TripletEncoder(
    embed_dim=EMBED_DIM,
    backbone_type="dinov2_vitb14",
    pretrained=True,
    freeze_backbone=False,  # we'll control fine-tuning explicitly below
).to(device)



print("Using device:", device)
print("Backbone:", model.backbone_type)

## Fine-tuning selection

In [None]:



# 1. Unfreeze the last 9 blocks
model.set_backbone_trainable(mode="last_k", last_k=5)

# 2. Build parameter groups with decay
param_groups = []

# Group A: The Projection Head (Highest LR)
param_groups.append({
    "params": [p for p in model.proj_head.parameters() if p.requires_grad],
    "lr": LR_HEAD,
    "weight_decay": WEIGHT_DECAY
})

# Group B: The Backbone Blocks (Decaying LR)
# We iterate blocks in reverse: Block 11 -> Block 0
current_lr = LR_BACKBONE_MAX

if hasattr(model.backbone, "blocks"):
    # Iterate only over blocks that have gradients (the last k)
    for block in reversed(model.backbone.blocks):
        # Only add if this block was unfrozen
        block_params = [p for p in block.parameters() if p.requires_grad]
        if block_params:
            print(f"Block LR: {current_lr:.2e}") # Debug print
            param_groups.append({
                "params": block_params,
                "lr": current_lr,
                "weight_decay": WEIGHT_DECAY
            })
            current_lr *= LAYER_DECAY  # Decay for the next block down
            
backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]
head_params = [p for p in model.proj_head.parameters() if p.requires_grad]

print(f"Trainable backbone params: {sum(p.numel() for p in backbone_params):,}")
print(f"Trainable head params    : {sum(p.numel() for p in head_params):,}")

# 3. Create Optimizer
optimizer = optim.AdamW(param_groups)
# optional LR scheduler, same for both groups
EPOCHS = 30
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

In [None]:
# Cell 7 – Optimizer with separate LR for backbone vs projection head



# choose fine-tuning mode here:
#   "all"   -> full DINO fine-tuning
#   "none"  -> freeze DINO, train only projection
#   "last_k" with last_k blocks
model.set_backbone_trainable(mode="last_k", last_k=9)


backbone_params = [p for p in model.backbone.parameters() if p.requires_grad]
head_params = [p for p in model.proj_head.parameters() if p.requires_grad]

print(f"Trainable backbone params: {sum(p.numel() for p in backbone_params):,}")
print(f"Trainable head params    : {sum(p.numel() for p in head_params):,}")

optimizer = optim.AdamW(
    [
        {"params": backbone_params, "lr": LR_BACKBONE},
        {"params": head_params, "lr": LR_HEAD},
    ],
    weight_decay=WEIGHT_DECAY,
)

# optional LR scheduler, same for both groups
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)


# Model Training

In [None]:
# Cell 8 (Modified): Resume Training from Checkpoint

# --- Setup ---
torch.manual_seed(42)
if device.type == "cuda":
    torch.cuda.manual_seed_all(42)

margin = 0.2
triplet_loss_fn = nn.TripletMarginLoss(margin=margin, p=2)
scaler = GradScaler(enabled=(device.type == "cuda"))

CKPT_DIR = PROJECT_ROOT / "experiments" / "2_stream" / "checkpoints_k9"
CKPT_DIR.mkdir(exist_ok=True, parents=True)

# ---------------------------------------------------------
# RESUME LOGIC
# ---------------------------------------------------------
RESUME_FROM = "epoch_1.pt"  # The filename you want to load
resume_path = CKPT_DIR / RESUME_FROM

START_EPOCH = 0
best_without_acc = 0.0
best_acc = 0.0
history = []

if resume_path.exists():
    print(f"--> Loading checkpoint: {resume_path}")
    state_dict = torch.load(resume_path, map_location=device)
    model.load_state_dict(state_dict)
    
    # Extract epoch number from filename (e.g., "epoch_1.pt" -> 1)
    # This assumes the file format is exactly "epoch_{int}.pt"
    try:
        START_EPOCH = int(RESUME_FROM.split("_")[1].split(".")[0])
        print(f"--> Resuming start from Epoch {START_EPOCH + 1} (Index {START_EPOCH})")
        
        # Since optimizer state wasn't saved, we start optimizer fresh, 
        # but we must step the scheduler to match the current epoch.
        for _ in range(START_EPOCH):
            scheduler.step()
            
    except (ValueError, IndexError):
        print("Could not parse epoch from filename, starting from next step manually.")
        START_EPOCH = 1 # Manual override if parsing fails

    # Optional: Load history if it exists to keep the best score variable correct
    history_path = CKPT_DIR / "history.json"
    if history_path.exists():
        with open(history_path, "r") as f:
            history = json.load(f)
            # Find best previous accuracy to prevent overwriting best_model.pt with worse results
            for record in history:
                if record["unpaired_top1"] > best_without_acc:
                    best_without_acc = record["unpaired_top1"]
        print(f"--> History loaded. Best previous Unpaired Top-1: {best_without_acc:.4%}")
else:
    print(f"Checkpoint {resume_path} not found. Starting from scratch.")

# ---------------------------------------------------------
# HELPER FUNCTIONS (Same as before)
# ---------------------------------------------------------
def build_prototypes(model, loader):
    model.eval()
    proto_sum = torch.zeros(NUM_CLASSES, EMBED_DIM, device=device)
    proto_count = torch.zeros(NUM_CLASSES, dtype=torch.long, device=device)

    with torch.no_grad():
        for batch in loader:
            images = batch["image"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)
            domains = batch["domain"].to(device, non_blocking=True)
            
            mask = (domains == 0) # Herbarium only
            if mask.sum() == 0: continue

            emb = model(images[mask])
            lbls = labels[mask]
            
            for e, l in zip(emb, lbls):
                proto_sum[l] += e
                proto_count[l] += 1
    
    prototypes = torch.zeros_like(proto_sum)
    for c in range(NUM_CLASSES):
        if proto_count[c] > 0:
            prototypes[c] = proto_sum[c] / proto_count[c].float()
            prototypes[c] = F.normalize(prototypes[c], p=2, dim=-1)
            
    return prototypes, proto_count

def run_eval(model, loader, prototypes, proto_count):
    model.eval()
    num_classes = prototypes.shape[0]
    k = min(5, num_classes)
    valid_proto_mask = proto_count > 0

    # (Counters initialization omitted for brevity, same as previous cell)
    # ... [Assume counters set to 0] ...
    total_overall = 0; correct1_overall = 0; correct5_overall = 0
    total_with = 0; correct1_with = 0; correct5_with = 0
    total_without = 0; correct1_without = 0; correct5_without = 0
    
    with torch.no_grad():
        for batch in loader:
            imgs = batch["image"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)
            
            emb = model(imgs)
            sims = emb @ prototypes.T
            sims[:, ~valid_proto_mask] = -1e9
            topk_vals, topk_idx = sims.topk(k=k, dim=1)
            preds_top1 = topk_idx[:, 0]

            labels_cpu = labels.cpu().tolist()
            top1_cpu = preds_top1.cpu().tolist()
            topk_cpu = topk_idx.cpu().tolist()

            for lbl, p1, pk_list in zip(labels_cpu, top1_cpu, topk_cpu):
                total_overall += 1
                if p1 == lbl: correct1_overall += 1
                if lbl in pk_list: correct5_overall += 1

                if lbl in with_set:
                    total_with += 1
                    if p1 == lbl: correct1_with += 1
                    if lbl in pk_list: correct5_with += 1
                elif lbl in without_set:
                    total_without += 1
                    if p1 == lbl: correct1_without += 1
                    if lbl in pk_list: correct5_without += 1

    def safe_div(num, den): return float(num) / float(den) if den > 0 else 0.0

    return {
        "overall_top1": safe_div(correct1_overall, total_overall),
        "overall_top5": safe_div(correct5_overall, total_overall),
        "paired_top1":  safe_div(correct1_with,    total_with),
        "paired_top5":  safe_div(correct5_with,    total_with),
        "unpaired_top1": safe_div(correct1_without, total_without),
        "unpaired_top5": safe_div(correct5_without, total_without),
        "counts": {"overall": total_overall, "paired": total_with, "unpaired": total_without},
    }

# --- Loaders ---
proto_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
eval_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"Starting Training from Epoch {START_EPOCH + 1} to {EPOCHS}...")

# --- Main Loop Modified Range ---
for epoch in range(START_EPOCH, EPOCHS):
    # 1. Train
    model.train()
    running_loss = 0.0
    count = 0
    
    pbar = tqdm(triplet_loader, desc=f"Ep {epoch+1}/{EPOCHS}", leave=False)
    for batch in pbar:
        optimizer.zero_grad()
        
        anc = batch["anchor"].to(device)
        pos = batch["positive"].to(device)
        neg = batch["negative"].to(device)
        
        with autocast(enabled=(device.type == "cuda")):
            ea = model(anc)
            ep = model(pos)
            en = model(neg)
            loss = triplet_loss_fn(ea, ep, en)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": f"{running_loss/count:.4f}"})
    
    epoch_loss = running_loss / count
    
    # 2. Build Prototypes
    prototypes, proto_count = build_prototypes(model, proto_loader)
    
    # 3. Evaluate
    eval_metrics = run_eval(model, eval_loader, prototypes, proto_count)
    
    # 4. Record History
    record = {
        "epoch": epoch + 1,
        "train_loss": float(epoch_loss),
        "overall_top1":  eval_metrics["overall_top1"],
        "overall_top5":  eval_metrics["overall_top5"],
        "paired_top1":   eval_metrics["paired_top1"],
        "paired_top5":   eval_metrics["paired_top5"],
        "unpaired_top1": eval_metrics["unpaired_top1"],
        "unpaired_top5": eval_metrics["unpaired_top5"],
        "counts": eval_metrics["counts"],
    }
    history.append(record)
    
    print(
        f"[Epoch {epoch+1}] "
        f"Loss: {epoch_loss:.4f} | "
        f"Ov T1: {eval_metrics['overall_top1']:.4%} | "
        f"Ov T5: {eval_metrics['overall_top5']:.4%} | "
        f"Paired T1: {eval_metrics['paired_top1']:.4%} | "
        f"Paired T5: {eval_metrics['paired_top5']:.4%} | "
        f"Unpaired T1: {eval_metrics['unpaired_top1']:.4%} | "
        f"Unpaired T5: {eval_metrics['unpaired_top5']:.4%} "
    )
    
    # 5. Save checkpoint
    torch.save(model.state_dict(), CKPT_DIR / f"epoch_{epoch+1}.pt")
    
    # Track best model
    current_without_acc = eval_metrics["unpaired_top1"]
    current_with_acc = eval_metrics["paired_top1"]
    
    should_save_best = False
    if current_without_acc > best_without_acc:
        # Better unpaired accuracy
        best_without_acc = current_without_acc
        best_acc = current_with_acc
        should_save_best = True
        print(f"--> New Best Model Saved! (Unpaired Top-1 = {best_without_acc:.4%})")
    elif current_without_acc == best_without_acc and current_with_acc > best_acc:
        # Same unpaired accuracy but better paired accuracy
        best_acc = current_with_acc
        should_save_best = True
        print(f"--> New Best Model Saved! (Unpaired Top-1 = {best_without_acc:.4%}, Paired Top-1 = {best_acc:.4%})")
    
    if should_save_best:
        torch.save(model.state_dict(), CKPT_DIR / "best_model.pt")
        
    # Save History
    with open(CKPT_DIR / "history.json", "w") as f:
        json.dump(history, f, indent=4)

    scheduler.step()

print("Training Complete.")

# Model Evaluation

In [None]:
# Cell 8: Build herbarium prototypes for all classes


proto_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

model.eval()

# Accumulate sum and counts per class
proto_sum = torch.zeros(NUM_CLASSES, EMBED_DIM, device=device)
proto_count = torch.zeros(NUM_CLASSES, dtype=torch.long, device=device)

with torch.no_grad():
    pbar = tqdm(proto_loader, desc="Building prototypes", leave=False)
    for batch in pbar:
        images = batch["image"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True)
        domains = batch["domain"].to(device, non_blocking=True)

        # Only herbarium samples
        mask = (domains == HERBARIUM_DOMAIN)
        if mask.sum() == 0:
            continue

        imgs_h = images[mask]
        lbls_h = labels[mask]

        emb = model(imgs_h)  # already normalized
        for e, lbl in zip(emb, lbls_h):
            proto_sum[lbl] += e
            proto_count[lbl] += 1

# Compute mean and re-normalize per class
prototypes = torch.zeros_like(proto_sum)
for c in range(NUM_CLASSES):
    if proto_count[c] > 0:
        prototypes[c] = proto_sum[c] / proto_count[c].float()
        prototypes[c] = F.normalize(prototypes[c], p=2, dim=-1)
    else:
        # no herbarium for this class; leave as zeros (we'll detect later)
        pass

missing_proto = (proto_count == 0).nonzero(as_tuple=True)[0].tolist()
if missing_proto:
    print("[Warning] Classes with no herbarium prototypes:", missing_proto)
else:
    print("All classes have herbarium prototypes.")

# Optionally save prototypes to disk
proto_path = CKPT_DIR / "epoch_4.pt"
torch.save(
    {
        "prototypes": prototypes.cpu(),
        "proto_count": proto_count.cpu(),
        "embed_dim": EMBED_DIM,
    },
    proto_path,
)
print("Saved prototypes to:", proto_path)


In [None]:
# Load checkpoint for evaluation
CKPT_DIR_TEST = PROJECT_ROOT / "experiments" / "2_stream" / "checkpoints_k5"
checkpoint_path = CKPT_DIR_TEST / "epoch_25.pt"

if checkpoint_path.exists():
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
    print(f"✓ Loaded checkpoint: {checkpoint_path}")
else:
    print(f"⚠ Warning: Checkpoint not found at {checkpoint_path}")
    print("  Using current model state in memory")

model.eval()

In [None]:
# Cell 9: Evaluation (Top-1 / Top-5, overall + with/without-pair)

eval_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

model.eval()
prototypes = prototypes.to(device)
# if some prototypes are zero vectors (no herbarium), we will ignore them by masking later
proto_norms = prototypes.norm(dim=1)
valid_proto_mask = (proto_norms > 0.0)

def evaluate_topk(k=5):
    total = 0
    correct_top1 = 0
    correct_topk = 0

    total_with = 0
    correct_top1_with = 0
    correct_topk_with = 0

    total_without = 0
    correct_top1_without = 0
    correct_topk_without = 0

    with torch.no_grad():
        pbar = tqdm(eval_loader, desc="Evaluating", leave=False)
        for batch in pbar:
            imgs = batch["image"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)

            emb = model(imgs)  # [B, D], normalized

            # cosine similarity: dot product since embeddings & prototypes are unit-normalized
            sims = emb @ prototypes.T  # [B, NUM_CLASSES]

            # If some classes have invalid prototypes, set their sim to -inf
            sims[:, ~valid_proto_mask] = -1e9

            topk_vals, topk_idx = sims.topk(k=k, dim=1)  # [B, k]

            # Overall
            total_batch = labels.size(0)
            total += total_batch

            # Top-1
            pred_top1 = topk_idx[:, 0]
            correct1 = (pred_top1 == labels).sum().item()
            correct_top1 += correct1

            # Top-k
            correctk = (topk_idx == labels.unsqueeze(1)).any(dim=1).sum().item()
            correct_topk += correctk

            # Split by with/without-pair
            labels_cpu = labels.cpu().tolist()
            pred_top1_cpu = pred_top1.cpu().tolist()
            topk_idx_cpu = topk_idx.cpu().tolist()

            for lbl, p1, pk_list in zip(labels_cpu, pred_top1_cpu, topk_idx_cpu):
                in_with = lbl in with_set
                in_without = lbl in without_set

                if in_with:
                    total_with += 1
                    if p1 == lbl:
                        correct_top1_with += 1
                    if lbl in pk_list:
                        correct_topk_with += 1
                elif in_without:
                    total_without += 1
                    if p1 == lbl:
                        correct_top1_without += 1
                    if lbl in pk_list:
                        correct_topk_without += 1

    results = {
        "overall": {
            "total": total,
            "top1_correct": correct_top1,
            "topk_correct": correct_topk,
        },
        "with": {
            "total": total_with,
            "top1_correct": correct_top1_with,
            "topk_correct": correct_topk_with,
        },
        "without": {
            "total": total_without,
            "top1_correct": correct_top1_without,
            "topk_correct": correct_topk_without,
        },
    }
    return results

results_k1 = evaluate_topk(k=1)
results_k5 = evaluate_topk(k=5)

def print_results(name, res):
    total = res["total"]
    if total == 0:
        print(f"{name}: no samples.")
        return
    t1 = res["top1_correct"]
    tk = res["topk_correct"]
    print(
        f"{name}: "
        f"Top-1 = {t1}/{total} = {t1/total:.4%}, "
        f"Top-k = {tk}/{total} = {tk/total:.4%}"
    )

print("=== Top-1 (k=1) ===")
print_results("Overall", results_k1["overall"])
print_results("With-pair", results_k1["with"])
print_results("Without-pair", results_k1["without"])

print("\n=== Top-5 (k=5) ===")
print_results("Overall", results_k5["overall"])
print_results("With-pair", results_k5["with"])
print_results("Without-pair", results_k5["without"])


In [None]:


# Assuming NUM_CLASSES is defined globally (e.g., NUM_CLASSES = 100)

def calculate_metrics_and_print(cm, all_labels, all_preds, best_idx, worst_idx):
    """Calculates and prints Macro-Averaged and detailed OvR metrics."""
    
    # 1. Get per-class metrics (Precision, Recall, F1-Score)
    # labels=np.arange(cm.shape[0]) ensures we calculate for all 100 classes (0-99).
    precision, recall, f1, support = precision_recall_fscore_support(
        all_labels, all_preds, labels=np.arange(cm.shape[0]), zero_division=0.0
    )
    
    # 2. Calculate TP, FP, FN, TN (One-vs-Rest approach)
    S_total = cm.sum()
    TP = cm.diagonal()
    FP = cm.sum(axis=0) - TP # Predicted Positive - TP
    FN = cm.sum(axis=1) - TP # Actual Positive - TP
    TN = S_total - TP - FP - FN # Total - TP - FP - FN

    # 3. Aggregate Metrics
    macro_precision = precision.mean()
    macro_recall = recall.mean()
    macro_f1 = f1.mean()

    print("\n" + "="*40)
    print("Aggregate Metrics (Macro-Averaged)")
    print(f"Macro Precision: {macro_precision:.4f}")
    print(f"Macro Recall:    {macro_recall:.4f}")
    print(f"Macro F1-Score:  {macro_f1:.4f}")
    print("="*40 + "\n")

    # 4. Print Detailed Metrics for Best/Worst Classes
    def print_class_metrics(idx, name):
        # Only proceed if the class index actually has test samples
        if support[idx] == 0:
            print(f"--- Detailed Metrics for {name} Class (Index {idx}) ---")
            print("Note: No samples in the test set for this class index.")
            print("-" * 40)
            return

        accuracy = TP[idx] / support[idx]
            
        print(f"--- Detailed Metrics for {name} Class (Index {idx}) ---")
        print(f"Accuracy:  {accuracy:.4f}")
        print(f"TP (True Positives):  {TP[idx]}")
        print(f"TN (True Negatives):  {TN[idx]}")
        print(f"FP (False Positives): {FP[idx]}")
        print(f"FN (False Negatives): {FN[idx]}")
        print(f"Precision: {precision[idx]:.4f}")
        print(f"Recall:    {recall[idx]:.4f}")
        print(f"F1-Score:  {f1[idx]:.4f}")
        print("-" * 40)
        
    print_class_metrics(best_idx, "BEST")
    print_class_metrics(worst_idx, "WORST")


# --- Integration into the original notebook flow ---

# 1. Modify the prediction loop (visualize_performance equivalent)
#    to return all_labels, all_preds, and cm.

# 2. Call the function:
#    best_idx, worst_idx = visualize_performance_with_metrics(...)
#    calculate_metrics_and_print(cm, all_labels, all_preds, best_idx, worst_idx)