In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm
import matplotlib.pyplot as plt

# =============================
# Configurations
# =============================
BATCH_SIZE = 32
NUM_WORKERS = 4
IMAGE_SIZE = 224
EPOCHS_CLP = 20
EPOCHS_PLF = 20
LR = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# reproducibility
torch.manual_seed(42)

# =============================
# Dataset
# =============================
class AnimalCLEFDataset(Dataset):
    def __init__(self, root, split="database", transform=None):
        self.root = root.rstrip('/')
        meta = pd.read_csv(f"{self.root}/metadata.csv")
        sel = meta[meta['path'].str.contains(f"/{split}/")].reset_index(drop=True)
        if sel.empty:
            raise ValueError(f"No entries for split '{split}'")
        self.paths = sel['path'].tolist()
        if split == 'database':
            ids = sel['image_id'].astype(str)
            self.id2idx = {iid: i for i, iid in enumerate(sorted(ids.unique()))}
            self.labels = ids.map(self.id2idx).tolist()
            num_classes = len(self.id2idx)
            assert all(0 <= label < num_classes for label in self.labels), "Invalid labels found"
        else:
            self.labels = [-1] * len(sel)
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = Image.open(f"{self.root}/{self.paths[i]}").convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[i]

# =============================
# MAE Encoder + Projection Head + Decoder
# =============================
class MAEFramework(nn.Module):
    def __init__(self,
                 embed_dim: int = 768,
                 proj_dim: int = 256,
                 decoder_dim: int = 256,
                 layer_indices: list[int] = [3, 6, 9]):
        super().__init__()
        self.embed_dim = embed_dim
        # 1) Backbone ViT
        self.encoder = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.encoder.head = nn.Identity()
        self.layer_indices = set(layer_indices)

        # 3) Projection head
        self.proj_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, proj_dim),
        )
        # 4) Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, decoder_dim),
            nn.ReLU(inplace=True),
            nn.Linear(decoder_dim, 3 * IMAGE_SIZE * IMAGE_SIZE),
        )

    def forward(self, x: torch.Tensor, return_feats: bool = False):
        B = x.size(0)
        # patch embed
        x_p = self.encoder.conv_proj(x)
        x_p = x_p.flatten(2).transpose(1, 2)
        cls_tok = self.encoder.class_token.expand(B, -1, -1)
        tokens = torch.cat([cls_tok, x_p], dim=1)
        tokens = tokens + self.encoder.encoder.pos_embedding

        feats = []
        for idx, block in enumerate(self.encoder.encoder.layers):
            tokens = block(tokens)
            if idx in self.layer_indices:
                feats.append(tokens.clone())

        cls_feat = self.encoder.encoder.ln(tokens[:, 0])
        proj = self.proj_head(cls_feat)
        rec = self.decoder(cls_feat).view(B, 3, IMAGE_SIZE, IMAGE_SIZE)
        rec_loss = F.mse_loss(rec, x, reduction='none').mean([1, 2, 3])

        if return_feats:
            return cls_feat, proj, rec_loss, rec, feats
        return proj, rec_loss, rec

# =============================
# Losses
# =============================
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temp = temperature

    def forward(self, p1, p2, labels=None):
        # NT-Xent instance-level contrastive loss
        z = torch.cat([p1, p2], dim=0)
        z = F.normalize(z, dim=1)
        N = p1.size(0)
        sim = torch.matmul(z, z.T) / self.temp
        mask = torch.eye(2*N, device=sim.device).bool()
        sim.masked_fill_(mask, -9e15)
        idx = torch.arange(N, device=sim.device)
        targets = torch.cat([idx + N, idx])
        return F.cross_entropy(sim, targets)

class ProtoLoss(nn.Module):
    def forward(self, feats, prototypes, labels):
        dist = torch.cdist(feats, prototypes)
        return F.cross_entropy(-dist, labels)

# =============================
# Utilities
# =============================
def compute_layer_distances(bef_feats, aft_feats, temperature=0.5):
    total = []  # Initialize as a list to store distances per layer
    for b, a in zip(bef_feats, aft_feats):
        b_f, a_f = b.flatten(1), a.flatten(1)
        eu = F.pairwise_distance(b_f, a_f)
        cos = 1 - F.cosine_similarity(b_f, a_f, dim=1)
        score = eu + temperature * cos  # this score is already (B,)
        total.append(score)  # Add the current batch's score to the list
    return torch.stack(total, dim=0).mean(dim=0)  # Stack and compute mean along dim=0

def calculate_unknown_score(feat_bef, feat_aft, feat_vec, prototypes, temperature=0.5, eps=1e-8):
    # 1. Compute layer distance
    ld = compute_layer_distances(feat_bef, feat_aft, temperature)  # shape: (B,)

    # 2. Compute cosine similarity to prototypes
    fv_n = F.normalize(feat_vec, dim=1)  # (B, D)
    p_n = F.normalize(prototypes, dim=1)  # (C, D)
    sim = torch.matmul(fv_n, p_n.T)  # (B, C)

    # 3. Use 1 - max similarity
    max_sim, _ = sim.max(dim=1)  # (B,)
    proto_dist = 1 - max_sim  # (B,)

    # 4. Combine scores (no unsqueeze needed)
    score = ld + temperature * proto_dist  # (B,)

    return score  # shape: (B,)


# =============================
# Training: CLP
# =============================
def train_CLP(model, loader, epochs, lr, ckpt):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    scl = SupConLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x1 = x + 0.05 * torch.randn_like(x)
            x2 = x + 0.05 * torch.randn_like(x)
            f1, p1, r1_loss, r1, _ = model(x1, return_feats=True)
            f2, p2, r2_loss, r2, _ = model(x2, return_feats=True)
            l_cl = scl(p1, p2)
            l_rec = r1_loss.mean() + r2_loss.mean()
            loss = l_cl + l_rec
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"CLP Epoch {ep+1}: {total/len(loader):.4f}")

# =============================
# Training: PLF
# =============================
def train_PLF(model, loader, epochs, lr, ckpt, num_classes, encoder_pre):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    # unwrap if DataParallel
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    embed_dim = base_model.embed_dim
    proto_tensor = torch.randn(num_classes, embed_dim, device=DEVICE)
    proto = nn.Parameter(proto_tensor, requires_grad=True)
    opt = torch.optim.Adam(list(model.parameters()) + [proto], lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        proto.data = ck['proto']
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'proto': proto.data, 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")
    return proto

# =============================
# Inference
# =============================

@torch.no_grad()
def inference(enc_pre, model, proto, loader, threshold):
    enc_pre.eval()
    model.eval()
    preds = []

    for x, _ in tqdm(loader, desc="Infer"):
        x = x.to(DEVICE)
        enc_pre = enc_pre.to(DEVICE)
        model = model.to(DEVICE)

        _, _, _, _, bef_feats = enc_pre(x, return_feats=True)
        feat, _, _, _, aft_feats = model(x, return_feats=True)

        scores = calculate_unknown_score(bef_feats, aft_feats, feat, proto)  # shape: (B,)
        min_s, idx = scores.min(dim=1)  # shape: (B,)

        known = min_s < threshold
        pred = idx.clone()
        pred[~known] = -1  # mark unknowns

        preds.append(pred.cpu())

    return torch.cat(preds, dim=0)

# =============================
# Submission
# =============================
def generate_submission(root, preds, db_ds):
    sub = pd.read_csv(f"{root}/sample_submission.csv")
    meta = pd.read_csv(f"{root}/metadata.csv")
    q = meta[meta['path'].str.contains('/query/')].reset_index(drop=True)
    q['pred_idx'] = preds.numpy()

    idx2id = {v: k for k, v in db_ds.id2idx.items()}
    q['prediction'] = q['pred_idx'].apply(lambda i: 'new_individual' if i < 0 else idx2id[int(i)])
    out = sub[['image_id']].merge(q[['image_id', 'prediction']], on='image_id')

    # Timestamped filename
    import time
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    save_path = f'/kaggle/working/submission_{timestamp}.csv'
    out.to_csv(save_path, index=False)

    return save_path

def plot_score_distribution(enc_pre, model, proto, loader):
    enc_pre.eval()
    model.eval()
    all_scores = []

    for x, _ in tqdm(loader, desc="Scoring for Plot"):
        x = x.to(DEVICE)
        enc_pre = enc_pre.to(DEVICE)
        model = model.to(DEVICE)

        _, _, _, _, bef_feats = enc_pre(x, return_feats=True)
        feat, _, _, _, aft_feats = model(x, return_feats=True)

        scores = calculate_unknown_score(bef_feats, aft_feats, feat, proto)
        min_s, _ = scores.min(dim=1)
        all_scores.append(min_s.cpu())

    all_scores = torch.cat(all_scores).numpy()

    plt.figure(figsize=(8, 5))
    plt.hist(all_scores, bins=50, alpha=0.75, color='steelblue')
    plt.xlabel("Unknown Score")
    plt.ylabel("Count")
    plt.title("Score Distribution — Use to Tune Threshold")
    plt.grid(True)
    plt.show()

# =============================
# Main Workflow
# =============================
def main():
    root = '/kaggle/input/animal-clef-2025'
    tf = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ])
    #idx2id = {v: k for k, v in db_ds.id2idx.items()}

    db_ds = AnimalCLEFDataset(root, 'database', transform=tf)
    db_loader = DataLoader(db_ds, batch_size=8, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=True)
    query_ds = AnimalCLEFDataset(root, 'query', transform=tf)
    q_loader = DataLoader(query_ds, batch_size=8, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=True)

    clp_model = MAEFramework()
    plf_model = MAEFramework()

    num_gpus = torch.cuda.device_count()
    if num_gpus >= 2:
        device_ids = [0, 1]
    elif num_gpus == 1:
        device_ids = [0]
    else:
        device_ids = None

    clp_model = nn.DataParallel(clp_model, device_ids=device_ids).to(DEVICE)
    plf_model = nn.DataParallel(plf_model, device_ids=device_ids).to(DEVICE)

    train_CLP(clp_model, db_loader, EPOCHS_CLP, LR, 'clp.pth')
    prototype = train_PLF(plf_model, db_loader, EPOCHS_PLF, LR,
                          'plf.pth', len(db_ds.id2idx), clp_model)

    # Plot the score distribution before estimating threshold
    #plot_score_distribution(clp_model, plf_model, prototype, db_loader)

    # Estimate threshold from database distribution
    dists = []
    with torch.no_grad():
        for x, _ in DataLoader(db_ds, batch_size=8, shuffle=False,
                               num_workers=NUM_WORKERS, pin_memory=True):
            x = x.to(DEVICE)
            feat, _, _, _, aft_feats = plf_model(x, return_feats=True)
            _, _, _, _, bef_feats = clp_model(x, return_feats=True)
            scores = calculate_unknown_score(bef_feats, aft_feats, feat, prototype)
            dists.extend(scores.min(dim=0)[0].cpu().tolist())

    threshold = torch.quantile(torch.cat(dists), 0.95)

    # Inference on query set
    with torch.no_grad():
        preds = inference(clp_model, plf_model, prototype, q_loader, threshold)

    print("Saving submission…")
    generate_submission(root, preds, db_ds)

if __name__ == '__main__':
    main()


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 176MB/s]  
CLP Ep1: 100%|██████████| 1635/1635 [09:21<00:00,  2.91it/s]


CLP Epoch 1: 1.3126


CLP Ep2: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 2: 1.2318


CLP Ep3: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 3: 1.2261


CLP Ep4: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 4: 1.2161


CLP Ep5: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 5: 1.2157


CLP Ep6: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 6: 1.2030


CLP Ep7: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 7: 1.2148


CLP Ep8: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 8: 1.2180


CLP Ep9: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 9: 1.2193


CLP Ep10: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 10: 1.2000


CLP Ep11: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 11: 1.2184


CLP Ep12: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 12: 1.2023


CLP Ep13: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 13: 1.1855


CLP Ep14: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 14: 1.1930


CLP Ep15: 100%|██████████| 1635/1635 [09:19<00:00,  2.92it/s]


CLP Epoch 15: 1.1908


CLP Ep16: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 16: 1.1880


CLP Ep17: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 17: 1.1881


CLP Ep18: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 18: 1.1830


CLP Ep19: 100%|██████████| 1635/1635 [09:20<00:00,  2.92it/s]


CLP Epoch 19: 1.1669


CLP Ep20:  91%|█████████ | 1490/1635 [08:30<00:49,  2.92it/s]