# I. Pre-Process
## A. Format
Pré-process des PNG

- Standardiser l’image : lire le PNG, convertir en RGB (gérer l’alpha si RGBA), puis resize + crop (ou pad) pour obtenir exactement la taille attendue par le backbone.

- Mettre au format “modèle” : convertir en float tensor, normaliser avec les mean/std du pré-entraînement (ImageNet / CLIP selon le backbone).

In [20]:
Path_train = "/content/drive/MyDrive/jaguar-re-id(1)/train/train"
Path_test = "/content/drive/MyDrive/jaguar-re-id(1)/test/test"
import os
from PIL import Image, ImageOps
import numpy as np
import torch
import os, glob
import pandas as pd
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import KFold, StratifiedKFold
import torchvision.transforms as T


class ImageDataset(Dataset):
    def __init__(self, paths, labels=None, transform=None, float_size="fp32"):
        self.paths = paths
        self.labels = labels
        self.transform = transform
        self.float_size = float_size

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert("RGB")
        x = self.transform(img) if self.transform else img

        if self.float_size == "fp16":
            x = x.half()
        elif self.float_size == "bf16":
            x = x.to(torch.bfloat16)

        if self.labels is None:
            return x, os.path.basename(p)  # test: (tensor, filename)
        return x, self.labels[idx]        # train/val: (tensor, class_id)



# II. DataLoader



In [21]:
import os, glob
import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold, StratifiedKFold
import torchvision.transforms as T

import os, glob
import pandas as pd
import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold, StratifiedKFold
import torchvision.transforms as T

def make_loaders(
    train_dir,
    test_dir,
    img_size=224,
    batch_size=64,
    num_workers=2,
    val_size=1.0,
    n_folds=5,
    fold=0,
    seed=42,
    train_csv="/content/drive/MyDrive/jaguar-re-id(1)/train.csv",
    test_csv="/content/drive/MyDrive/jaguar-re-id(1)/test.csv",
    img_col="filename",          # <-- CHANGE
    label_col="ground_truth",    # <-- CHANGE
    float_size="fp32",
    mean=(0.485, 0.456, 0.406),
    std=(0.229, 0.224, 0.225),
    train_tf=None, test_tf=None,
):
    def to_path(name, root):
        name = str(name)
        if not name.lower().endswith(".png"):
            name += ".png"
        return name if os.path.isabs(name) else os.path.join(root, name)

    # ---- TRAIN (from train.csv)
    df_tr = pd.read_csv(train_csv)
    train_paths = [to_path(n, train_dir) for n in df_tr[img_col].tolist()]
    raw_labels  = df_tr[label_col].astype(str).tolist()

    # ---- TEST (optional, for later inference)
    if test_csv is not None and os.path.exists(test_csv):
        df_te = pd.read_csv(test_csv)
        if "query_image" in df_te.columns and "gallery_image" in df_te.columns:
            test_imgs = sorted(set(df_te["query_image"]) | set(df_te["gallery_image"]))
            test_paths = [to_path(n, test_dir) for n in test_imgs]
        elif img_col in df_te.columns:
            test_paths = [to_path(n, test_dir) for n in df_te[img_col].tolist()]
        else:
            test_paths = sorted(glob.glob(os.path.join(test_dir, "*.png")))
    else:
        test_paths = sorted(glob.glob(os.path.join(test_dir, "*.png")))

    # ---- encode labels
    classes = sorted(set(raw_labels))
    class_to_idx = {c: i for i, c in enumerate(classes)}
    y = [class_to_idx[c] for c in raw_labels]
    num_classes = len(classes)

    # ---- transforms (train aug, val/test no aug)
    if train_tf is None:
        # fallback: your existing ViT transforms (keep your old code here)
      train_tf = T.Compose([
          T.RandomResizedCrop(img_size, scale=(0.8, 1.0)),
          T.RandomHorizontalFlip(p=0.5),
          T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
          T.ToTensor(),
          T.Normalize(mean, std),
          T.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value="random"),
      ])

    if test_tf is None:
        test_tf = T.Compose([
            T.Resize(img_size),
            T.CenterCrop(img_size),
            T.ToTensor(),
            T.Normalize(mean, std),
        ])

    # ---- K-fold split
    splitter = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) if num_classes > 1 else \
               KFold(n_splits=n_folds, shuffle=True, random_state=seed)

    idx_all = list(range(len(train_paths)))
    tr_idx, va_idx = list(splitter.split(idx_all, y if num_classes > 1 else None))[fold]

    if val_size < 1.0:
        g = torch.Generator().manual_seed(seed + fold)
        va_idx = torch.tensor(va_idx)[torch.randperm(len(va_idx), generator=g)].tolist()
        k = max(1, int(len(va_idx) * val_size))
        va_keep, va_rest = va_idx[:k], va_idx[k:]
        tr_idx = list(tr_idx) + list(va_rest)
        va_idx = va_keep

    tr_paths  = [train_paths[i] for i in tr_idx]
    va_paths  = [train_paths[i] for i in va_idx]
    tr_labels = [y[i] for i in tr_idx]
    va_labels = [y[i] for i in va_idx]

    train_ds = ImageDataset(tr_paths, tr_labels, transform=train_tf, float_size=float_size)
    val_ds   = ImageDataset(va_paths, va_labels, transform=test_tf,  float_size=float_size)
    test_ds  = ImageDataset(test_paths, None,     transform=test_tf,  float_size=float_size)
    train_eval_ds = ImageDataset(tr_paths, tr_labels, transform=test_tf, float_size=float_size)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=True)
    train_eval_loader = DataLoader(train_eval_ds, batch_size=batch_size, shuffle=False,num_workers=num_workers, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

    meta = {
        "N_train": len(train_ds),
        "N_val": len(val_ds),
        "N_test": len(test_ds),
        "num_classes": num_classes,
        "class_to_idx": class_to_idx,
    }
    return train_loader, train_eval_loader, val_loader, test_loader, meta



In [22]:
import torch
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
else:
    print("CPU runtime")


NVIDIA H100 80GB HBM3


# III. Model
## A. Backbone

Backbone: Swin-L or OpenCLIP ViT-L/H, plus one strong CNN for ensemble diversity

Head: GeM → Linear/BN/PReLU → L2-norm

Loss: Sub-center ArcFace + dynamic margin

Training: progressive resize, class-imbalance sampling, heavy aug

Inference: cosine kNN + DBA/re-ranking

## B. Head

In [23]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
import torch
import torch.nn as nn
import torch.nn.functional as F

class ViTBaseBackbone(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = ViT_B_16_Weights.IMAGENET1K_V1 if pretrained else None # 85M params
        m = vit_b_16(weights=weights)
        self.backbone = m
        self.out_dim = m.heads.head.in_features
        self.backbone.heads = nn.Identity()  # remove classifier -> returns (B, out_dim)

    def forward(self, x):
        x = self.backbone(x)  # (B, D)
        x = torch.nn.functional.normalize(x, dim=1)
        return x

class GeM(nn.Module):
    def __init__(self, p=3.0, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.tensor(p))
        self.eps = eps

    def forward(self, x):  # (B,C,H,W)
        x = x.clamp(min=self.eps).pow(self.p)
        x = F.avg_pool2d(x, (x.size(-2), x.size(-1)))
        x = x.pow(1.0 / self.p)
        return x.squeeze(-1).squeeze(-1)  # (B,C)

class Head(nn.Module):
    """
    GeM → Linear/BN/PReLU → L2-norm
    """
    def __init__(self, d_model, out_dim=None):
        super().__init__()
        out_dim = d_model if out_dim is None else out_dim
        self.gem = GeM()
        self.fc  = nn.Linear(d_model, out_dim, bias=False)
        self.bn  = nn.BatchNorm1d(out_dim)
        self.act = nn.PReLU(out_dim)

    def forward(self, x):
        # x can be (B,C,H,W) or already pooled (B,C)
        if x.dim() == 4:
            x = self.gem(x)
        x = self.act(self.bn(self.fc(x)))
        return F.normalize(x, dim=1)





## C. Training

In [24]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models.vision_transformer import interpolate_embeddings
import timm
from timm.data import resolve_model_data_config, create_transform


# ---- ArcFace (minimal) + dynamic margin handled in train loop
class ArcFace(nn.Module):
    def __init__(self, in_dim, num_classes, s=30.0):
        super().__init__()
        self.W = nn.Parameter(torch.empty(num_classes, in_dim))
        nn.init.xavier_uniform_(self.W)
        self.s = s

    def forward(self, x, y, m):
        x = F.normalize(x, dim=1)
        W = F.normalize(self.W, dim=1)
        cos = F.linear(x, W).clamp(-1 + 1e-7, 1 - 1e-7)  # (B, num_classes)

        sin = torch.sqrt(1.0 - cos * cos)
        cos_m, sin_m = torch.cos(torch.tensor(m, device=x.device)), torch.sin(torch.tensor(m, device=x.device))
        phi = cos * cos_m - sin * sin_m

        onehot = torch.zeros_like(cos)
        onehot.scatter_(1, y.view(-1, 1), 1.0)
        logits = (onehot * phi) + ((1.0 - onehot) * cos)
        return logits * self.s

@torch.no_grad()
def get_emb(loader, vit, head, device):
    embs, ys = [], []
    for x, y in loader:
        x = x.to(device)
        e = head(vit(x))
        embs.append(e.cpu())
        ys.append(y)
    return torch.cat(embs), torch.cat(ys)

@torch.no_grad()
def dba(gallery_emb, k=10):
    # gallery_emb must be L2-normalized
    sim = gallery_emb @ gallery_emb.T
    idx = sim.topk(k + 1, dim=1).indices[:, 1:]      # skip self
    neigh = gallery_emb[idx]                          # (N, k, D)
    out = (gallery_emb.unsqueeze(1) + neigh).mean(1)  # (N, D)
    return F.normalize(out, dim=1)

@torch.no_grad()
def knn_acc(query_emb, query_y, gallery_emb, gallery_y, k=1):
    sim = query_emb @ gallery_emb.T
    nn_idx = sim.topk(k, dim=1).indices              # (Nq, k)
    pred = gallery_y[nn_idx[:, 0]]                   # top-1
    return (pred == query_y).float().mean().item()



In [25]:
import numpy as np
import torch
import timm
from timm.data import resolve_model_data_config, create_transform

def build_timm_transforms(model, img_size, is_train):
    cfg = resolve_model_data_config(model)
    cfg = dict(cfg)
    cfg["input_size"] = (3, img_size, img_size)   # force stage size
    tf = create_transform(**cfg, is_training=is_train)
    return tf, cfg
def average_precision_from_scores(scores: np.ndarray, rel: np.ndarray) -> float:
    order = scores.argsort()[::-1]
    rel_sorted = np.asarray(rel, dtype=bool)[order]
    pos = np.flatnonzero(rel_sorted)          # ranks (0-based) where rel=1
    if pos.size == 0:
        return 0.0
    return float((np.arange(1, pos.size + 1) / (pos + 1)).mean())


def identity_balanced_map(
    query_emb: torch.Tensor,
    query_y: torch.Tensor,
    gallery_emb: torch.Tensor,
    gallery_y: torch.Tensor,
    query_names=None,
    gallery_names=None,
    remove_self=True,
) -> float:
    sim = (query_emb @ gallery_emb.T).detach().cpu().numpy()   # (Nq, Ng)
    qy  = query_y.detach().cpu().numpy()
    gy  = gallery_y.detach().cpu().numpy()

    # precompute relevance masks per identity in the gallery
    masks = {i: (gy == i) for i in np.unique(qy)}

    name_to_gidx = None
    if remove_self and query_names is not None and gallery_names is not None:
        name_to_gidx = {n: j for j, n in enumerate(gallery_names)}

    sum_ap, cnt = {}, {}
    for i, ident in enumerate(qy):
        order = sim[i].argsort()[::-1]
        rel_sorted = masks[ident][order]

        # remove self-match if query image also exists in gallery
        if name_to_gidx is not None:
            j = name_to_gidx.get(query_names[i], None)
            if j is not None:
                rel_sorted = rel_sorted & (order != j)

        pos = np.flatnonzero(rel_sorted)
        ap = 0.0 if pos.size == 0 else float((np.arange(1, pos.size + 1) / (pos + 1)).mean())

        sum_ap[ident] = sum_ap.get(ident, 0.0) + ap
        cnt[ident]    = cnt.get(ident, 0) + 1

    return float(np.mean([sum_ap[k] / cnt[k] for k in sum_ap])) if sum_ap else 0.0


In [26]:
@torch.no_grad()
def get_emb_and_names(loader, vit, head, device):
    embs, ys, names = [], [], []
    ptr = 0  # used only when loader returns labels and dataset has paths

    for x, y_or_name in loader:
        bs = x.size(0)
        x = x.to(device)
        embs.append(head(vit(x)).cpu())

        if isinstance(y_or_name[0], str):
            names.extend(list(y_or_name))
        else:
            ys.append(y_or_name.cpu())
            # recover filenames if possible (val/test_tf loaders are usually not shuffled)
            if hasattr(loader.dataset, "paths"):
                batch_paths = loader.dataset.paths[ptr:ptr+bs]
                names.extend([os.path.basename(p) for p in batch_paths])
                ptr += bs

    embs = torch.cat(embs)
    ys = torch.cat(ys) if len(ys) else None
    return embs, ys, names



## GPU memory testing

In [27]:
import gc, torch, torch.nn.functional as F

def enable_grad_ckpt(model, enabled=True):
    if hasattr(model, "set_grad_checkpointing"):
        model.set_grad_checkpointing(enabled)
        return True
    if hasattr(model, "grad_checkpointing"):
        model.grad_checkpointing = enabled
        return True
    return False

def _set_bn_eval(m):
    if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
        m.eval()

def one_train_step_peak_gb(backbone, head, arc, img_size, batch_size, num_classes, device="cuda", amp=True):
    backbone.train(); head.train(); arc.train()

    # BN-safe: if batch_size==1, run BN layers in eval (head only)
    if batch_size == 1:
        head.apply(_set_bn_eval)

    amp = amp and device.startswith("cuda")
    opt = torch.optim.AdamW(
        list(backbone.parameters()) + list(head.parameters()) + list(arc.parameters()),
        lr=1e-4
    )
    scaler = torch.amp.GradScaler("cuda", enabled=amp)

    if device.startswith("cuda"):
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()

    x = torch.randn(batch_size, 3, img_size, img_size, device=device)
    y = torch.randint(0, num_classes, (batch_size,), device=device)

    opt.zero_grad(set_to_none=True)
    with torch.amp.autocast(device_type="cuda", enabled=amp):
        emb = head(backbone(x))
        logits = arc(emb, y, 0.5)
        loss = F.cross_entropy(logits, y)

    scaler.scale(loss).backward()
    scaler.step(opt)
    scaler.update()

    if device.startswith("cuda"):
        return torch.cuda.max_memory_allocated() / 1024**3
    return 0.0

def find_max_batch(backbone, head, arc, img_size, num_classes, device="cuda", amp=True,
                   candidates=(2,4,8,16,24,32,48,64,96,128)):
    ok_bs, ok_peak = None, None
    for bs in candidates:
        try:
            peak = one_train_step_peak_gb(backbone, head, arc, img_size, bs, num_classes, device=device, amp=amp)
            ok_bs, ok_peak = bs, peak
            print(f"OK  bs={bs:>3} | peak_alloc={peak:.2f} GB")
        except torch.cuda.OutOfMemoryError:
            print(f"OOM bs={bs:>3}")
            if device.startswith("cuda"):
                torch.cuda.empty_cache()
            gc.collect()
            break
    return ok_bs, ok_peak


In [28]:
import os
import torch
import torch.nn.functional as F
import timm

def train_fold_SWIN(train_dir, test_dir, fold=0, n_folds=5,
                    epochs_stage_1=5, epochs_stage_2=2,
                    lr_stage_1=1e-3, lr_stage_2=4e-4,
                    out_dir="/content/drive/MyDrive/jaguar-re-id(1)/ckpts/"):

    os.makedirs(out_dir, exist_ok=True)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    stages = [(384, epochs_stage_1, lr_stage_1), (448, epochs_stage_2, lr_stage_2)]
    ckpt0 = os.path.join(out_dir, f"mega_fold{fold}_stage0.pth")
    ckptb = os.path.join(out_dir, f"mega_fold{fold}_best.pth")

    best = -1.0
    best_meta = None

    for si, (img_size, epochs, lr) in enumerate(stages):
        backbone = timm.create_model("hf-hub:BVRA/MegaDescriptor-L-384", pretrained=(si == 0)).to(device)

        train_tf, _ = build_timm_transforms(backbone, img_size, is_train=True)
        test_tf,  _ = build_timm_transforms(backbone, img_size, is_train=False)

        train_loader, train_eval_loader, val_loader, _, meta = make_loaders(
            train_dir=train_dir, test_dir=test_dir,
            img_size=img_size, n_folds=n_folds, fold=fold,
            train_tf=train_tf, test_tf=test_tf
        )
        num_classes = meta["num_classes"]

        head = Head(d_model=1536).to(device)
        arc  = ArcFace(in_dim=1536, num_classes=num_classes).to(device)

        if si > 0:
            s = torch.load(ckpt0, map_location="cpu")
            backbone.load_state_dict(s["backbone"], strict=True)
            head.load_state_dict(s["head"], strict=True)
            arc.load_state_dict(s["arc"], strict=True)

        opt = torch.optim.Adam([
            {"params": backbone.parameters(), "lr": lr * 0.1},
            {"params": head.parameters(),     "lr": lr},
            {"params": arc.parameters(),      "lr": lr},
        ])
        use_amp = (device == "cuda")
        scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
        m = 0.5

        for ep in range(epochs):
            backbone.train(); head.train(); arc.train()
            for x, y in train_loader:
                x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
                opt.zero_grad(set_to_none=True)
                with torch.cuda.amp.autocast(enabled=use_amp):
                    emb = head(backbone(x))
                    loss = F.cross_entropy(arc(emb, y, m), y)
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()

            backbone.eval(); head.eval(); arc.eval()
            tr_emb, tr_y, _ = get_emb_and_names(train_eval_loader, backbone, head, device)
            va_emb, va_y, _ = get_emb_and_names(val_loader,       backbone, head, device)
            score = identity_balanced_map(va_emb, va_y, tr_emb, tr_y, remove_self=False)

            print(f"[fold {fold}] [stage {img_size}] ep {ep+1}/{epochs} ibmAP={score:.4f}")

            if score > best:
                best, best_meta = score, meta
                torch.save(
                    {"img_size": img_size, "backbone": backbone.state_dict(), "head": head.state_dict(), "arc": arc.state_dict()},
                    ckptb
                )

        if si == 0:
            torch.save(
                {"img_size": img_size, "backbone": backbone.state_dict(), "head": head.state_dict(), "arc": arc.state_dict()},
                ckpt0
            )

    return ckptb, best, best_meta


def run_cv_SWIN(train_dir, test_dir, n_folds=5,
                epochs_stage_1=5, epochs_stage_2=2,
                lr_stage_1=1e-3, lr_stage_2=4e-4,
                out_dir="/content/drive/MyDrive/jaguar-re-id(1)/ckpts/"):

    ckpts, scores = [], []
    for fold in range(n_folds):
        ckpt, score, _ = train_fold_SWIN(
            train_dir, test_dir, fold=fold, n_folds=n_folds,
            epochs_stage_1=epochs_stage_1, epochs_stage_2=epochs_stage_2,
            lr_stage_1=lr_stage_1, lr_stage_2=lr_stage_2,
            out_dir=out_dir
        )
        ckpts.append(ckpt); scores.append(score)
    return ckpts, scores


In [None]:

out_dir = "/content/drive/MyDrive/jaguar-re-id(1)/ckpts_MegaDescriptor/"

ckpts, scores = run_cv_SWIN(Path_train, Path_test, n_folds=2, epochs_stage_1=12, epochs_stage_2=2,lr_stage_1=1e-3, lr_stage_2=4e-4 ,out_dir=out_dir)

print("ckpts:", ckpts)
print("scores:", scores)
print("mean:", sum(scores)/len(scores))




  scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):


[fold 0] [stage 384] ep 1/12 ibmAP=0.4948
[fold 0] [stage 384] ep 2/12 ibmAP=0.6220
[fold 0] [stage 384] ep 3/12 ibmAP=0.7284
[fold 0] [stage 384] ep 4/12 ibmAP=0.7680
[fold 0] [stage 384] ep 5/12 ibmAP=0.7935


In [None]:
# ---- Training function: progressive resize + pos-emb interpolation (1 fold)
import os
import torch

def train_fold(train_dir, test_dir, fold=0, n_folds=5,
               epochs_224=2, epochs_384=1, lr_224=3e-4, lr_384=1e-4,
               out_dir="/content/drive/MyDrive/jaguar-re-id(1)/ckpts/"):

    os.makedirs(out_dir, exist_ok=True)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    weights = ViT_B_16_Weights.IMAGENET1K_V1
    pre = weights.transforms()
    mean = pre.mean if hasattr(pre, "mean") else (0.485, 0.456, 0.406)
    std  = pre.std  if hasattr(pre, "std")  else (0.229, 0.224, 0.225)

    ckpt_stage0 = os.path.join(out_dir, f"vit_fold{fold}_stage224.pth")
    ckpt_best   = os.path.join(out_dir, f"vit_fold{fold}_best.pth")

    stages = [(224, epochs_224, lr_224), (384, epochs_384, lr_384)]

    best_score = -1.0
    best_meta = None

    for stage_i, (img_size, epochs, lr) in enumerate(stages):
        train_loader, train_eval_loader, val_loader, _, meta = make_loaders(
            train_dir=train_dir, test_dir=test_dir,
            img_size=img_size, n_folds=n_folds, fold=fold,
            mean=mean, std=std
        )
        num_classes = meta["num_classes"]

        if stage_i == 0:
            vit = vit_b_16(weights=weights, image_size=img_size)
            vit.heads = nn.Identity()
            head = Head(d_model=768).to(device)
            arc  = ArcFace(in_dim=768, num_classes=num_classes).to(device)
        else:
            vit = vit_b_16(weights=None, image_size=img_size)
            vit.heads = nn.Identity()
            head = Head(d_model=768).to(device)
            arc  = ArcFace(in_dim=768, num_classes=num_classes).to(device)

            state = torch.load(ckpt_stage0, map_location="cpu")
            vit_state = interpolate_embeddings(
                image_size=img_size, patch_size=16,
                model_state=state["vit"], interpolation_mode="bicubic"
            )
            vit.load_state_dict(vit_state, strict=False)
            head.load_state_dict(state["head"])
            arc.load_state_dict(state["arc"])

        vit = vit.to(device)
        #for p in vit.parameters(): p.requires_grad = False
        #vit.eval()

        vit_lr = lr * 0.1  # <-- smaller lr for ViT (example: 10x smaller)

        opt = torch.optim.Adam([
            {"params": vit.parameters(),  "lr": vit_lr},
            {"params": head.parameters(), "lr": lr},
            {"params": arc.parameters(),  "lr": lr},
        ])

        for epoch in range(epochs):
            head.train(); arc.train();vit.train()

            m_start, m_end = 0.0, 0.5
            m = m_start if epochs == 1 else (m_start + (m_end - m_start) * epoch / (epochs - 1))

            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                opt.zero_grad()
                f = vit(x)
                emb = head(f)
                logits = arc(emb, y, m)
                loss = F.cross_entropy(logits, y)
                loss.backward()
                opt.step()

            # ---- VALIDATION (your ib-mAP eval)
            vit.eval()
            head.eval()
            arc.eval()

            # IMPORTANT: use a non-aug train loader for embeddings (shuffle=False, test_tf)
            tr_emb, tr_y, tr_names = get_emb_and_names(train_eval_loader, vit, head, device)
            va_emb, va_y, va_names = get_emb_and_names(val_loader,       vit, head, device)

            score_map = identity_balanced_map(
                query_emb=va_emb, query_y=va_y,
                gallery_emb=tr_emb, gallery_y=tr_y,
                remove_self=False,
            )

            print(f"[fold {fold}] [stage {img_size}] epoch {epoch+1}/{epochs}  margin={m:.3f}  val_ibmAP={score_map:.4f}")

            if score_map > best_score:
                best_score = score_map
                best_meta = meta
                torch.save(
                    {"vit": vit.state_dict(), "head": head.state_dict(), "arc": arc.state_dict()},
                    ckpt_best
                )

        # save stage-0 checkpoint for stage-1 init
        if stage_i == 0:
            torch.save(
                {"vit": vit.state_dict(), "head": head.state_dict(), "arc": arc.state_dict()},
                ckpt_stage0
            )

    return ckpt_best, best_score, best_meta



def run_cv(train_dir, test_dir, n_folds=5,
           epochs_224=2, epochs_384=1, lr_224=3e-4, lr_384=1e-4,
           out_dir="/content/drive/MyDrive/jaguar-re-id(1)/ckpts/"):

    ckpts, scores = [], []
    for fold in range(n_folds):
        ckpt, score, _ = train_fold(
            train_dir, test_dir,
            fold=fold, n_folds=n_folds,
            epochs_224=epochs_224, epochs_384=epochs_384,
            lr_224=lr_224, lr_384=lr_384,
            out_dir=out_dir
        )
        ckpts.append(ckpt)
        scores.append(score)
    return ckpts, scores



In [None]:

out_dir = "/content/drive/MyDrive/jaguar-re-id(1)/ckpts_MegaDescriptor/"

ckpts, scores = run_cv(Path_train, Path_test, n_folds=2, epochs_384=12, epochs_512=2, out_dir=out_dir)

print("ckpts:", ckpts)
print("scores:", scores)
print("mean:", sum(scores)/len(scores))




Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:02<00:00, 123MB/s]


[fold 0] [stage 224] epoch 1/5  margin=0.000  val_ibmAP=0.4623
[fold 0] [stage 224] epoch 2/5  margin=0.125  val_ibmAP=0.6246
[fold 0] [stage 224] epoch 3/5  margin=0.250  val_ibmAP=0.7365
[fold 0] [stage 224] epoch 4/5  margin=0.375  val_ibmAP=0.7971
[fold 0] [stage 224] epoch 5/5  margin=0.500  val_ibmAP=0.8297


## D.Inference

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torchvision.models import vit_b_16
from tqdm import tqdm
import math
import torch
import torch.nn as nn
from torchvision.models import vit_b_16
from torchvision.models.vision_transformer import interpolate_embeddings

def load_vit_head(ckpt_path, device, force_img_size=None):
    state = torch.load(ckpt_path, map_location="cpu")
    vit_state = state["vit"]

    # ---- infer checkpoint img_size from pos_embedding length
    # pos_embedding shape: (1, 1 + (H/16)*(W/16), 768)
    n_tokens = vit_state["encoder.pos_embedding"].shape[1]
    grid = int(round(math.sqrt(n_tokens - 1)))   # e.g. 14 for 224, 24 for 384
    ckpt_img_size = grid * 16

    img_size = ckpt_img_size if force_img_size is None else int(force_img_size)

    vit = vit_b_16(weights=None, image_size=img_size)
    vit.heads = nn.Identity()
    head = Head(d_model=768)

    # ---- if you force a different img_size, interpolate pos embeddings
    if img_size != ckpt_img_size:
        vit_state = interpolate_embeddings(
            image_size=img_size,
            patch_size=16,
            model_state=vit_state,
            interpolation_mode="bicubic"
        )

    vit.load_state_dict(vit_state, strict=True)
    head.load_state_dict(state["head"], strict=True)

    vit.to(device).eval()
    head.to(device).eval()
    return vit, head, img_size


@torch.no_grad()
def embed_images(unique_filenames, test_dir, vit, head, device, img_size, mean, std, batch_size=32):
    import torchvision.transforms as T
    from torch.utils.data import DataLoader

    tf = T.Compose([
        T.Resize(img_size),
        T.CenterCrop(img_size),
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    print("doing dataloader")
    paths = [os.path.join(test_dir, fn) for fn in unique_filenames]
    ds = ImageDataset(paths, labels=None, transform=tf, float_size="fp32")
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    emb_map = {}
    for x, names in tqdm(loader):
        x = x.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=(device == "cuda")):
            e = head(vit(x))
        e = e.float().cpu()
        for name, vec in zip(names, e):
            emb_map[name] = vec
    return emb_map

def make_submission(test_csv_path, test_dir, ckpt_path, out_path="submission.csv"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    mean = (0.485, 0.456, 0.406)
    std  = (0.229, 0.224, 0.225)

    df = pd.read_csv(test_csv_path)

    # optional safety checks (recommended)
    # ensures you match test.csv order and count
    assert "row_id" in df.columns and "query_image" in df.columns and "gallery_image" in df.columns
    assert len(df) == df["row_id"].shape[0]

    unique_imgs = sorted(set(df["query_image"]) | set(df["gallery_image"]))

    vit, head, img_size = load_vit_head(ckpt_path, device, force_img_size=384)
    print("doing embed image")
    emb_map = embed_images(unique_imgs, test_dir, vit, head, device, img_size, mean, std)

    sims = []
    for i,(q, g) in enumerate(zip(df["query_image"].values, df["gallery_image"].values)):
        cos_sim = float((emb_map[q] * emb_map[g]).sum().item())  # in [-1, 1]
        sim01 = (cos_sim + 1.0) * 0.5
        if i % 100 == 0 :
            print(f"Scoring pairs: {i}/{len(df)} ({100*i/len(df):.1f}%)")
        sims.append(min(1.0, max(0.0, sim01)))  # clamp to [0,1]

    sub = pd.DataFrame({"row_id": df["row_id"].values, "similarity": sims})
    sub.to_csv(out_path, index=False)
    return out_path


In [None]:

test_dir = "/content/drive/MyDrive/jaguar-re-id(1)/test/test"
test_csv = "/content/drive/MyDrive/jaguar-re-id(1)/test.csv"
ckpt_path = "/content/drive/MyDrive/jaguar-re-id(1)/ckpts/vit_fold0_best.pth"  # or best fold

out = make_submission(test_csv, test_dir, ckpt_path, out_path="submission.csv")


doing embed image
doing dataloader


  with torch.cuda.amp.autocast(enabled=(device == "cuda")):
100%|██████████| 12/12 [02:59<00:00, 14.92s/it]


Scoring pairs: 0/137270 (0.0%)
Scoring pairs: 100/137270 (0.1%)
Scoring pairs: 200/137270 (0.1%)
Scoring pairs: 300/137270 (0.2%)
Scoring pairs: 400/137270 (0.3%)
Scoring pairs: 500/137270 (0.4%)
Scoring pairs: 600/137270 (0.4%)
Scoring pairs: 700/137270 (0.5%)
Scoring pairs: 800/137270 (0.6%)
Scoring pairs: 900/137270 (0.7%)
Scoring pairs: 1000/137270 (0.7%)
Scoring pairs: 1100/137270 (0.8%)
Scoring pairs: 1200/137270 (0.9%)
Scoring pairs: 1300/137270 (0.9%)
Scoring pairs: 1400/137270 (1.0%)
Scoring pairs: 1500/137270 (1.1%)
Scoring pairs: 1600/137270 (1.2%)
Scoring pairs: 1700/137270 (1.2%)
Scoring pairs: 1800/137270 (1.3%)
Scoring pairs: 1900/137270 (1.4%)
Scoring pairs: 2000/137270 (1.5%)
Scoring pairs: 2100/137270 (1.5%)
Scoring pairs: 2200/137270 (1.6%)
Scoring pairs: 2300/137270 (1.7%)
Scoring pairs: 2400/137270 (1.7%)
Scoring pairs: 2500/137270 (1.8%)
Scoring pairs: 2600/137270 (1.9%)
Scoring pairs: 2700/137270 (2.0%)
Scoring pairs: 2800/137270 (2.0%)
Scoring pairs: 2900/137270