In [36]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm
import matplotlib.pyplot as plt
from animal_dataset import AnimalCLEFDataset  # now importable

In [37]:
# =============================
# Configurations
# =============================
BATCH_SIZE = 32
NUM_WORKERS = 4
IMAGE_SIZE = 224
EPOCHS_CLP = 1
EPOCHS_PLF = 1
LR = 1e-4
#DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

root = '../animal-clef-2025_data'

# reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x10e51d990>

In [38]:
# Set device, Mac, GPU, or CPU
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    DEVICE = torch.device("mps")
elif torch.cuda.is_available():
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

print("Using device:", DEVICE)

Using device: mps


In [39]:
"""
# =============================
# Dataset
# =============================
class AnimalCLEFDataset(Dataset):
    def __init__(self, root, split="database", transform=None):
        self.root = root.rstrip('/')
        meta = pd.read_csv(f"{self.root}/metadata.csv")
        sel = meta[meta['path'].str.contains(f"/{split}/")].reset_index(drop=True)
        if sel.empty:
            raise ValueError(f"No entries for split '{split}'")

        self.paths = sel['path'].tolist()
        self.image_ids = sel['image_id'].tolist()

        if split == 'database':
            #  Use individual identity,  
            ids = sel['identity'].astype(str)

            #  Build mapping from identity string → label index
            self.id2idx = {iid: i for i, iid in enumerate(sorted(ids.unique()))}

            #  Map each sample's identity to its label
            self.labels = ids.map(self.id2idx).tolist()

            # Safety check
            num_classes = len(self.id2idx)
            assert all(0 <= label < num_classes for label in self.labels), "Invalid labels found"
        else:
            self.labels = [-1] * len(sel)

        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = Image.open(f"{self.root}/{self.paths[i]}").convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[i]
""";

In [40]:
# =============================
# MAE Encoder + Projection Head + Decoder
# =============================
class MAEFramework(nn.Module):
    def __init__(self,
                 embed_dim: int = 768,
                 proj_dim: int = 256,
                 decoder_dim: int = 256,
                 layer_indices: list[int] = [3, 6, 9]):
        super().__init__()
        self.embed_dim = embed_dim
        # 1) Backbone ViT
        self.encoder = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.encoder.head = nn.Identity()
        self.layer_indices = set(layer_indices)

        # 3) Projection head
        self.proj_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, proj_dim),
        )
        # 4) Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, decoder_dim),
            nn.ReLU(inplace=True),
            nn.Linear(decoder_dim, 3 * IMAGE_SIZE * IMAGE_SIZE),
        )

    def forward(self, x: torch.Tensor, return_feats: bool = False):
        B = x.size(0)
        # patch embed
        x_p = self.encoder.conv_proj(x)
        x_p = x_p.flatten(2).transpose(1, 2)
        cls_tok = self.encoder.class_token.expand(B, -1, -1)
        tokens = torch.cat([cls_tok, x_p], dim=1)
        tokens = tokens + self.encoder.encoder.pos_embedding

        feats = []
        for idx, block in enumerate(self.encoder.encoder.layers):
            tokens = block(tokens)
            if idx in self.layer_indices:
                feats.append(tokens.clone())

        cls_feat = self.encoder.encoder.ln(tokens[:, 0])
        proj = self.proj_head(cls_feat)
        rec = self.decoder(cls_feat).view(B, 3, IMAGE_SIZE, IMAGE_SIZE)
        rec_loss = F.mse_loss(rec, x, reduction='none').mean([1, 2, 3])

        if return_feats:
            return cls_feat, proj, rec_loss, rec, feats
        return proj, rec_loss, rec


In [41]:
# =============================
# Losses
# =============================
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temp = temperature

    def forward(self, p1, p2, labels=None):
        # NT-Xent instance-level contrastive loss
        z = torch.cat([p1, p2], dim=0)
        z = F.normalize(z, dim=1)
        N = p1.size(0)
        sim = torch.matmul(z, z.T) / self.temp
        mask = torch.eye(2*N, device=sim.device).bool()
        sim.masked_fill_(mask, -9e15)
        idx = torch.arange(N, device=sim.device)
        targets = torch.cat([idx + N, idx])
        return F.cross_entropy(sim, targets)

class ProtoLoss(nn.Module):
    def forward(self, feats, prototypes, labels):
        dist = torch.cdist(feats, prototypes)
        return F.cross_entropy(-dist, labels)

In [42]:
# =============================
# Utilities
# =============================
def compute_layer_distances(bef_feats, aft_feats, temperature=0.5):
    total = []  # Initialize as a list to store distances per layer
    for b, a in zip(bef_feats, aft_feats):
        b_f, a_f = b.flatten(1), a.flatten(1)
        eu = F.pairwise_distance(b_f, a_f)
        cos = 1 - F.cosine_similarity(b_f, a_f, dim=1)
        score = eu + temperature * cos  # this score is already (B,)
        total.append(score)  # Add the current batch's score to the list
    return torch.stack(total, dim=0).mean(dim=0)  # Stack and compute mean along dim=0

def calculate_unknown_score(feat_bef, feat_aft, feat_vec, prototypes, lamda=1.0):
    # 1. Compute multilayer feature distance (s_total)
    s_total = compute_layer_distances(feat_bef, feat_aft, temperature=0.5)  # Can keep temp fixed or expose as param

    # 2. Compute max prototype similarity (s_prototypes)
    fv_n = F.normalize(feat_vec, dim=1)  # (B, D)
    p_n = F.normalize(prototypes, dim=1)  # (C, D)
    sim = torch.matmul(fv_n, p_n.T)  # (B, C)
    s_proto, _ = sim.max(dim=1)  # (B,)

    # 3. Final score using lambda
    score = s_proto - lamda * s_total

    return score  # (B,)

In [43]:
def train_CLP(model, loader, epochs, lr, ckpt, alpha_cl=1.0, alpha_rec=1.0):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    start = 0

    # --- load checkpoint if it exists, stripping "module." if necessary
    if os.path.exists(path):
        ck = torch.load(path, map_location=DEVICE)
        raw_sd = ck['model']
        # strip DataParallel "module." prefix
        sd = {k.replace("module.", ""): v for k, v in raw_sd.items()}
        model.load_state_dict(sd)
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    # ---------------------------------------------------------------

    scl = SupConLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x1 = x + 0.05 * torch.randn_like(x)
            x2 = x + 0.05 * torch.randn_like(x)
            f1, p1, r1_loss, r1, _ = model(x1, return_feats=True)
            f2, p2, r2_loss, r2, _ = model(x2, return_feats=True)
            l_cl  = scl(p1, p2)
            l_rec = r1_loss.mean() + r2_loss.mean()
            loss  = alpha_cl * l_cl + alpha_rec * l_rec
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()

        torch.save({
            'model': model.state_dict() if not isinstance(model, nn.DataParallel)
                              else model.module.state_dict(),
            'opt':   opt.state_dict(),
            'ep':    ep
        }, path)
        print(f"CLP Epoch {ep+1}: {total/len(loader):.4f}")


In [44]:
def train_PLF(model, loader, epochs, lr, ckpt, num_classes, encoder_pre):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    # unwrap if DataParallel
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    embed_dim   = base_model.embed_dim

    proto_tensor = torch.randn(num_classes, embed_dim, device=DEVICE)
    proto = nn.Parameter(proto_tensor, requires_grad=True)
    opt   = torch.optim.Adam(list(model.parameters()) + [proto], lr=lr)
    start = 0

    # --- load checkpoint if it exists, stripping "module." if necessary
    if os.path.exists(path):
        ck = torch.load(path, map_location=DEVICE)
        raw_sd = ck['model']
        sd = {k.replace("module.", ""): v for k, v in raw_sd.items()}
        model.load_state_dict(sd)
        proto.data = ck['proto']
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    # ---------------------------------------------------------------

    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()

        torch.save({
            'model': model.state_dict() if not isinstance(model, nn.DataParallel)
                              else model.module.state_dict(),
            'proto': proto.data,
            'opt':   opt.state_dict(),
            'ep':    ep
        }, path)
        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")

    return proto


In [45]:
"""
# =============================
# Training: CLP
# =============================
def train_CLP(model, loader, epochs, lr, ckpt, alpha_cl=1.0, alpha_rec=1.0):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    scl = SupConLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x1 = x + 0.05 * torch.randn_like(x)
            x2 = x + 0.05 * torch.randn_like(x)
            f1, p1, r1_loss, r1, _ = model(x1, return_feats=True)
            f2, p2, r2_loss, r2, _ = model(x2, return_feats=True)
            l_cl = scl(p1, p2)
            l_rec = r1_loss.mean() + r2_loss.mean()
            loss = alpha_cl * l_cl + alpha_rec * l_rec
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"CLP Epoch {ep+1}: {total/len(loader):.4f}")


# =============================
# Training: PLF
# =============================
def train_PLF(model, loader, epochs, lr, ckpt, num_classes, encoder_pre):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    # unwrap if DataParallel
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    embed_dim = base_model.embed_dim
    proto_tensor = torch.randn(num_classes, embed_dim, device=DEVICE)
    proto = nn.Parameter(proto_tensor, requires_grad=True)
    opt = torch.optim.Adam(list(model.parameters()) + [proto], lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        proto.data = ck['proto']
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'proto': proto.data, 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")
    return proto
""";
    
# =============================
# Inference
# =============================

@torch.no_grad()
def inference(enc_pre, model, proto, loader, threshold, lamda=1.0):
    enc_pre.eval()
    model.eval()
    preds = []

    for x, _ in tqdm(loader, desc="Infer"):
        x = x.to(DEVICE)
        enc_pre = enc_pre.to(DEVICE)
        model = model.to(DEVICE)

        _, _, _, _, bef_feats = enc_pre(x, return_feats=True)
        feat, _, _, _, aft_feats = model(x, return_feats=True)

        # 1. Calculate unknown score
        scores = calculate_unknown_score(bef_feats, aft_feats, feat, proto, lamda=lamda)  # shape: (B,)

        # 2. Predict the most similar known class
        idx = torch.argmax(torch.matmul(F.normalize(feat, dim=1), F.normalize(proto, dim=1).T), dim=1)  # (B,)

        # 3. Threshold to filter unknowns
        known = scores > threshold
        pred = idx.clone()
        pred[~known] = -1  # mark unknowns

        preds.append(pred.cpu())

    return torch.cat(preds, dim=0)

In [46]:
# =============================
# Submission
# =============================
def generate_submission(root, preds, db_ds):
    sub = pd.read_csv(f"{root}/sample_submission.csv")
    meta = pd.read_csv(f"{root}/metadata.csv")
    q = meta[meta['path'].str.contains('/query/')].reset_index(drop=True)
    q['pred_idx'] = preds.numpy()

    # Step 1: Print raw predicted indices
    print("\nRaw predicted indices (pred_idx):")
    print(q[['image_id', 'pred_idx']].head(10))

    # Step 2: Index to identity mapping
    idx2id = {v: k for k, v in db_ds.id2idx.items()}
    q['prediction'] = q['pred_idx'].apply(lambda i: 'new_individual' if i < 0 else idx2id.get(int(i), f"unknown_{i}"))

    # Step 3: Print mapped predictions
    print("\nMapped predictions (after idx2id):")
    print(q[['image_id', 'prediction']].head(10))

    out = sub[['image_id']].merge(q[['image_id', 'prediction']], on='image_id')

    import time
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    save_path = f'submission_{timestamp}.csv'
    out.to_csv(save_path, index=False)

    return save_path

In [47]:
# =============================
# Main Workflow
# =============================
def main():
    #root = '/kaggle/input/animal-clef-2025'
    tf = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ])
    #idx2id = {v: k for k, v in db_ds.id2idx.items()}

    db_ds = AnimalCLEFDataset(root, 'database', transform=tf)
    
    #print("Example metadata path[0]:", db_ds.paths[0])
    #print("Looking for file at:", os.path.join(db_ds.root, db_ds.paths[0]))
    #print("Exists on disk?", os.path.exists(os.path.join(db_ds.root, db_ds.paths[0])))
    
    db_loader = DataLoader(db_ds, batch_size=8, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=True)
    query_ds = AnimalCLEFDataset(root, 'query', transform=tf)
    q_loader = DataLoader(query_ds, batch_size=8, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=True)

    clp_model = MAEFramework()
    plf_model = MAEFramework()

    # pick your device
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        DEVICE = torch.device("mps")
    elif torch.cuda.is_available():
        DEVICE = torch.device("cuda")
    else:
        DEVICE = torch.device("cpu")

    # decide which GPUs to wrap
    cuda_gpus = torch.cuda.device_count()
    if DEVICE.type == "cuda" and cuda_gpus > 1:
        device_ids = list(range(cuda_gpus))  # e.g. [0,1,…]
        clp_model = nn.DataParallel(clp_model,  device_ids=device_ids).to(DEVICE)
        plf_model = nn.DataParallel(plf_model,  device_ids=device_ids).to(DEVICE)
    else:
        # single‐device: MPS, single CUDA, or CPU
        clp_model = clp_model.to(DEVICE)
        plf_model = plf_model.to(DEVICE)

    # now training and inference will work on either MPS or CUDA
    train_CLP(clp_model, db_loader, EPOCHS_CLP, LR,  'clp.pth', alpha_cl=1.0, alpha_rec=1.0)
    prototype = train_PLF(plf_model, db_loader, EPOCHS_PLF, LR,
                          'plf.pth', len(db_ds.id2idx), clp_model)
    
    print("Prototype shape:", prototype.shape)
    print("Number of unique classes:", len(db_ds.id2idx))

    # Plot the score distribution before estimating threshold
    #plot_score_distribution(clp_model, plf_model, prototype, db_loader)

    # Estimate threshold from database distribution
    dists = []
    with torch.no_grad():
        for x, _ in DataLoader(db_ds, batch_size=8, shuffle=False,
                               num_workers=NUM_WORKERS, pin_memory=True):
            x = x.to(DEVICE)
            feat, _, _, _, aft_feats = plf_model(x, return_feats=True)
            _, _, _, _, bef_feats = clp_model(x, return_feats=True)
            scores = calculate_unknown_score(bef_feats, aft_feats, feat, prototype)
            dists.extend(scores.cpu().tolist())


    threshold = torch.quantile(torch.tensor(dists), 0.95)


    # Inference on query set
    with torch.no_grad():
        lamda = 0.2 # or any value you want to test
        preds = inference(clp_model, plf_model, prototype, q_loader, threshold, lamda=lamda)


    print("Saving submission…")
    generate_submission(root, preds, db_ds)

if __name__ == '__main__':
    main()


Prototype shape: torch.Size([1102, 768])
Number of unique classes: 1102


Infer: 100%|██████████| 267/267 [00:52<00:00,  5.13it/s]

Saving submission…

Raw predicted indices (pred_idx):
   image_id  pred_idx
0         3        50
1         5        24
2        12        35
3        13        35
4        18        35
5        19        35
6        27        50
7        33        35
8        36        50
9        45        24

Mapped predictions (after idx2id):
   image_id          prediction
0         3  LynxID2025_lynx_62
1         5  LynxID2025_lynx_32
2        12  LynxID2025_lynx_43
3        13  LynxID2025_lynx_43
4        18  LynxID2025_lynx_43
5        19  LynxID2025_lynx_43
6        27  LynxID2025_lynx_62
7        33  LynxID2025_lynx_43
8        36  LynxID2025_lynx_62
9        45  LynxID2025_lynx_32





In [None]:
#preds = inference(clp_model, plf_model, prototype, q_loader, threshold, lamda=lamda)
#print("Saving submission…")
#generate_submission(root, preds, db_ds)

NameError: name 'clp_model' is not defined

Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 205MB/s]  
CLP Ep1: 100%|██████████| 1635/1635 [15:00<00:00,  1.81it/s]


CLP Epoch 1: 1.3089


PLF Ep1: 100%|██████████| 1635/1635 [08:10<00:00,  3.33it/s]


PLF Epoch 1: 6.0188
Prototype shape: torch.Size([1102, 768])
Number of unique classes: 1102


Infer: 100%|██████████| 267/267 [00:49<00:00,  5.35it/s]

Saving submission…

Raw predicted indices (pred_idx):
   image_id  pred_idx
0         3        24
1         5        24
2        12        24
3        13        35
4        18        24
5        19        24
6        27        24
7        33        40
8        36        24
9        45        24

Mapped predictions (after idx2id):
   image_id          prediction
0         3  LynxID2025_lynx_32
1         5  LynxID2025_lynx_32
2        12  LynxID2025_lynx_32
3        13  LynxID2025_lynx_43
4        18  LynxID2025_lynx_32
5        19  LynxID2025_lynx_32
6        27  LynxID2025_lynx_32
7        33  LynxID2025_lynx_49
8        36  LynxID2025_lynx_32
9        45  LynxID2025_lynx_32



