In [260]:
import os
#import timm
import torch
import time
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights
from transformers import ViTMAEModel, AutoFeatureExtractor
from tqdm import tqdm
import matplotlib.pyplot as plt
from animal_dataset import AnimalCLEFDataset  # now importable

from transformers import ViTMAEForPreTraining, ViTMAEConfig

In [261]:
# =============================
# Configurations
# =============================
BATCH_SIZE = 32
NUM_WORKERS = 0
IMAGE_SIZE = 224
LR = 1e-4

epoch = 15
EPOCHS_CLP = epoch
EPOCHS_PLF = epoch

LAMBDA = 0.3
ALPHA_CL = 1#.2
ALPHA_REC = 1#1.8

FORCE_TRAIN_RESTART_CLP = True
FORCE_TRAIN_RESTART_PLF = True

# NEW: what percentile of your database‐scores to use as the open‐set cutoff
OPENSET_PERCENTILE = 0.60  # try 0.50 (median), or 0.75, etc.

#DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

root = '../animal-clef-2025_data'

# reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x113329c30>

In [262]:
# Set device, Mac, GPU, or CPU
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print("Using device:", DEVICE)

Using device: mps


In [263]:
"""
# =============================
# Dataset
# =============================
class AnimalCLEFDataset(Dataset):
    def __init__(self, root, split="database", transform=None):
        self.root = root.rstrip('/')
        meta = pd.read_csv(f"{self.root}/metadata.csv")
        sel = meta[meta['path'].str.contains(f"/{split}/")].reset_index(drop=True)
        if sel.empty:
            raise ValueError(f"No entries for split '{split}'")

        self.paths = sel['path'].tolist()
        self.image_ids = sel['image_id'].tolist()

        if split == 'database':
            #  Use individual identity,  
            ids = sel['identity'].astype(str)

            #  Build mapping from identity string → label index
            self.id2idx = {iid: i for i, iid in enumerate(sorted(ids.unique()))}

            #  Map each sample's identity to its label
            self.labels = ids.map(self.id2idx).tolist()

            # Safety check
            num_classes = len(self.id2idx)
            assert all(0 <= label < num_classes for label in self.labels), "Invalid labels found"
        else:
            self.labels = [-1] * len(sel)

        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = Image.open(f"{self.root}/{self.paths[i]}").convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[i]
""";

In [264]:
# =============================
# MAE Encoder + Projection Head + Decoder
# =============================
class MAEFramework(nn.Module):
    def __init__(self,
                 model_name: str = "facebook/vit-mae-base",
                 proj_dim:   int   = 256,
                 mask_ratio: float = 0.75):
        super().__init__()
        # 1) encoder-only ViT‐MAE
        self.encoder = ViTMAEModel.from_pretrained(model_name)
        # 2) full MAE (encoder+decoder) for reconstruction
        self.mae = ViTMAEForPreTraining.from_pretrained(model_name)
        self.mae.config.mask_ratio = mask_ratio

        # projection head on top of the [CLS]-token
        self.embed_dim = self.encoder.config.hidden_size
        self.proj_head = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dim, proj_dim),
        )

    def forward(self, pixel_values: torch.Tensor, return_feats: bool = False):
        # ---  A) contrastive backbone  ---
        enc_out = self.encoder(
            pixel_values=pixel_values,
            return_dict=True,
            output_hidden_states=True        # ← request all layer outputs
        )
        cls_feat = enc_out.last_hidden_state[:, 0, :]   # (B, hidden_size)
        proj     = self.proj_head(cls_feat)             # (B, proj_dim)

        # ---  B) reconstruction head  ---
        pre_out  = self.mae(pixel_values=pixel_values, return_dict=True)
        rec_loss = pre_out.loss                         # scalar
        rec      = self.mae.unpatchify(pre_out.logits)  # (B,3,H,W)

        if return_feats:
            # drop the patch‐embeddings and keep only the block outputs
            feats = list(enc_out.hidden_states[1:])    # list of [B,seq_len,D]
            return cls_feat, proj, rec_loss, rec, feats

        return proj, rec_loss, rec



"""class MAEFramework(nn.Module):
    def __init__(self,
                 model_name: str = "facebook/vit-mae-base",
                 proj_dim: int = 256,
                 mask_ratio: float = 0.75):
        super().__init__()
        # 1) Load a Vision‐MAE model (encoder + decoder, pretrained weights available)
        self.mae = ViTMAEForPreTraining.from_pretrained(model_name)
        # you can override mask ratio if you like
        self.mae.config.mask_ratio = mask_ratio

        # 2) Projection head (same as before)
        self.embed_dim = self.mae.config.hidden_size  # typically 768
        self.proj_head = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dim, proj_dim),
        )

    def forward(self, pixel_values: torch.Tensor, return_feats: bool = False):
        
        #pixel_values: Tensor of shape (B,3,H,W), normalized to [0,1]
        
        # the model itself will sample a mask (mask_ratio),
        # run encoder → decoder → unpatchify → compute reconstruction loss
        outputs = self.mae(pixel_values=pixel_values, return_dict=True)

        # HuggingFace returns:
        #  - loss:   the MSE on masked patches averaged over patches & pixels
        #  - predicted_pixel_values: the full (B,3,H,W) reconstruction
        rec_loss = outputs.loss                # (scalar, averaged over batch)
        rec      = outputs.reconstructed_pixel_values

        # take CLS token from encoder (before decoder)
        # note: HF stores the last hidden state of the encoder in .encoder_last_hidden_state
        cls_feat = outputs.encoder_last_hidden_state[:, 0, :]
        proj     = self.proj_head(cls_feat)

        if return_feats:
            # we don’t extract intermediate “layer_indices” feats here
            return cls_feat, proj, rec_loss, rec, []
        return proj, rec_loss, rec
""";

"""
class MAEFramework(nn.Module):
    def __init__(self,
                 embed_dim: int = 768,
                 proj_dim: int = 256,
                 decoder_dim: int = 256,
                 layer_indices: list[int] = [3, 6, 9]):
        super().__init__()
        self.embed_dim = embed_dim
        # 1) Backbone ViT
        self.encoder = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.encoder.head = nn.Identity()
        self.layer_indices = set(layer_indices)

        # 3) Projection head
        self.proj_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, proj_dim),
        )
        # 4) Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, decoder_dim),
            nn.ReLU(inplace=True),
            nn.Linear(decoder_dim, 3 * IMAGE_SIZE * IMAGE_SIZE),
        )

    def forward(self, x: torch.Tensor, return_feats: bool = False):
        B = x.size(0)
        # patch embed
        x_p = self.encoder.conv_proj(x)
        x_p = x_p.flatten(2).transpose(1, 2)
        cls_tok = self.encoder.class_token.expand(B, -1, -1)
        tokens = torch.cat([cls_tok, x_p], dim=1)
        tokens = tokens + self.encoder.encoder.pos_embedding

        feats = []
        for idx, block in enumerate(self.encoder.encoder.layers):
            tokens = block(tokens)
            if idx in self.layer_indices:
                feats.append(tokens.clone())

        cls_feat = self.encoder.encoder.ln(tokens[:, 0])
        proj = self.proj_head(cls_feat)
        rec = self.decoder(cls_feat).view(B, 3, IMAGE_SIZE, IMAGE_SIZE)
        rec_loss = F.mse_loss(rec, x, reduction='none').mean([1, 2, 3])

        if return_feats:
            return cls_feat, proj, rec_loss, rec, feats
        return proj, rec_loss, rec
""";

In [265]:
# =============================
# Losses
# =============================
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temp = temperature

    def forward(self, p1, p2, labels=None):
        # NT-Xent instance-level contrastive loss
        z = torch.cat([p1, p2], dim=0)
        z = F.normalize(z, dim=1)
        N = p1.size(0)
        sim = torch.matmul(z, z.T) / self.temp
        mask = torch.eye(2*N, device=sim.device).bool()
        sim.masked_fill_(mask, -9e15)
        idx = torch.arange(N, device=sim.device)
        targets = torch.cat([idx + N, idx])
        return F.cross_entropy(sim, targets)

class ProtoLoss(nn.Module):
    def forward(self, feats, prototypes, labels):
        dist = torch.cdist(feats, prototypes)
        return F.cross_entropy(-dist, labels)

In [266]:
# =============================
# Utilities
# =============================
def compute_layer_distances(bef_feats, aft_feats, temperature=0.5):
    total = []  # Initialize as a list to store distances per layer
    for b, a in zip(bef_feats, aft_feats):
        b_f, a_f = b.flatten(1), a.flatten(1)
        eu = F.pairwise_distance(b_f, a_f)
        cos = 1 - F.cosine_similarity(b_f, a_f, dim=1)
        score = eu + temperature * cos  # this score is already (B,)
        total.append(score)  # Add the current batch's score to the list
    return torch.stack(total, dim=0).mean(dim=0)  # Stack and compute mean along dim=0

def calculate_unknown_score(feat_bef, feat_aft, feat_vec, prototypes, lamda=1.0):
    # 1. Compute multilayer feature distance (s_total)
    s_total = compute_layer_distances(feat_bef, feat_aft, temperature=0.5)  # Can keep temp fixed or expose as param

    # 2. Compute max prototype similarity (s_prototypes)
    fv_n = F.normalize(feat_vec, dim=1)  # (B, D)
    p_n = F.normalize(prototypes, dim=1)  # (C, D)
    sim = torch.matmul(fv_n, p_n.T)  # (B, C)
    s_proto, _ = sim.max(dim=1)  # (B,)

    # 3. Final score using lambda
    score = s_proto - lamda * s_total

    return score  # (B,)

In [267]:
def train_CLP(model, loader, epochs, lr, ckpt, alpha_cl=1.0, alpha_rec=1.0, force_restart=False):
    """
    Contrastive‐plus‐reconstruction pre-training, but with
    supervised CE on the CLS token instead of instance‐level SupCon.
    """
    cl_hist, rec_hist, tot_hist = [], [], []
    path = os.path.join(CHECKPOINT_DIR, ckpt)

    # --- NEW: figure out how many classes we have in the 'database' split
    num_classes = len(loader.dataset.id2idx)

    # --- NEW: add a tiny classification head on CLS
    clf_head = nn.Linear(model.embed_dim, num_classes).to(DEVICE)

    # --- NEW: optimizer now includes both the ViT parameters AND the new head
    opt = torch.optim.Adam(
        list(model.parameters()) + list(clf_head.parameters()),
        lr=lr
    )
    start = 0

    if os.path.exists(path) and force_restart:
        print(f"[train_CLP] Restarting from scratch (deleting {path})")
        os.remove(path)

    # --- load checkpoint as before (we assume it saved only model & opt & ep) ---
    if os.path.exists(path):
        ck     = torch.load(path, map_location=DEVICE)
        raw_sd = ck['model']
        sd     = {k.replace("module.", ""): v for k,v in raw_sd.items()}
        model.load_state_dict(sd)
        clf_head.load_state_dict(ck['clf_head'])          # NEW: restore classifier
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1

    model.to(DEVICE)
    clf_head.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total_ce, total_rec, total = 0.0, 0.0, 0.0

        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)

            # two‐view reconstructions for the MAE
            x1 = x + 0.05*torch.randn_like(x)
            x2 = x + 0.05*torch.randn_like(x)

            # forward‐pass with features
            cls_feat, _, r1_loss, _, _ = model(x1, return_feats=True)
            _,       _, r2_loss, _, _ = model(x2, return_feats=True)

            # --- NEW: supervised cross‐entropy on the CLS embedding
            logits = clf_head(cls_feat)               # (B, num_classes)
            l_ce   = F.cross_entropy(logits, y)       # standard CE

            # reconstruction loss (unchanged)
            l_rec  = r1_loss.mean() + r2_loss.mean()

            loss = alpha_cl * l_ce + alpha_rec * l_rec

            opt.zero_grad()
            loss.backward()
            opt.step()

            total     += loss.item()
            total_ce  += l_ce.item()
            total_rec += l_rec.item()

        # --- save checkpoint (model, classifier, optimizer, epoch)
        torch.save({
            'model':     model.module.state_dict() if isinstance(model, nn.DataParallel)
                                               else model.state_dict(),
            'clf_head': clf_head.state_dict(),         # NEW
            'opt':       opt.state_dict(),
            'ep':        ep
        }, path)

        avg_ce  = total_ce  / len(loader)
        avg_rec = total_rec / len(loader)
        avg_tot = total   / len(loader)

        cl_hist.append(avg_ce)
        rec_hist.append(avg_rec)
        tot_hist.append(avg_tot)

        print(f"CLP Epoch {ep+1}: total={avg_tot:.4f}, CE={avg_ce:.4f}, rec={avg_rec:.4f}")

    return model, cl_hist, rec_hist, tot_hist


In [268]:
"""def train_CLP(model, loader, epochs, lr, ckpt, alpha_cl=1.0, alpha_rec=1.0):
    
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    start = 0

    # --- load checkpoint if it exists, stripping "module." if necessary
    if os.path.exists(path):
        ck = torch.load(path, map_location=DEVICE)
        raw_sd = ck['model']
        # strip DataParallel "module." prefix
        sd = {k.replace("module.", ""): v for k, v in raw_sd.items()}
        model.load_state_dict(sd)
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    # ---------------------------------------------------------------

    scl = SupConLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x1 = x + 0.05 * torch.randn_like(x)
            x2 = x + 0.05 * torch.randn_like(x)
            f1, p1, r1_loss, r1, _ = model(x1, return_feats=True)
            f2, p2, r2_loss, r2, _ = model(x2, return_feats=True)
            l_cl  = scl(p1, p2)
            l_rec = r1_loss.mean() + r2_loss.mean()
            loss  = alpha_cl * l_cl + alpha_rec * l_rec
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()

        torch.save({
            'model': model.state_dict() if not isinstance(model, nn.DataParallel)
                              else model.module.state_dict(),
            'opt':   opt.state_dict(),
            'ep':    ep
        }, path)
        print(f"CLP Epoch {ep+1}: {total/len(loader):.4f}")
        """;


In [269]:
"""def train_PLF(model, loader, epochs, lr, ckpt, num_classes, force_restart = False):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    
        # if asked, delete the old checkpoint
    if os.path.exists(path) and force_restart:
        print(f"Checkpoint {path} exists. Restarting PLF training.")
        os.remove(path)
    
    # unwrap if DataParallel
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    embed_dim   = base_model.embed_dim

    proto_tensor = torch.randn(num_classes, embed_dim, device=DEVICE)
    proto = nn.Parameter(proto_tensor, requires_grad=True)
    opt   = torch.optim.Adam(list(model.parameters()) + [proto], lr=lr)
    start = 0

    # --- load checkpoint if it exists, stripping "module." if necessary
    if os.path.exists(path):
        ck = torch.load(path, map_location=DEVICE)
        raw_sd = ck['model']
        sd = {k.replace("module.", ""): v for k, v in raw_sd.items()}
        model.load_state_dict(sd)
        proto.data = ck['proto']
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    # ---------------------------------------------------------------

    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()

        torch.save({
            'model': model.state_dict() if not isinstance(model, nn.DataParallel)
                              else model.module.state_dict(),
            'proto': proto.data,
            'opt':   opt.state_dict(),
            'ep':    ep
        }, path)
        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")

    return proto
""";

# =============================
# Prototype‐only training
# =============================
def train_PLF(model, loader, epochs, lr, ckpt, num_classes, force_restart=False, unfrozen_encoder=False):
    """
    Train only the prototypes and the proj_head parameters.
    """
    path = os.path.join(CHECKPOINT_DIR, ckpt)

    # 1) create fresh prototypes
    proto = nn.Parameter(torch.randn(num_classes, model.embed_dim, device=DEVICE))
    
    # 2) build optimizer params: always proj_head + proto,
    #    and optionally any encoder layers you’ve unfrozen
    params = list(model.proj_head.parameters()) + [proto]
    if unfrozen_encoder:
        
        # include only those encoder params that we explicitly un-froze
        params += [p for p in model.encoder.parameters() if p.requires_grad]
    opt = torch.optim.Adam(params, lr=lr)

    """
    # 2) optimizer over just proj_head + proto
    opt = torch.optim.Adam(
        list(model.proj_head.parameters()) + [proto],
        lr=lr
    )
    """

    start = 0
    if os.path.exists(path) and force_restart:
        os.remove(path)
    if os.path.exists(path):
        ck = torch.load(path, map_location=DEVICE)
        proto.data = ck['proto']
        model.proj_head.load_state_dict(ck['proj_head'])
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1

    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            # forward through *frozen* encoder + proj_head
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()

        # save checkpoint (prototypes, proj_head, optimizer, epoch)
        torch.save({
            'proto':     proto.data,
            'proj_head': model.proj_head.state_dict(),
            'opt':       opt.state_dict(),
            'ep':        ep
        }, path)

        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")

    return proto




In [270]:
"""
# =============================
# Training: CLP
# =============================
def train_CLP(model, loader, epochs, lr, ckpt, alpha_cl=1.0, alpha_rec=1.0):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    scl = SupConLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x1 = x + 0.05 * torch.randn_like(x)
            x2 = x + 0.05 * torch.randn_like(x)
            f1, p1, r1_loss, r1, _ = model(x1, return_feats=True)
            f2, p2, r2_loss, r2, _ = model(x2, return_feats=True)
            l_cl = scl(p1, p2)
            l_rec = r1_loss.mean() + r2_loss.mean()
            loss = alpha_cl * l_cl + alpha_rec * l_rec
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"CLP Epoch {ep+1}: {total/len(loader):.4f}")


# =============================
# Training: PLF
# =============================
def train_PLF(model, loader, epochs, lr, ckpt, num_classes, encoder_pre):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    # unwrap if DataParallel
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    embed_dim = base_model.embed_dim
    proto_tensor = torch.randn(num_classes, embed_dim, device=DEVICE)
    proto = nn.Parameter(proto_tensor, requires_grad=True)
    opt = torch.optim.Adam(list(model.parameters()) + [proto], lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        proto.data = ck['proto']
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'proto': proto.data, 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")
    return proto
""";
    
# =============================
# Inference
# =============================

@torch.no_grad()
def inference(enc_pre, model, proto, loader, threshold, lamda=1.0):
    enc_pre.eval()
    model.eval()
    preds = []

    for x, _ in tqdm(loader, desc="Infer"):
        x = x.to(DEVICE)
        enc_pre = enc_pre.to(DEVICE)
        model = model.to(DEVICE)

        _, _, _, _, bef_feats = enc_pre(x, return_feats=True)
        feat, _, _, _, aft_feats = model(x, return_feats=True)

        # 1. Calculate unknown score
        scores = calculate_unknown_score(bef_feats, aft_feats, feat, proto, lamda = lamda)  # shape: (B,)

        # 2. Predict the most similar known class
        idx = torch.argmax(torch.matmul(F.normalize(feat, dim=1), F.normalize(proto, dim=1).T), dim=1)  # (B,)

        # 3. Threshold to filter unknowns
        known = scores > threshold
        pred = idx.clone()
        pred[~known] = -1  # mark unknowns

        preds.append(pred.cpu())

    return torch.cat(preds, dim=0)

In [271]:
# =============================
# Submission
# =============================
def generate_submission(root, preds, db_ds):
    sub = pd.read_csv(f"{root}/sample_submission.csv")
    meta = pd.read_csv(f"{root}/metadata.csv")
    q = meta[meta['path'].str.contains('/query/')].reset_index(drop=True)
    q['pred_idx'] = preds.numpy()

    # Step 1: Print raw predicted indices
    print("\nRaw predicted indices (pred_idx):")
    print(q[['image_id', 'pred_idx']].head(10))

    # Step 2: Index to identity mapping
    idx2id = {v: k for k, v in db_ds.id2idx.items()}
    q['prediction'] = q['pred_idx'].apply(lambda i: 'new_individual' if i < 0 else idx2id.get(int(i), f"unknown_{i}"))

    # Step 3: Print mapped predictions
    print("\nMapped predictions (after idx2id):")
    print(q[['image_id', 'prediction']].head(10))

    # drop the original metadata identity so we don't end up with two columns
    q = q.drop(columns=['identity'])

    # rename 'prediction' → 'identity' so the output column is called identity
    q = q.rename(columns={'prediction':'identity'})

    # now merge on image_id and keep the 'identity' column
    out = sub[['image_id']].merge(q[['image_id','identity']], on='image_id')

    #out = sub[['image_id']].merge(q[['image_id', 'identity']], on='image_id')
    #out = sub[['image_id']].merge(q[['image_id','prediction']], on='image_id')


    timestamp = time.strftime("%Y%m%d-%H%M%S")
    save_path = f'submission_{timestamp}.csv'
    out.to_csv(save_path, index=False)

    return save_path

In [None]:
# =============================
# Main Workflow
# =============================
"""
def main():
    #root = '/kaggle/input/animal-clef-2025'
    tf = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ])
    #idx2id = {v: k for k, v in db_ds.id2idx.items()}

    db_ds = AnimalCLEFDataset(root, 'database', transform=tf)
    
    #print("Example metadata path[0]:", db_ds.paths[0])
    #print("Looking for file at:", os.path.join(db_ds.root, db_ds.paths[0]))
    #print("Exists on disk?", os.path.exists(os.path.join(db_ds.root, db_ds.paths[0])))
    
    db_loader = DataLoader(db_ds, batch_size = BATCH_SIZE, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=True)
    query_ds = AnimalCLEFDataset(root, 'query', transform=tf)
    q_loader = DataLoader(query_ds, batch_size = BATCH_SIZE, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=True)

    clp_model = MAEFramework()
    plf_model = MAEFramework()

    # pick your device
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        DEVICE = torch.device("mps")
    elif torch.cuda.is_available():
        DEVICE = torch.device("cuda")
    else:
        DEVICE = torch.device("cpu")

    # decide which GPUs to wrap
    cuda_gpus = torch.cuda.device_count()
    if DEVICE.type == "cuda" and cuda_gpus > 1:
        device_ids = list(range(cuda_gpus))  # e.g. [0,1,…]
        clp_model = nn.DataParallel(clp_model,  device_ids=device_ids).to(DEVICE)
        plf_model = nn.DataParallel(plf_model,  device_ids=device_ids).to(DEVICE)
    else:
        # single‐device: MPS, single CUDA, or CPU
        clp_model = clp_model.to(DEVICE)
        plf_model = plf_model.to(DEVICE)

    # now training and inference will work on either MPS or CUDA
    #train_CLP(clp_model, db_loader, EPOCHS_CLP, LR,  'clp.pth', alpha_cl = ALPHA_CL, alpha_rec = ALPHA_REC)
    
    clp_model, cl_hist, rec_hist, tot_hist = train_CLP(
        clp_model, db_loader, EPOCHS_CLP, LR, 'clp.pth',
        alpha_cl=ALPHA_CL, alpha_rec=ALPHA_REC, force_restart=FORCE_TRAIN_RESTART_CLP)
    
    epochs = list(range(1, len(cl_hist)+1))
    plt.plot(epochs, cl_hist, label='Contrastive Loss')
    plt.plot(epochs, rec_hist, label='Reconstruction Loss')
    plt.plot(epochs, tot_hist, label='Total Loss')

    plt.title(f'CLP Losses (α_cl={ALPHA_CL}, α_rec={ALPHA_REC}, λ={LAMBDA}, batch={BATCH_SIZE})')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    prototype = train_PLF(plf_model, db_loader, EPOCHS_PLF, LR,
                          'plf.pth', len(db_ds.id2idx),  force_restart = FORCE_TRAIN_RESTART_PLF)
    

    
    print("Prototype shape:", prototype.shape)
    print("Number of unique classes:", len(db_ds.id2idx))

    # Plot the score distribution before estimating threshold
    #plot_score_distribution(clp_model, plf_model, prototype, db_loader)

    # Estimate threshold from database distribution
    dists = []
    with torch.no_grad():
        for x, _ in DataLoader(db_ds, batch_size = BATCH_SIZE, shuffle=False,
                               num_workers=NUM_WORKERS, pin_memory=True):
            x = x.to(DEVICE)
            feat, _, _, _, aft_feats = plf_model(x, return_feats=True)
            _, _, _, _, bef_feats = clp_model(x, return_feats=True)
            scores = calculate_unknown_score(bef_feats, aft_feats, feat, prototype, lamda = LAMBDA)
            dists.extend(scores.cpu().tolist())


    threshold = torch.quantile(torch.tensor(dists), 0.95).item()


    # Inference on query set
    with torch.no_grad():
        #lamda = 0.2 # or any value you want to test
        preds = inference(clp_model, plf_model, prototype, q_loader, threshold, lamda = LAMBDA)


    print("Saving submission…")
    generate_submission(root, preds, db_ds)
""";
def main():
    # transforms & datasets
    #tf = transforms.Compose([
    #    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    #    transforms.ToTensor()
    #])

    feat_extractor = AutoFeatureExtractor.from_pretrained("facebook/vit-mae-base")
    mean = feat_extractor.image_mean
    std  = feat_extractor.image_std

    tf = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    db_ds    = AnimalCLEFDataset(root, 'database', transform=tf)
    query_ds = AnimalCLEFDataset(root, 'query',    transform=tf)
    
    pin = (DEVICE.type == "cuda")

    db_loader = DataLoader(db_ds, batch_size=BATCH_SIZE, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=pin)
    q_loader  = DataLoader(query_ds, batch_size=BATCH_SIZE, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=pin)

    # --- 1) Train CLP as before ---
    clp_model = MAEFramework().to(DEVICE)
    if DEVICE.type == "cuda" and torch.cuda.device_count() > 1:
        clp_model = nn.DataParallel(clp_model)
    clp_model, cl_hist, rec_hist, tot_hist = train_CLP(
        clp_model, db_loader, EPOCHS_CLP, LR, 'clp.pth',
        alpha_cl=ALPHA_CL, alpha_rec=ALPHA_REC,
        force_restart=FORCE_TRAIN_RESTART_CLP
    )

    # plot losses
    epochs = list(range(1, len(cl_hist)+1))
    plt.plot(epochs, cl_hist, label='Contrastive')
    plt.plot(epochs, rec_hist,  label='Reconstruction')
    plt.plot(epochs, tot_hist,  label='Total')
    #plt.title(f"CLP Losses (α_cl={ALPHA_CL}, α_rec={ALPHA_REC})")
    
    plt.title(
        f"Contrastive and Reconstruction Losses (α_cl={ALPHA_CL}, α_rec={ALPHA_REC})\n"
        f"λ={LAMBDA}, bs={BATCH_SIZE}, LR={LR}\n"
        f"{int(OPENSET_PERCENTILE*100)}%-quantile cutoff={threshold:.4f}"
    )

    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.show()

    # --- 2) Initialize PLF from CLP encoder & freeze everything except proj_head & proto ---
    #plf_model = MAEFramework().to(DEVICE)
    
    # --- 2) Initialize PLF from CLP encoder & freeze everything except proj_head & proto ---
    plf_model = MAEFramework().to(DEVICE)

    # copy encoder, head, decoder weights
    if isinstance(clp_model, nn.DataParallel):
        enc_sd  = clp_model.module.encoder.state_dict()
        head_sd = clp_model.module.proj_head.state_dict()
        dec_sd  = clp_model.module.mae.state_dict()
    else:
        enc_sd  = clp_model.encoder.state_dict()
        head_sd = clp_model.proj_head.state_dict()
        dec_sd  = clp_model.mae.state_dict()

    plf_model.encoder.load_state_dict(enc_sd)
    plf_model.proj_head.load_state_dict(head_sd)
    plf_model.mae.load_state_dict(dec_sd, strict=False)

    # only need one decoder‐freeze pass now
    #for p in plf_model.mae.parameters():
    #    p.requires_grad = False

    # 2) freeze entire MAE backbone (encoder + decoder)
    for p in plf_model.encoder.parameters():
        p.requires_grad = False
    for p in plf_model.mae.parameters():
        p.requires_grad = False

    # NEW: unfreeze the *last* 2 ViT blocks so they can adapt to class prototypes
    for blk in plf_model.encoder.encoder.layer[-2:]:
        for p in blk.parameters():
            p.requires_grad = True

    # proj_head remains trainable by default

    if DEVICE.type == "cuda" and torch.cuda.device_count() > 1:
        plf_model = nn.DataParallel(plf_model)

    """
    # --- 3) Train PLF (only prototypes + proj_head) ---
    prototype = train_PLF(
        plf_model, db_loader, EPOCHS_PLF, LR, 'plf.pth',
        len(db_ds.id2idx),
        force_restart=FORCE_TRAIN_RESTART_PLF
    )
    """

    # --- 3) Train PLF (prototypes + proj_head + last 2 encoder blocks) ---
    # Make sure train_PLF’s optimizer picks up the newly-unfrozen layers:
    prototype = train_PLF(
        plf_model, db_loader, EPOCHS_PLF, LR, 'plf.pth',
        len(db_ds.id2idx),
        force_restart=FORCE_TRAIN_RESTART_PLF,
        unfrozen_encoder=True  # ← you can add a flag so train_PLF adds encoder params
    )

    print("Prototype shape:", prototype.shape)

    # --- 4) Estimate threshold ---
    dists = []
    with torch.no_grad():
        for x, _ in DataLoader(db_ds, batch_size=BATCH_SIZE,
                               shuffle=False, num_workers=NUM_WORKERS, pin_memory=pin):
            x = x.to(DEVICE)
            _, _, _, _, bef_feats = clp_model(x, return_feats=True)
            _, _, _, _, aft_feats = plf_model(x, return_feats=True)
            feat, _, _, _, _      = plf_model(x, return_feats=True)
            scores = calculate_unknown_score(bef_feats, aft_feats, feat, prototype, lamda=LAMBDA)
            #print("min, max, 95%-thresh of database scores:", 
            #min(dists), max(dists), threshold)
            dists.extend(scores.cpu().tolist())
    threshold = torch.quantile(torch.tensor(dists), OPENSET_PERCENTILE).item()

    #print(f"DB scores → min {min(dists):.4f}, max {max(dists):.4f}, 95%-quantile {threshold:.4f}, frac≥thresh {(np.array(dists)>threshold).mean():.2%}")
    print(
        f"DB scores → "
        f"min {min(dists):.4f}, "
        f"max {max(dists):.4f}, "
        f"{int(OPENSET_PERCENTILE*100)}%-quantile {threshold:.4f}, "
        f"frac≥thresh {(np.array(dists) > threshold).mean():.2%}"
    )


    # --- 5) Inference + submission ---
    with torch.no_grad():
        preds = inference(clp_model, plf_model, prototype, q_loader, threshold, lamda=LAMBDA)
    print("Saving submission…")
    generate_submission(root, preds, db_ds)


if __name__ == '__main__':
    main()



[train_CLP] Restarting from scratch (deleting ./checkpoints/clp.pth)


CLP Ep1: 100%|██████████| 409/409 [07:12<00:00,  1.06s/it]


CLP Epoch 1: total=5.8306, CE=5.5661, rec=0.2644


CLP Ep2:  68%|██████▊   | 277/409 [05:01<02:30,  1.14s/it]

In [258]:
#preds = inference(clp_model, plf_model, prototype, q_loader, threshold, lamda=lamda)
#print("Saving submission…")
#generate_submission(root, preds, db_ds)

In [259]:
from transformers import ViTMAEForPreTraining, ViTMAEConfig