# Import Libraries and Setup Constant Configuration

In [1]:
import os
import sys
import time
import random
import json
import argparse
import numpy as np
from pathlib import Path
from collections import defaultdict
from typing import List, Dict, Tuple
from PIL import Image
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
import importlib
from sklearn.metrics import precision_recall_fscore_support
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.cuda import amp
from torchvision import transforms
import torchvision.models as models
import timm


# =========================
# Pass configurations
# =========================
PROJECT_ROOT = Path.cwd().parents[1]

DATASET_DIRNAME = "AML_project_herbarium_dataset"
DATA_ROOT = PROJECT_ROOT / DATASET_DIRNAME

TRAIN_DIR = DATA_ROOT / "train"          
TEST_DIR = DATA_ROOT / "test"              

LIST_DIR = DATA_ROOT / "list"
TRAIN_LIST = LIST_DIR / "train.txt"
TEST_LIST = LIST_DIR / "test.txt"
SPECIES_LIST = LIST_DIR / "species_list.txt"
GROUNDTRUTH = LIST_DIR / "groundtruth.txt"
CLASS_WITH_PAIRS = LIST_DIR / "class_with_pairs.txt"
CLASS_WITHOUT_PAIRS = LIST_DIR / "class_without_pairs.txt"

NUM_CLASSES = 100

IMAGE_SIZE = 518
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

# =========================
# Device Setup
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
print(f"Project Root: {PROJECT_ROOT}")
print(f"Data Root: {DATA_ROOT}")


# =========================
# Constant Configuration 
# =========================
HERBARIUM_DOMAIN = 0  # 0 = herbarium, 1 = photo
PHOTO_DOMAIN = 1
EMBED_DIM = 512
NUM_WORKERS = 0
EPOCHS = 30
BATCH_SIZE = 8
USE_AMP = (device.type == "cuda")

LR_BACKBONE_MAX = 1e-6  # Topmost backbone layers
LAYER_DECAY = 0.8       # Each layer gets 80% of the LR of the layer above it
LR_BACKBONE = 1e-6      # smaller LR for pretrained DINO
LR_HEAD = 1e-6          # larger LR for randomly init projection head
WEIGHT_DECAY = 1e-4
TRIPLET_BATCH_SIZE = 4
TRIPLET_NUM_WORKERS = 0









  from .autonotebook import tqdm as notebook_tqdm


Device: cuda
Project Root: c:\Users\William\School\Swinburne\Computer Science\2025 Semester 2\COS30082 Applied Machine Learning\Assignment 2\Approach3
Data Root: c:\Users\William\School\Swinburne\Computer Science\2025 Semester 2\COS30082 Applied Machine Learning\Assignment 2\Approach3\AML_project_herbarium_dataset


# Data Preprocessing

In [None]:
# =========================
# 2. Data Preprocessing
# =========================

# --- Transforms ---

def build_train_transform(image_size: int = IMAGE_SIZE):
    return transforms.Compose([
        transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        transforms.RandomRotation(15),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(
            brightness=0.4,
            contrast=0.4,
            saturation=0.2,
            hue=0.1,
        ),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

def build_eval_transform(image_size: int = IMAGE_SIZE):
    return transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

# --- Data Utilities ---

_LABEL_MAP = None

def _get_label_map():
    global _LABEL_MAP
    if _LABEL_MAP is not None:
        return _LABEL_MAP

    mapping = {}
    idx = 0
    if not SPECIES_LIST.exists():
        print(f"Warning: {SPECIES_LIST} not found.")
        return {}
        
    with SPECIES_LIST.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            left = line.split(";", 1)[0].strip()
            try:
                raw_id = int(left)
            except ValueError:
                continue

            if raw_id not in mapping:
                mapping[raw_id] = idx
                idx += 1

    _LABEL_MAP = mapping
    return mapping

def _map_label(raw_label: int) -> int:
    mapping = _get_label_map()
    if raw_label not in mapping:
        # Fallback or error based on preference
        raise KeyError(f"Raw label {raw_label} not found in species_list.txt")
    return mapping[raw_label]

def _load_raw_id_list(path: Path):
    ids = []
    if not path.exists():
        return ids
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            token = line.split(";", 1)[0].split()[0]
            try:
                raw_id = int(token)
            except ValueError:
                continue
            ids.append(raw_id)
    return ids

def get_with_without_label_sets():
    mapping = _get_label_map()
    with_ids = _load_raw_id_list(CLASS_WITH_PAIRS)
    without_ids = _load_raw_id_list(CLASS_WITHOUT_PAIRS)

    with_set = {mapping[i] for i in with_ids if i in mapping}
    without_set = {mapping[i] for i in without_ids if i in mapping}
    return with_set, without_set

def _parse_train_list(path: Path, root: Path) -> List[Dict]:
    samples: List[Dict] = []
    if not path.exists():
        print(f"Error: Train list {path} not found.")
        return samples
        
    with path.open("r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            rel_path, label_str = line.split()
            label = _map_label(int(label_str))
            full_path = root / rel_path

            if "herbarium" in rel_path:
                domain = 0
            elif "photo" in rel_path:
                domain = 1
            else:
                # fallback
                domain = 0

            samples.append({
                "path": full_path,
                "label": label,
                "domain": domain,
                "rel_path": rel_path,
            })
    return samples

def _parse_test_list_with_groundtruth(test_list_path: Path, gt_path: Path, dataset_root: Path) -> List[Dict]:
    rel_paths: List[str] = []
    if not test_list_path.exists() or not gt_path.exists():
        print("Error: Test list or Groundtruth not found.")
        return []

    with test_list_path.open("r") as f:
        for line in f:
            rel_path = line.strip()
            if not rel_path:
                continue
            rel_paths.append(rel_path)

    gt_entries: List[Tuple[str, int]] = []
    with gt_path.open("r") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            parts = line.split()
            if len(parts) == 1:
                gt_entries.append(("", _map_label(int(parts[0]))))
            else:
                rel_from_gt = parts[0]
                label = _map_label(int(parts[-1]))
                gt_entries.append((rel_from_gt, label))

    if len(rel_paths) != len(gt_entries):
        raise ValueError(f"Mismatch length: {len(rel_paths)} vs {len(gt_entries)}")

    samples: List[Dict] = []
    for idx, rel_path in enumerate(rel_paths):
        rel_from_gt, label = gt_entries[idx]
        if rel_from_gt and rel_from_gt != rel_path:
            raise ValueError(f"Mismatch at line {idx + 1}: '{rel_from_gt}' vs '{rel_path}'")

        full_path = dataset_root / rel_path
        samples.append({
            "path": full_path,
            "label": label,
            "domain": 1,
            "rel_path": rel_path,
        })

    return samples

class HerbFieldDataset(Dataset):
    def __init__(self, samples: List[Dict], transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int):
        s = self.samples[idx]
        try:
            img = Image.open(s["path"]).convert("RGB")
        except Exception as e:
            print(f"Error loading {s['path']}: {e}")
            # Create blank image on failure to prevent crash
            img = Image.new("RGB", (IMAGE_SIZE, IMAGE_SIZE))
            
        if self.transform is not None:
            img = self.transform(img)
        return {
            "image": img,
            "label": s["label"],
            "domain": s["domain"],
            "rel_path": s["rel_path"],
        }

def compute_class_weights(samples: List[Dict]) -> torch.Tensor:
    counts = torch.zeros(NUM_CLASSES, dtype=torch.float)
    for s in samples:
        counts[s["label"]] += 1.0
    counts = torch.clamp(counts, min=1.0)
    weights = 1.0 / torch.log1p(counts)
    return weights

def build_train_dataset() -> HerbFieldDataset:
    samples = _parse_train_list(TRAIN_LIST, DATA_ROOT)
    return HerbFieldDataset(samples, transform=build_train_transform())

def build_test_dataset() -> HerbFieldDataset:
    samples = _parse_test_list_with_groundtruth(TEST_LIST, GROUNDTRUTH, DATA_ROOT)
    return HerbFieldDataset(samples, transform=build_eval_transform())

# --- Initialization ---

print("Initializing datasets...")
train_ds = build_train_dataset()
test_ds = build_test_dataset()
with_set, without_set = get_with_without_label_sets()

print(f"Train size: {len(train_ds)}")
print(f"Test size: {len(test_ds)}")
print(f"With-Pairs classes: {len(with_set)}")
print(f"Without-Pairs classes: {len(without_set)}")
all_train_labels = {s["label"] for s in train_ds.samples}
all_test_labels = {s["label"] for s in test_ds.samples}

print("Distinct train labels:", len(all_train_labels))
print("Distinct test labels :", len(all_test_labels))

Initializing datasets...
Train size: 4744
Test size: 207
With-Pairs classes: 60
Without-Pairs classes: 40


# Dataset Setup

In [4]:
# =========================
# 3. Data Loaders & Triplet Dataset
# =========================

# --- A. Standard Classification Loaders (For Hybrid Model / Evaluation) ---
# We need these for the CrossEntropy loss and for calculating Prototypes.

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

# --- B. Triplet Dataset (Unified Superset) ---
# This implementation (from metric_learning.ipynb) computes indices internally
# and returns extra metadata (labels/domains), supporting BOTH model types.

class TripletDataset(Dataset):
    """
    Samples cross-domain triplets:
      - anchor: herbarium OR photo
      - positive: same class, other domain
      - negative: different class (any domain)
    Using only with-pair species (labels in with_set).
    """

    def __init__(self, base_dataset, with_labels):
        """
        base_dataset: HerbFieldDataset (train)
        with_labels : iterable of label indices that have herbarium-photo pairs
        """
        self.base_dataset = base_dataset
        self.with_labels = sorted(set(with_labels))

        # Pre-build index: label -> {0: [idxs of herbarium], 1: [idxs of photo]}
        self.label_domain_index = {}
        for idx, s in enumerate(self.base_dataset.samples):
            lbl = s["label"]
            dom = s["domain"]
            if lbl not in self.with_labels:
                continue
            if lbl not in self.label_domain_index:
                self.label_domain_index[lbl] = {
                    HERBARIUM_DOMAIN: [],
                    PHOTO_DOMAIN: [],
                }
            self.label_domain_index[lbl][dom].append(idx)

        # Filter out labels that don't actually have both domains
        cleaned_labels = []
        for lbl in self.with_labels:
            doms = self.label_domain_index.get(lbl, None)
            if doms is None:
                continue
            if len(doms[HERBARIUM_DOMAIN]) > 0 and len(doms[PHOTO_DOMAIN]) > 0:
                cleaned_labels.append(lbl)

        self.with_labels = cleaned_labels
        print(f"[TripletDataset] Usable with-pair classes: {len(self.with_labels)}")

    def __len__(self):
        # We can define this arbitrarily since we sample randomly. 
        # Using base dataset length ensures a "full" epoch feel.
        return len(self.base_dataset)

    def _sample_cross_pair(self):
        # Choose a class that has both domains
        lbl = random.choice(self.with_labels)
        
        # Randomly pick which domain is anchor vs positive
        dom_anchor = random.choice([HERBARIUM_DOMAIN, PHOTO_DOMAIN])
        dom_pos = PHOTO_DOMAIN if dom_anchor == HERBARIUM_DOMAIN else HERBARIUM_DOMAIN

        anchor_idx = random.choice(self.label_domain_index[lbl][dom_anchor])
        pos_idx = random.choice(self.label_domain_index[lbl][dom_pos])

        # Negative: any *other* class from with_labels
        neg_lbl = random.choice([c for c in self.with_labels if c != lbl])
        
        # Negative domain: pick any domain that actually has images for that class
        neg_dom_choices = []
        for dom in (HERBARIUM_DOMAIN, PHOTO_DOMAIN):
            if self.label_domain_index[neg_lbl][dom]:
                neg_dom_choices.append(dom)
        neg_dom = random.choice(neg_dom_choices)
        neg_idx = random.choice(self.label_domain_index[neg_lbl][neg_dom])

        return anchor_idx, pos_idx, neg_idx, lbl, neg_lbl

    def __getitem__(self, idx):
        # idx is ignored; we generate a fresh triplet every time
        anchor_idx, pos_idx, neg_idx, lbl, neg_lbl = self._sample_cross_pair()

        a = self.base_dataset[anchor_idx]
        p = self.base_dataset[pos_idx]
        n = self.base_dataset[neg_idx]

        return {
            "anchor": a["image"],
            "positive": p["image"],
            "negative": n["image"],
            # Metadata (Useful for Metric learning, ignored by Hybrid)
            "anchor_label": a["label"],
            "positive_label": p["label"],
            "negative_label": n["label"],
            "anchor_domain": a["domain"],
            "positive_domain": p["domain"],
            "negative_domain": n["domain"],
        }

# --- C. Triplet Loader ---

triplet_ds = TripletDataset(train_ds, with_set)

triplet_loader = DataLoader(
    triplet_ds,
    batch_size=TRIPLET_BATCH_SIZE, # Defined in Cell 1
    shuffle=True,
    num_workers=TRIPLET_NUM_WORKERS, # Defined in Cell 1
    pin_memory=True,
    drop_last=True,
)

# Warm-up check (from triplet.ipynb)
batch = next(iter(triplet_loader))
print("Triplet batch shapes:", batch["anchor"].shape, batch["positive"].shape, batch["negative"].shape)

[TripletDataset] Usable with-pair classes: 60
Triplet batch shapes: torch.Size([4, 3, 518, 518]) torch.Size([4, 3, 518, 518]) torch.Size([4, 3, 518, 518])


# Architecture Setup

In [5]:
# =========================
# 4. Model Architectures
# =========================

# --- A. Metric Learning Model (Backbone + Projection Head) ---
# Source: metric_learning.ipynb
# Best for: Pure distance-based learning (Prototypes)

class TripletEncoder(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        backbone_type: str = "dinov2_vitb14",
        pretrained: bool = True,
        freeze_backbone: bool = False,
        proj_hidden_dim: int = 1024,
        proj_layers: int = 2,
        dropout_p: float = 0.0,
    ):
        super().__init__()
        self.backbone_type = backbone_type.lower()

        # 1. Create Backbone
        # Note: In metric_learning.ipynb, there was logic to load a local checkpoint.
        # We default to standard timm loading here for safety, but you can uncomment
        # the local load logic if you have "model_best.pth.tar".
        self.backbone = timm.create_model(
            "vit_base_patch14_reg4_dinov2.lvd142m",
            pretrained=pretrained,
            num_classes=0  # remove original classifier
        )

        # 2. Get Feature Dimension
        feat_dim = getattr(self.backbone, "num_features", None)
        if feat_dim is None:
            feat_dim = getattr(self.backbone, "embed_dim", None)

        # 3. Optional Freezing
        if freeze_backbone:
            for p in self.backbone.parameters():
                p.requires_grad = False

        # 4. Projection Head (feat_dim -> embed_dim)
        proj_layers_list = []
        in_dim = feat_dim
        for i in range(proj_layers - 1):
            proj_layers_list.append(nn.Linear(in_dim, proj_hidden_dim))
            proj_layers_list.append(nn.BatchNorm1d(proj_hidden_dim))
            proj_layers_list.append(nn.ReLU(inplace=True))
            if dropout_p > 0:
                proj_layers_list.append(nn.Dropout(dropout_p))
            in_dim = proj_hidden_dim
        proj_layers_list.append(nn.Linear(in_dim, embed_dim))

        self.proj_head = nn.Sequential(*proj_layers_list)

    def set_backbone_trainable(self, mode: str = "all", last_k: int = 2):
        """
        Helper to control fine-tuning depth (Metric Learning approach).
        """
        mode = mode.lower()
        # Freeze everything first
        for p in self.backbone.parameters():
            p.requires_grad = False

        if mode == "none":
            return
        if mode == "all":
            for p in self.backbone.parameters():
                p.requires_grad = True
            return

        # Unfreeze last k blocks
        if hasattr(self.backbone, "blocks"):
            blocks = self.backbone.blocks
            last_k = min(last_k, len(blocks))
            for blk in blocks[-last_k:]:
                for p in blk.parameters():
                    p.requires_grad = True

    def forward_backbone(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.forward_backbone(x)
        z = self.proj_head(feats)
        z = F.normalize(z, p=2, dim=-1)
        return z


# --- B. Hybrid Model (Backbone + Linear Classifier) ---
# Source: triplet.ipynb
# Best for: Classification + Regularization (CrossEntropy + Triplet Loss)

class DinoTriplet(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        self.backbone = timm.create_model(
            "vit_base_patch14_reg4_dinov2.lvd142m",
            pretrained=True,
            num_classes=0,  # embeddings only
        )
        embed_dim = self.backbone.num_features
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x, return_embedding=False, normalize_embedding=True):
        feats = self.backbone(x)  # (B, D)
        logits = self.classifier(feats)
        if return_embedding:
            emb = F.normalize(feats, p=2, dim=1) if normalize_embedding else feats
            return logits, emb
        return logits

    def encode(self, x, normalize=True):
        """Helper for Triplet Loss calculation"""
        feats = self.backbone(x)
        if normalize:
            return F.normalize(feats, p=2, dim=1)
        return feats

print("Models defined: TripletEncoder (Metric) and DinoTriplet (Hybrid)")

Models defined: TripletEncoder (Metric) and DinoTriplet (Hybrid)


# Optimization and Loss Configuration

In [None]:
# =========================
# 5. Optimizer, Loss, and Scheduler Setup
# =========================

# --- Helper 1: Metric Learning Optimizer Groups (Layer-wise Decay) ---
# Source: metric_learning.ipynb
def get_metric_optimizer_groups(model, lr_head, lr_backbone_max, weight_decay, layer_decay=0.8):
    param_groups = []
    
    # Group A: Projection Head (Highest LR)
    if hasattr(model, "proj_head"):
        head_params = [p for p in model.proj_head.parameters() if p.requires_grad]
        if head_params:
            param_groups.append({
                "params": head_params,
                "lr": lr_head,
                "weight_decay": weight_decay
            })

    # Group B: Backbone Blocks (Decaying LR)
    # Iterate blocks in reverse: Block 11 -> Block 0
    current_lr = lr_backbone_max
    
    if hasattr(model.backbone, "blocks"):
        for block in reversed(model.backbone.blocks):
            block_params = [p for p in block.parameters() if p.requires_grad]
            if block_params:
                param_groups.append({
                    "params": block_params,
                    "lr": current_lr,
                    "weight_decay": weight_decay
                })
                current_lr *= layer_decay # Decay for next block down
                
    # Add generic backbone params (norm, patch_embed, etc.) if unfrozen
    other_backbone_params = []
    for name, p in model.backbone.named_parameters():
        if "blocks" not in name and p.requires_grad:
            other_backbone_params.append(p)
    if other_backbone_params:
         param_groups.append({
            "params": other_backbone_params,
            "lr": current_lr, # Lowest LR
            "weight_decay": weight_decay
        })
        
    return param_groups

# --- Helper 2: Hybrid Learning Optimizer Groups ---
# Source: triplet.ipynb
def get_hybrid_optimizer_groups(model, base_lr=1e-4, decay_rate=0.9):
    param_groups = []

    # 1. Classifier (Highest LR)
    param_groups.append({
        "params": [p for p in model.classifier.parameters() if p.requires_grad],
        "lr": base_lr * 1.0,
        "weight_decay": 1e-4,
    })

    # 2. Final Norm
    param_groups.append({
        "params": [p for p in model.backbone.norm.parameters() if p.requires_grad],
        "lr": base_lr * 0.9,
        "weight_decay": 1e-4,
    })

    # 3. Backbone Blocks (Scaled LR)
    n_blocks = len(model.backbone.blocks)
    for i, block in enumerate(model.backbone.blocks):
        depth = i
        lr_scale = decay_rate ** (n_blocks - depth - 1)
        params = [p for p in block.parameters() if p.requires_grad]
        if not params: continue
        
        param_groups.append({
            "params": params,
            "lr": base_lr * lr_scale,
            "weight_decay": 1e-4,
        })

    # 4. Patch Embed (Lowest LR)
    patch_params = [p for p in model.backbone.patch_embed.parameters() if p.requires_grad]
    if patch_params:
        param_groups.append({
            "params": patch_params,
            "lr": base_lr * (decay_rate ** n_blocks),
            "weight_decay": 1e-4,
        })

    return param_groups

# =========================
# CONFIGURATION SELECTOR
# =========================

# TOGGLE THIS: "metric" OR "hybrid"
TRAINING_MODE = "hybrid" 
print(f"--> Setting up for mode: {TRAINING_MODE}")

if TRAINING_MODE == "metric":
    # 1. Instantiate Model
    model = TripletEncoder(
        embed_dim=EMBED_DIM,
        backbone_type="dinov2_vitb14",
        pretrained=True,
        freeze_backbone=False # Handled by set_backbone_trainable
    ).to(device)
    
    # 2. Fine-tuning Setup
    model.set_backbone_trainable(mode="last_k", last_k=5)
    
    # 3. Optimizer Groups
    param_groups = get_metric_optimizer_groups(
        model, 
        lr_head=LR_HEAD, 
        lr_backbone_max=LR_BACKBONE_MAX, 
        weight_decay=WEIGHT_DECAY,
        layer_decay=LAYER_DECAY
    )
    
    # 4. Loss Functions
    # Metric learning only uses Triplet Loss
    criterion_triplet = nn.TripletMarginLoss(margin=0.2, p=2)
    criterion_ce = None # Not used

elif TRAINING_MODE == "hybrid":
    # 1. Instantiate Model
    model = DinoTriplet(num_classes=NUM_CLASSES).to(device)
    
    # 2. Freezing Logic (Manual from triplet.ipynb)
    for p in model.backbone.parameters(): p.requires_grad = False
    n_blocks = len(model.backbone.blocks)
    for i in range(n_blocks - 5, n_blocks): # Unfreeze last 5 blocks
        for p in model.backbone.blocks[i].parameters(): p.requires_grad = True
    for p in model.backbone.norm.parameters(): p.requires_grad = True
    for p in model.classifier.parameters(): p.requires_grad = True
    
    # 3. Optimizer Groups
    param_groups = get_hybrid_optimizer_groups(model, base_lr=1e-4, decay_rate=0.9)
    
    # 4. Loss Functions
    # Hybrid uses both CE (weighted) and Triplet
    class_weights = compute_class_weights(train_ds.samples).to(device)
    criterion_ce = nn.CrossEntropyLoss(weight=class_weights)
    criterion_triplet = nn.TripletMarginLoss(margin=0.3, p=2)

# Common Setup
optimizer = torch.optim.AdamW(param_groups)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
scaler = amp.GradScaler(enabled=USE_AMP)

print(f"Model: {type(model).__name__}")
print(f"Trainable Params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f}M")

# Training and Evaluation Function

In [None]:
# =========================
# 6. Training and Evaluation Functions
# =========================

# --- A. Metric Learning Evaluation (Prototypes) ---

def build_prototypes(model, loader, device):
    model.eval()
    proto_sum = torch.zeros(NUM_CLASSES, EMBED_DIM, device=device)
    proto_count = torch.zeros(NUM_CLASSES, dtype=torch.long, device=device)

    with torch.no_grad():
        for batch in loader:
            images = batch["image"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)
            domains = batch["domain"].to(device, non_blocking=True)
            
            # We build prototypes ONLY from the Herbarium domain (source)
            mask = (domains == 0) 
            if mask.sum() == 0: continue

            # MetricModel forward returns normalized embeddings
            emb = model(images[mask]) 
            lbls = labels[mask]
            
            for e, l in zip(emb, lbls):
                proto_sum[l] += e
                proto_count[l] += 1
    
    prototypes = torch.zeros_like(proto_sum)
    for c in range(NUM_CLASSES):
        if proto_count[c] > 0:
            prototypes[c] = proto_sum[c] / proto_count[c].float()
            prototypes[c] = F.normalize(prototypes[c], p=2, dim=-1)
            
    return prototypes, proto_count

def run_metric_eval(model, loader, prototypes, proto_count, device, with_set, without_set):
    model.eval()
    k = 5
    valid_proto_mask = proto_count > 0

    # Initialize counters
    stats = {
        "overall": {"total": 0, "c1": 0, "c5": 0},
        "paired":  {"total": 0, "c1": 0, "c5": 0}, # "with_set"
        "unpaired": {"total": 0, "c1": 0, "c5": 0} # "without_set"
    }
    
    with torch.no_grad():
        for batch in loader:
            imgs = batch["image"].to(device, non_blocking=True)
            labels = batch["label"].to(device, non_blocking=True)
            
            emb = model(imgs)
            # Similarity = Dot product (since vectors are normalized)
            sims = emb @ prototypes.T
            # Mask out missing prototypes
            sims[:, ~valid_proto_mask] = -1e9
            
            topk_vals, topk_idx = sims.topk(k=k, dim=1)
            preds_top1 = topk_idx[:, 0]

            labels_cpu = labels.cpu().tolist()
            top1_cpu = preds_top1.cpu().tolist()
            topk_cpu = topk_idx.cpu().tolist()

            for lbl, p1, pk_list in zip(labels_cpu, top1_cpu, topk_cpu):
                # Update Overall
                stats["overall"]["total"] += 1
                if p1 == lbl: stats["overall"]["c1"] += 1
                if lbl in pk_list: stats["overall"]["c5"] += 1

                # Update Split Specific
                split_key = None
                if lbl in with_set: split_key = "paired"
                elif lbl in without_set: split_key = "unpaired"
                
                if split_key:
                    stats[split_key]["total"] += 1
                    if p1 == lbl: stats[split_key]["c1"] += 1
                    if lbl in pk_list: stats[split_key]["c5"] += 1

    # Formatting results
    def safe_div(n, d): return n / d if d > 0 else 0.0
    
    return {
        "overall_top1": safe_div(stats["overall"]["c1"], stats["overall"]["total"]),
        "paired_top1":  safe_div(stats["paired"]["c1"], stats["paired"]["total"]),
        "unpaired_top1": safe_div(stats["unpaired"]["c1"], stats["unpaired"]["total"]),
        "counts": {k: v["total"] for k, v in stats.items()}
    }


# --- B. Hybrid Learning Evaluation (Linear Classifier) ---

@torch.no_grad()
def evaluate_hybrid_split(model, loader, device, with_set, without_set):
    model.eval()
    stats = {
        "all": {"c1": 0, "total": 0},
        "with": {"c1": 0, "total": 0},
        "without": {"c1": 0, "total": 0},
    }

    for batch in loader:
        imgs = batch["image"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True)

        logits = model(imgs) # DinoTriplet returns logits by default
        probs = logits.softmax(dim=1)
        pred1 = probs.argmax(dim=1)

        for i in range(labels.size(0)):
            y = labels[i].item()
            p1 = (pred1[i].item() == y)

            stats["all"]["total"] += 1
            if p1: stats["all"]["c1"] += 1

            if y in with_set:
                stats["with"]["total"] += 1
                if p1: stats["with"]["c1"] += 1

            if y in without_set:
                stats["without"]["total"] += 1
                if p1: stats["without"]["c1"] += 1

    def get_acc(d): return d["c1"] / d["total"] if d["total"] > 0 else 0.0
    
    return {
        "overall_top1": get_acc(stats["all"]),
        "paired_top1": get_acc(stats["with"]),
        "unpaired_top1": get_acc(stats["without"]),
    }


# --- C. Training Function: Metric Mode ---
# Iterates ONLY over Triplet Loader

def train_one_epoch_metric(model, loader, optimizer, loss_fn, device, scaler):
    model.train()
    running_loss = 0.0
    count = 0
    
    pbar = tqdm(loader, desc="Train Metric", leave=False)
    for batch in pbar:
        optimizer.zero_grad()
        
        anc = batch["anchor"].to(device)
        pos = batch["positive"].to(device)
        neg = batch["negative"].to(device)
        
        with autocast(enabled=(device.type == "cuda")):
            # MetricModel forward returns embeddings
            ea = model(anc)
            ep = model(pos)
            en = model(neg)
            loss = loss_fn(ea, ep, en)
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item()
        count += 1
        pbar.set_postfix({"loss": f"{running_loss/count:.4f}"})
    
    return running_loss / count


# --- D. Training Function: Hybrid Mode ---
# Iterates over BOTH Class Loader (CE) and Triplet Loader (Triplet Loss)

def train_one_epoch_hybrid(model, cls_loader, trip_loader, optimizer, ce_fn, trip_fn, device, scaler, lambda_t=0.5):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0
    
    trip_iter = iter(trip_loader)
    
    pbar = tqdm(cls_loader, desc="Train Hybrid", leave=False)
    for batch in pbar:
        imgs = batch["image"].to(device, non_blocking=True)
        labels = batch["label"].to(device, non_blocking=True)

        # Get Triplet Batch (Cyclic)
        try:
            trip_batch = next(trip_iter)
        except StopIteration:
            trip_iter = iter(trip_loader)
            trip_batch = next(trip_iter)

        anc = trip_batch["anchor"].to(device, non_blocking=True)
        pos = trip_batch["positive"].to(device, non_blocking=True)
        neg = trip_batch["negative"].to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=(device.type == "cuda")):
            # 1. Classification Loss
            logits = model(imgs)
            loss_ce = ce_fn(logits, labels)

            # 2. Triplet Loss
            # DinoTriplet.encode() is used here to get embeddings specifically
            anc_emb = model.encode(anc)
            pos_emb = model.encode(pos)
            neg_emb = model.encode(neg)
            loss_trip = trip_fn(anc_emb, pos_emb, neg_emb)

            loss = loss_ce + lambda_t * loss_trip

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * imgs.size(0)
        
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
        
        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

    return total_loss / total, correct / total

# Main Execution Loop

In [None]:
# =========================
# 7. Main Execution Loop
# =========================

# --- Setup Directories ---
CKPT_DIR = PROJECT_ROOT / "experiments" / "merged" / f"checkpoints_{TRAINING_MODE}"
CKPT_DIR.mkdir(parents=True, exist_ok=True)

print(f"--> Checkpoints will be saved to: {CKPT_DIR}")

# --- Resume Logic ---
# Tries to find the latest epoch or a specific file
start_epoch = 0
best_primary_metric = 0.0 # Will be "Unpaired Top-1" for Metric, "Overall Top-1" for Hybrid
history = []

resume_path = CKPT_DIR / "last.pt"
if not resume_path.exists():
    # Try looking for specific epoch files if "last.pt" doesn't exist
    chkpts = sorted(list(CKPT_DIR.glob("epoch_*.pt")))
    if chkpts:
        resume_path = chkpts[-1]

if resume_path.exists():
    print(f"--> Resuming from: {resume_path}")
    checkpoint = torch.load(resume_path, map_location=device)
    
    # Load States
    # Note: We use strict=False because switching modes might leave some keys unmatched 
    # if you try to load a hybrid checkpoint into a metric model (not recommended but handled).
    model.load_state_dict(checkpoint["model_state"], strict=False) 
    optimizer.load_state_dict(checkpoint["optimizer_state"])
    if "scheduler_state" in checkpoint and checkpoint["scheduler_state"]:
        scheduler.load_state_dict(checkpoint["scheduler_state"])
    
    start_epoch = checkpoint["epoch"] + 1
    history = checkpoint.get("history", [])
    best_primary_metric = checkpoint.get("best_primary_metric", 0.0)
    
    print(f"--> Resuming at Epoch {start_epoch + 1} with Best Metric: {best_primary_metric:.4%}")
else:
    print("--> Starting from scratch.")


# --- Training Loop ---

print(f"--> Starting Training ({TRAINING_MODE.upper()} Mode)...")

for epoch in range(start_epoch, EPOCHS):
    start_time = time.time()
    
    # -------------------------------------------
    # 1. TRAIN
    # -------------------------------------------
    if TRAINING_MODE == "metric":
        # Metric Mode: Train only on triplets
        train_loss = train_one_epoch_metric(
            model, triplet_loader, optimizer, criterion_triplet, device, scaler
        )
        # Placeholder for hybrid metrics
        train_acc, ce_loss, trip_loss = 0.0, 0.0, train_loss
        
    else: # Hybrid Mode
        # Hybrid Mode: Train on Class + Triplet
        train_loss, train_acc = train_one_epoch_hybrid(
            model, train_loader, triplet_loader, optimizer, 
            criterion_ce, criterion_triplet, device, scaler, lambda_t=0.5 # or LAMBDA_TRIPLET
        )
        # For logging simplicity, we treat the combined loss as "train_loss"
    
    # -------------------------------------------
    # 2. EVALUATE
    # -------------------------------------------
    if TRAINING_MODE == "metric":
        # Metric Mode: Build prototypes -> Distance check
        prototypes, proto_count = build_prototypes(model, train_loader, device)
        eval_metrics = run_metric_eval(model, test_loader, prototypes, proto_count, device, with_set, without_set)
        
        # Primary metric for saving best model: Unpaired Top-1 (Harder task)
        current_primary_metric = eval_metrics["unpaired_top1"]
        
    else: # Hybrid Mode
        # Hybrid Mode: Standard Linear Classification
        eval_metrics = evaluate_hybrid_split(model, test_loader, device, with_set, without_set)
        
        # Primary metric for saving best model: Overall Top-1
        current_primary_metric = eval_metrics["overall_top1"]

    epoch_time = time.time() - start_time
    
    # -------------------------------------------
    # 3. LOGGING
    # -------------------------------------------
    record = {
        "epoch": epoch + 1,
        "mode": TRAINING_MODE,
        "train_loss": float(train_loss),
        "train_acc": float(train_acc),
        "eval": eval_metrics, # Nested dict with all splits
        "time": epoch_time
    }
    history.append(record)
    
    # Console Print
    print(
        f"Ep {epoch+1:02d} | "
        f"Loss: {train_loss:.4f} | "
        f"Acc: {eval_metrics['overall_top1']:.2%} (All) | "
        f"{eval_metrics['paired_top1']:.2%} (Paired) | "
        f"{eval_metrics['unpaired_top1']:.2%} (Unpaired) | "
        f"Time: {epoch_time:.0f}s"
    )

    # -------------------------------------------
    # 4. SAVING
    # -------------------------------------------
    save_dict = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scheduler_state": scheduler.state_dict(),
        "history": history,
        "best_primary_metric": best_primary_metric,
        "mode": TRAINING_MODE
    }
    
    # Save Latest
    torch.save(save_dict, CKPT_DIR / "last.pt")
    
    # Save Best
    if current_primary_metric > best_primary_metric:
        print(f"--> New Best Model! ({best_primary_metric:.2%} -> {current_primary_metric:.2%})")
        best_primary_metric = current_primary_metric
        torch.save(save_dict, CKPT_DIR / "best_model.pt")
        
    # Periodic Save (Optional)
    # if (epoch + 1) % 5 == 0:
    #     torch.save(save_dict, CKPT_DIR / f"epoch_{epoch+1}.pt")
    
    # Dump History JSON
    with open(CKPT_DIR / "history.json", "w") as f:
        json.dump(history, f, indent=4)

    # Step Scheduler
    scheduler.step()

print("Training Complete.")

# Visualization

In [None]:
# =========================
# 8. Visualization of Results
# =========================

# If history is empty (e.g. just started), try loading it
if not history:
    history_path = CKPT_DIR / "history.json"
    if history_path.exists():
        with open(history_path, "r") as f:
            history = json.load(f)

if history:
    epochs = [h["epoch"] for h in history]
    train_loss = [h["train_loss"] for h in history]
    
    # Extract accuracies (structure depends on mode)
    # Both modes save 'eval' dict with 'overall_top1', 'paired_top1', 'unpaired_top1'
    overall_acc = [h["eval"]["overall_top1"] for h in history]
    paired_acc = [h["eval"]["paired_top1"] for h in history]
    unpaired_acc = [h["eval"]["unpaired_top1"] for h in history]
    
    plt.figure(figsize=(12, 5))
    
    # Plot 1: Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_loss, marker='o', label='Train Loss')
    plt.title(f"Training Loss ({TRAINING_MODE})")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()
    
    # Plot 2: Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, overall_acc, marker='o', label='Overall Top-1')
    plt.plot(epochs, paired_acc, marker='s', linestyle='--', label='Paired Top-1')
    plt.plot(epochs, unpaired_acc, marker='^', linestyle='--', label='Unpaired Top-1')
    plt.title(f"Evaluation Accuracy ({TRAINING_MODE})")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.ylim(0, 1.0)
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    print(f"Final Unpaired Accuracy: {unpaired_acc[-1]:.2%}")
else:
    print("No history found to plot.")

# Evaluation model's performance

## Save and Load Protoype

In [6]:
# --- Saving Prototypes ---
def save_prototypes(prototypes, counts, path):
    """Saves prototypes and their counts to a file."""
    print(f"Saving prototypes to {path}...")
    torch.save({
        "prototypes": prototypes.cpu(), # Move to CPU for storage
        "counts": counts.cpu()
    }, path)
    print("Done!")

# --- Loading Prototypes ---
def load_prototypes(path):
    """Loads prototypes from file, ready for inference."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"No prototype file found at {path}")
    
    print(f"Loading prototypes from {path}...")
    data = torch.load(path, map_location=device)
    return data["prototypes"]

def load_dual_prototypes(path, device=None):
    """
    Loads prototypes from a .pt file.
    Handles both raw Tensor saves and Dictionary saves.
    """
    if not os.path.exists(path):
        raise FileNotFoundError(f"Prototype file not found at: {path}")
    
    # Determine device automatically if not provided
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    print(f"Loading prototypes from {path}...")
    
    # Load the file mapping to the correct device immediately
    loaded_data = torch.load(path, map_location=device, weights_only=False)

    # CASE 1: The file is just the Tensor (Result of the create_prototypes function we fixed earlier)
    if isinstance(loaded_data, torch.Tensor):
        return loaded_data.to(device)

    # CASE 2: The file is a Dictionary containing the tensor
    elif isinstance(loaded_data, dict):
        # Check common keys
        if "prototypes" in loaded_data:
            protos = loaded_data["prototypes"]
            print(f"Loaded dictionary prototypes of shape: {protos.shape}")
            return protos.to(device)
        elif "protos" in loaded_data:
            protos = loaded_data["protos"]
            return protos.to(device)
        else:
            # Fallback: check if values are tensors
            keys = list(loaded_data.keys())
            raise KeyError(f"Could not find 'prototypes' key in dictionary. Available keys: {keys}")

    else:
        raise TypeError(f"Unsupported file format. Expected Tensor or Dict, got {type(loaded_data)}")

## Create Prototype

In [7]:


def create_and_save_prototypes(model, train_ds, save_path, batch_size=32):
    """
    Calculates prototypes from the training set (Herbarium domain only) 
    and saves them to the specified path.
    """
    model.eval()
    
    # 1. Build Prototypes from Train (Herbarium samples only)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=False, num_workers=4)
    
    proto_sum = None
    # Note: Ensure NUM_CLASSES and device are defined in your scope or passed as args
    proto_count = torch.zeros(NUM_CLASSES, dtype=torch.long, device=device)
    
    print(f"Building Prototypes from {len(train_ds)} samples...")
    with torch.no_grad():
        for batch in tqdm(train_loader, desc="Prototypes"):
            imgs = batch["image"].to(device)
            labels = batch["label"].to(device)
            domains = batch["domain"].to(device)
            
            # Domain 0 is Herbarium
            mask = (domains == 0)
            if mask.sum() == 0: continue
            
            emb = model(imgs[mask])
            lbls = labels[mask]
            
            if proto_sum is None:
                proto_sum = torch.zeros(NUM_CLASSES, emb.shape[1], device=device)
                
            for e, l in zip(emb, lbls):
                proto_sum[l] += e
                proto_count[l] += 1

    # Check if we actually found herbarium samples
    if proto_sum is None:
        raise RuntimeError("No Herbarium images found in train_ds! Check dataset loading.")

    prototypes = torch.zeros_like(proto_sum)
    for c in range(NUM_CLASSES):
        if proto_count[c] > 0:
            prototypes[c] = proto_sum[c] / proto_count[c]
            prototypes[c] = F.normalize(prototypes[c], p=2, dim=-1)

    # 2. Save Prototypes
    if not save_path.parent.exists():
        save_path.parent.mkdir(parents=True, exist_ok=True)
        
    # Removed the "if exists" check so this function forces a save/update
    save_prototypes(prototypes, proto_count, save_path)
    
    print(f"Prototypes saved to {save_path}")



## Evaluation Setup

In [10]:
def evaluate_model_with_prototypes(model, test_ds, prototypes, with_set, without_set, batch_size=32):
    """
    Evaluates model using pre-loaded prototypes.
    - Fixes Windows DataLoader crash (num_workers=0).
    - Fixes Architecture mismatch (checks for .encode() method).
    """
    model.eval()
    prototypes = prototypes.to(device)
    
    # FIX 1: Set num_workers=0 for Windows compatibility
    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)
    
    stats = {k: {"correct": 0, "total": 0} for k in ["all", "with", "without"]}
    
    print("Testing with loaded prototypes...")
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Testing"):
            imgs = batch["image"].to(device)
            labels = batch["label"].to(device)
            
            # FIX 2: Handle Architecture Difference
            # If the model has an 'encode' method (DinoTriplet), use it to get embeddings.
            # If not (TripletEncoder), use standard forward.
            if hasattr(model, 'encode'):
                emb = model.encode(imgs)
            else:
                emb = model(imgs)
            
            # Cosine Similarity
            sims = emb @ prototypes.T
            preds = sims.argmax(dim=1)
            
            for p, l in zip(preds, labels):
                lbl_item = l.item()
                is_correct = (p == l).item()
                
                stats["all"]["total"] += 1
                if is_correct: stats["all"]["correct"] += 1
                
                if lbl_item in with_set:
                    stats["with"]["total"] += 1
                    if is_correct: stats["with"]["correct"] += 1
                elif lbl_item in without_set:
                    stats["without"]["total"] += 1
                    if is_correct: stats["without"]["correct"] += 1
    
    return {
        "Overall": stats["all"]["correct"] / stats["all"]["total"] if stats["all"]["total"] > 0 else 0,
        "With-Pair": stats["with"]["correct"] / stats["with"]["total"] if stats["with"]["total"] > 0 else 0,
        "Without-Pair": stats["without"]["correct"] / stats["without"]["total"] if stats["without"]["total"] > 0 else 0
    }

## Running Models' Evaluation

In [11]:
# Define paths (Adjust if your folder structure is different)
# metric checkpoint is usually in experiments/2_stream/checkpoints_k5
path_m = PROJECT_ROOT / "experiments" / "ensemble" / "metric.pt"
save_path_m = PROJECT_ROOT / "experiments" / "ensemble" / "prototype_metric.pt"
# Triplet checkpoint is usually in experiments/triplet/checkpoints (or similar)
path_tr = PROJECT_ROOT / "experiments" / "ensemble" / "triplet.pt"
save_path_tr = PROJECT_ROOT / "experiments" / "ensemble" / "prototypes_triplet.pt"

# ===================================
# --- 1. Evaluate Metric Learning ---
# ===================================

model_metric = TripletEncoder(embed_dim=512).to(device)
state = torch.load(path_m, map_location=device)
# Handle wrapped state dicts
if 'model_state' in state: state = state['model_state']
model_metric.load_state_dict(state, strict=True)

if not save_path_m.exists():
    # Create and Save Prototypes
    create_and_save_prototypes(model_metric, train_ds, save_path_m)

# Load Prototypes
prototypes_2s = load_prototypes(save_path_m)
res_2s = evaluate_model_with_prototypes(model_metric, test_ds, prototypes_2s, with_set, without_set)
print(f"Metric Learning Results: {res_2s}")

# ===================================
# -------- 2. Evaluate Triplet ------
# ===================================

model_triplet = DinoTriplet(num_classes=NUM_CLASSES).to(device)
state = torch.load(path_tr, map_location=device)
if 'model_state' in state: state = state['model_state']
model_triplet.load_state_dict(state, strict=True)

if not save_path_tr.exists():
    # Create and Save Prototypes
    create_and_save_prototypes(model_triplet, train_ds, save_path_tr)

# Load Prototypes
prototypes_tr = load_prototypes(save_path_tr)
# Evaluate
res_tr = evaluate_model_with_prototypes(model_triplet, test_ds, prototypes_tr, with_set, without_set)
print(f"Triplet Results: {res_tr}")




Loading prototypes from c:\Users\William\School\Swinburne\Computer Science\2025 Semester 2\COS30082 Applied Machine Learning\Assignment 2\Approach3\experiments\ensemble\prototype_metric.pt...
Testing with loaded prototypes...


Testing: 100%|██████████| 7/7 [00:13<00:00,  1.90s/it]


Metric Learning Results: {'Overall': 0.7101449275362319, 'With-Pair': 0.7777777777777778, 'Without-Pair': 0.5185185185185185}
Loading prototypes from c:\Users\William\School\Swinburne\Computer Science\2025 Semester 2\COS30082 Applied Machine Learning\Assignment 2\Approach3\experiments\ensemble\prototypes_triplet.pt...
Testing with loaded prototypes...


Testing: 100%|██████████| 7/7 [00:10<00:00,  1.57s/it]

Triplet Results: {'Overall': 0.7681159420289855, 'With-Pair': 0.9607843137254902, 'Without-Pair': 0.2222222222222222}





# Ensemble Predicter

In [16]:
def run_inference(threshold, return_top5=False):
    """
    Hybrid Strategy:
    1. Get prediction from 'model_triplet' (The Specialist).
    2. Verify it using 'model_metric' (The Generalist) by checking the distance 
       in the 2-stream embedding space.
    3. If distance < threshold: TRUST model_triplet.
    4. If distance > threshold: REJECT model_triplet, use model_metric's nearest neighbor.
    
    If return_top5=True, also returns top-5 predictions for each sample.
    """
    hybrid_preds = []
    hybrid_top5 = []
    targets = []

    # Define the Specialist(well in paired classes) and Generalist(well in unpaired classes) prototypes
    Specialist = model_triplet.eval()
    Generalist = model_metric.eval()
    
    # FIX 1: Set num_workers=0 to prevent Windows crash
    test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=0)
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Hybrid Inference"): 
            imgs = batch["image"].to(device)
            lbls = batch["label"].to(device)
            targets.extend(lbls.cpu().numpy())
            
            # --- Step 1: Get Candidate from Specialist (model_triplet) ---
            # FIX 2: Use .encode() to get Embeddings (768d), NOT Logits (100d)
            emb_tr = Specialist.encode(imgs) 
            
            # Now shapes will match: (32x768) @ (768x100) -> (32x100)
            sims_tr = emb_tr @ prototypes_tr.T
            pred_class_tr = sims_tr.argmax(dim=1) # The class model_triplet thinks it is
            top5_tr = sims_tr.topk(5, dim=1)[1]

            # --- Step 2: Get Embedding from Generalist (model_metric) ---
            # Generalist (TripletEncoder) always returns embeddings (512d)
            emb_2s = Generalist(imgs)

            # --- Step 3: Distance Verification ---
            # Look up the 2-Stream prototype for the class model_triplet predicted.
            proposed_protos = prototypes_2s[pred_class_tr] # [Batch, 512]
            
            # Calculate Euclidean Distance
            dists = torch.norm(emb_2s - proposed_protos, dim=1) # [Batch]
            
            # --- Step 4: The Decision ---
            batch_preds = []
            batch_top5 = []
            for i in range(len(imgs)):

                if dists[i].item() < threshold:
                    # TRUST SPECIALIST
                    batch_preds.append(pred_class_tr[i].item())
                    if return_top5:
                        batch_top5.append(top5_tr[i].cpu().tolist())
                else:
                    # REJECT SPECIALIST -> FALLBACK TO GENERALIST
                    sims_2s = emb_2s[i].unsqueeze(0) @ prototypes_2s.T
                    nn_class = sims_2s.argmax().item()
                    batch_preds.append(nn_class)
                    if return_top5:
                        top5_2s = sims_2s.topk(5, dim=1)[1][0].cpu().tolist()
                        batch_top5.append(top5_2s)
            
            hybrid_preds.extend(batch_preds)
            if return_top5:
                hybrid_top5.extend(batch_top5)

    if return_top5:
        return hybrid_preds, targets, hybrid_top5
    return hybrid_preds, targets

## Tuning Hybrid Threshold

In [17]:
# --- Tuning Hybrid Threshold ---
print("\n--- Tuning Hybrid Threshold ---")

# Range of distances to test. 
# Lower (0.3) = Strict, trusts Generalist more.
# Higher (1.2) = Loose, trusts Specialist more.
# thresholds = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2]
thresholds = [0.90, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.00]


best_acc = 0.0
best_threshold = 0.0

for th in thresholds:
    preds, tgts = run_inference(th)
    correct = sum(1 for p, t in zip(preds, tgts) if p == t)
    acc = correct / len(tgts)
    print(f"Threshold {th:.2f} -> Overall Acc: {acc:.4f}")
    if acc > best_acc:
        best_acc = acc
        best_threshold = th

print("Best threshold:", best_threshold, "Acc:", best_acc)


--- Tuning Hybrid Threshold ---


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.75s/it]


Threshold 0.90 -> Overall Acc: 0.7923


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.73s/it]


Threshold 0.91 -> Overall Acc: 0.7971


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.73s/it]


Threshold 0.92 -> Overall Acc: 0.7971


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.73s/it]


Threshold 0.93 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.75s/it]


Threshold 0.94 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.81s/it]


Threshold 0.95 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.74s/it]


Threshold 0.96 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.74s/it]


Threshold 0.97 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.73s/it]


Threshold 0.98 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.74s/it]


Threshold 0.99 -> Overall Acc: 0.8068


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.73s/it]

Threshold 1.00 -> Overall Acc: 0.7923
Best threshold: 0.93 Acc: 0.8067632850241546





In [18]:
# --- Final Detailed Evaluation ---
# best_threshold = 0.93
print(f"\nRunning Final Breakdown with Threshold {best_threshold}...")
final_preds, final_targets, final_top5 = run_inference(best_threshold, return_top5=True)

# Manual Breakdown Calculation (Top-1 and Top-5)
stats = {
    k: {"correct1": 0, "correct5": 0, "total": 0} 
    for k in ["all", "with", "without"]
}

for p, t, top5 in zip(final_preds, final_targets, final_top5):
    is_correct1 = (p == t)
    is_correct5 = (t in top5)
    
    stats["all"]["total"] += 1
    if is_correct1: stats["all"]["correct1"] += 1
    if is_correct5: stats["all"]["correct5"] += 1
    
    if t in with_set:
        stats["with"]["total"] += 1
        if is_correct1: stats["with"]["correct1"] += 1
        if is_correct5: stats["with"]["correct5"] += 1
    elif t in without_set:
        stats["without"]["total"] += 1
        if is_correct1: stats["without"]["correct1"] += 1
        if is_correct5: stats["without"]["correct5"] += 1

results = {
    "Overall Top-1": stats["all"]["correct1"] / stats["all"]["total"],
    "Overall Top-5": stats["all"]["correct5"] / stats["all"]["total"],
    "With-Pair Top-1": stats["with"]["correct1"] / stats["with"]["total"],
    "With-Pair Top-5": stats["with"]["correct5"] / stats["with"]["total"],
    "Without-Pair Top-1": stats["without"]["correct1"] / stats["without"]["total"],
    "Without-Pair Top-5": stats["without"]["correct5"] / stats["without"]["total"]
}

print("Hybrid Results:")
for metric, value in results.items():
    print(f"  {metric}: {value:.2%}")


Running Final Breakdown with Threshold 0.93...


Hybrid Inference: 100%|██████████| 7/7 [00:19<00:00,  2.77s/it]

Hybrid Results:
  Overall Top-1: 80.68%
  Overall Top-5: 89.37%
  With-Pair Top-1: 94.77%
  With-Pair Top-5: 97.39%
  Without-Pair Top-1: 40.74%
  Without-Pair Top-5: 66.67%



