In [None]:
!pip install -qU timm

import os, math, random, warnings, gc
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from sklearn.model_selection import StratifiedKFold
import timm
warnings.filterwarnings('ignore')

# ---- Device & Seed ----
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(42)

# ---- Load Data ----
INPUT_DIR = '/kaggle/input/jaguar-re-id'
TRAIN_DIR = os.path.join(INPUT_DIR, 'train/train')
TEST_DIR = os.path.join(INPUT_DIR, 'test/test')

train_df = pd.read_csv(os.path.join(INPUT_DIR, 'train.csv'))
test_df = pd.read_csv(os.path.join(INPUT_DIR, 'test.csv'))
sample_sub = pd.read_csv(os.path.join(INPUT_DIR, 'sample_submission.csv'))

label_to_idx = {l: i for i, l in enumerate(sorted(train_df['ground_truth'].unique()))}
train_df['label'] = train_df['ground_truth'].map(label_to_idx)
NUM_CLASSES = len(label_to_idx)

print(f"PyTorch {torch.__version__} | Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"\nTraining images: {len(train_df)}")
print(f"Individuals: {NUM_CLASSES}")
print(f"Test pairs: {len(test_df):,}")

In [None]:
# ============================================================
# EXPLORATORY DATA ANALYSIS
# ============================================================

fig, axes = plt.subplots(4, 5, figsize=(16, 13))
fig.suptitle('Sample Jaguar Images by Individual', fontsize=16, fontweight='bold')

individuals = sorted(train_df['ground_truth'].unique())
np.random.seed(42)
selected = list(np.random.choice(individuals, size=min(4, len(individuals)), replace=False))

for row_idx, indiv in enumerate(selected):
    indiv_df = train_df[train_df['ground_truth'] == indiv]
    samples = indiv_df.sample(min(5, len(indiv_df)), random_state=42)
    for col_idx, (_, img_row) in enumerate(samples.iterrows()):
        if col_idx >= 5:
            break
        try:
            img = Image.open(os.path.join(TRAIN_DIR, img_row['filename'])).convert('RGB')
            axes[row_idx, col_idx].imshow(img)
        except:
            pass
        axes[row_idx, col_idx].axis('off')
        if col_idx == 0:
            axes[row_idx, col_idx].set_title(f"{indiv[:20]}\n({len(indiv_df)} imgs)", fontsize=9, fontweight='bold')
    for col_idx in range(len(samples), 5):
        axes[row_idx, col_idx].axis('off')
plt.tight_layout()
plt.show()

fig, ax = plt.subplots(figsize=(14, 4))
counts = train_df['ground_truth'].value_counts().sort_index()
ax.bar(range(NUM_CLASSES), counts.values, color=plt.cm.viridis(np.linspace(0.3, 0.9, NUM_CLASSES)))
ax.set_xlabel('Individual Jaguar')
ax.set_ylabel('Number of Images')
ax.set_title(f'Images per Individual ({len(train_df)} total, {NUM_CLASSES} jaguars)')
ax.set_xticks(range(NUM_CLASSES))
ax.set_xticklabels(counts.index, rotation=45, ha='right', fontsize=7)
plt.tight_layout()
plt.show()

print(f"Images per individual: Min={counts.min()} Max={counts.max()} Mean={counts.mean():.1f}")
print(f"Test: {len(set(test_df['query_image']) | set(test_df['gallery_image']))} unique images, {len(test_df):,} pairs")
print(f"\nStrategy: 5-Fold EVA02-Large + Alpha Masking + Focal Loss + Val mAP + Multi-Scale TTA + QE + Re-ranking")

In [None]:
# ============================================================
# JAGUAR RE-ID V8: ALL IMPROVEMENTS
# ============================================================
# Changes from V7:
#   [1] Alpha mask support — zero out backgrounds to prevent
#       spurious correlation with scenery
#   [2] Focal Loss — directly addresses class imbalance for
#       identity-balanced mAP metric
#   [3] Val mAP monitoring + best checkpoint — stop flying blind
#   [4] LR warmup — stabilize early fine-tuning of pretrained ViT
#   [5] CLS + GeM concatenation — capture global + local features
#   [6] Differential LR — lower for backbone, higher for head
# ============================================================

N_FOLDS = 5
MODEL_NAME = 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k'
IMG_SIZE = 448
BATCH_SIZE = 4
GRAD_ACCUM = 4
EPOCHS = 12
BACKBONE_LR = 1e-5       # Lower LR for pretrained backbone
HEAD_LR = 5e-5           # Higher LR for new layers (GeM, BN, ArcFace)
WD = 1e-3
WARMUP_EPOCHS = 2        # Linear warmup before cosine decay
USE_ALPHA_MASK = True    # Apply alpha mask to zero out backgrounds
GRAD_CKPT = True         # Gradient checkpointing to reduce VRAM

# Multi-scale TTA sizes (must be divisible by patch_size=14)
TTA_SIZES = [392, 448, 518]

# Focal loss params
FOCAL_GAMMA = 2.0
FOCAL_ALPHA = 0.25

print(f"Plan: {N_FOLDS}-fold {MODEL_NAME.split('.')[0]}")
print(f"  {IMG_SIZE}px train | TTA at {TTA_SIZES}")
print(f"  {EPOCHS} epochs per fold | {WARMUP_EPOCHS} warmup epochs")
print(f"  Alpha masking: {USE_ALPHA_MASK}")
print(f"  Backbone LR: {BACKBONE_LR} | Head LR: {HEAD_LR}")
print(f"  Focal Loss: gamma={FOCAL_GAMMA}, alpha={FOCAL_ALPHA}")

In [None]:
# ============================================================
# FOCAL LOSS — addresses class imbalance directly
# ============================================================
class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance.
    Down-weights easy examples so rare jaguars with fewer images
    contribute more meaningfully to the loss."""
    def __init__(self, gamma=2.0, alpha=0.25, reduction='mean'):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
        if self.reduction == 'mean':
            return focal_loss.mean()
        return focal_loss


# ============================================================
# GeM POOLING
# ============================================================
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p),
                           (x.size(-2), x.size(-1))).pow(1.0 / self.p)


# ============================================================
# ARCFACE HEAD
# ============================================================
class ArcFaceLayer(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    def forward(self, input, label=None):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        if label is None:
            return cosine
        phi = cosine - self.m
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        return ((one_hot * phi) + ((1.0 - one_hot) * cosine)) * self.s


# ============================================================
# MODEL — now with CLS + GeM concatenation
# ============================================================
class JaguarModel(nn.Module):
    def __init__(self, num_classes, s=30.0, m=0.5):
        super().__init__()
        self.backbone = timm.create_model(MODEL_NAME, pretrained=True,
                                          num_classes=0, dynamic_img_size=True)
        if GRAD_CKPT:
            self.backbone.set_grad_checkpointing(True)
            print('  Gradient checkpointing enabled')
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()

        # CLS + GeM concat → project back to feat_dim
        self.neck = nn.Sequential(
            nn.Linear(self.feat_dim * 2, self.feat_dim),
            nn.BatchNorm1d(self.feat_dim),
        )
        self.head = ArcFaceLayer(self.feat_dim, num_classes, s=s, m=m)

    def forward(self, x, label=None):
        features = self.backbone.forward_features(x)

        if features.dim() == 3:
            B, N, C = features.shape
            # CLS token is the first token
            cls_token = features[:, 0, :]           # (B, C)
            patch_tokens = features[:, 1:, :]       # (B, N-1, C)

            H = W = int(math.sqrt(patch_tokens.shape[1]))
            # Handle case where patch count isn't a perfect square
            if H * W != patch_tokens.shape[1]:
                H = W = int(math.sqrt(patch_tokens.shape[1]))
                patch_tokens = patch_tokens[:, :H*W, :]

            spatial = patch_tokens.permute(0, 2, 1).reshape(B, C, H, W)
            gem_feat = self.gem(spatial).flatten(1)  # (B, C)

            # Concatenate CLS + GeM for global + local features
            emb = torch.cat([cls_token, gem_feat], dim=1)  # (B, 2C)
        else:
            # Fallback for CNN-style outputs
            gem_feat = self.gem(features).flatten(1)
            emb = torch.cat([gem_feat, gem_feat], dim=1)  # keep dimensions consistent

        emb = self.neck(emb)  # (B, feat_dim)

        if label is not None:
            return self.head(emb, label)
        return emb

In [None]:
# ============================================================
# DATASET — with alpha mask support
# ============================================================
class JaguarDataset(Dataset):
    """Loads images with optional alpha mask application.
    When use_alpha=True, zeros out the background using the alpha channel,
    forcing the model to learn from the jaguar's spot patterns only."""

    def __init__(self, df, img_dir, transform=None, is_test=False, use_alpha=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
        self.use_alpha = use_alpha

    def __len__(self):
        return len(self.df)

    def _load_image(self, filename):
        path = os.path.join(self.img_dir, filename)
        try:
            img = Image.open(path)

            if self.use_alpha and img.mode == 'RGBA':
                # Extract alpha channel and apply as mask
                r, g, b, a = img.split()
                # Convert alpha to binary mask (0 or 255)
                a_np = np.array(a).astype(np.float32) / 255.0
                rgb_np = np.array(img.convert('RGB')).astype(np.float32)
                # Apply mask: background becomes neutral gray (128)
                # Using gray rather than black avoids creating artificial edges
                masked = rgb_np * a_np[:, :, None] + 128.0 * (1 - a_np[:, :, None])
                img = Image.fromarray(masked.astype(np.uint8))
            else:
                img = img.convert('RGB')

            return img
        except Exception:
            return Image.new('RGB', (448, 448), (128, 128, 128))

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = self._load_image(row['filename'])
        if self.transform:
            img = self.transform(img)
        if self.is_test:
            return img, row['filename']
        return img, torch.tensor(row['label'], dtype=torch.long)


# ============================================================
# TRANSFORMS — same as V7 (augmentations are adequate)
# ============================================================
def get_train_transform():
    return T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE)),
        T.RandomHorizontalFlip(),
        T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        T.ColorJitter(brightness=0.2, contrast=0.2),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        T.RandomErasing(p=0.25),
    ])

def get_test_transform(size):
    return T.Compose([
        T.Resize((size, size)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

In [None]:
# ============================================================
# VALIDATION: Identity-Balanced mAP
# ============================================================
def compute_val_map(model, val_df, img_dir, batch_size=4):
    """Compute identity-balanced mAP on the validation fold.
    This mirrors the competition metric: mAP per identity, then averaged."""
    model.eval()

    val_dataset = JaguarDataset(
        val_df, img_dir, get_test_transform(IMG_SIZE),
        is_test=False, use_alpha=USE_ALPHA_MASK
    )
    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                            shuffle=False, num_workers=2)

    all_embs = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs = imgs.to(DEVICE)
            with torch.amp.autocast('cuda'):
                emb = model(imgs)
            all_embs.append(emb.float().cpu())
            all_labels.append(labels)

    embs = torch.cat(all_embs, 0)
    embs = F.normalize(embs, dim=1).numpy()
    labels = torch.cat(all_labels, 0).numpy()

    # Compute similarity matrix
    sim = embs @ embs.T

    # Identity-balanced mAP: compute AP per identity, then average
    unique_labels = np.unique(labels)
    per_identity_aps = []

    for identity in unique_labels:
        # Get indices for this identity
        query_mask = labels == identity
        query_indices = np.where(query_mask)[0]

        if len(query_indices) < 2:
            continue  # Need at least 2 images for query-gallery pairs

        aps_for_identity = []
        for q_idx in query_indices:
            # Gallery = all images except this query
            gallery_mask = np.ones(len(labels), dtype=bool)
            gallery_mask[q_idx] = False
            gallery_indices = np.where(gallery_mask)[0]

            # Rank gallery by similarity to query
            sims = sim[q_idx, gallery_indices]
            sorted_idx = np.argsort(-sims)
            sorted_labels = labels[gallery_indices[sorted_idx]]

            # Compute AP for this query
            relevant = (sorted_labels == identity).astype(float)
            if relevant.sum() == 0:
                continue
            cumsum = np.cumsum(relevant)
            precision_at_k = cumsum / np.arange(1, len(relevant) + 1)
            ap = (precision_at_k * relevant).sum() / relevant.sum()
            aps_for_identity.append(ap)

        if aps_for_identity:
            per_identity_aps.append(np.mean(aps_for_identity))

    mean_ap = np.mean(per_identity_aps) if per_identity_aps else 0.0
    return mean_ap

In [None]:
# ============================================================
# LR SCHEDULER WITH WARMUP
# ============================================================
class WarmupCosineScheduler:
    """Linear warmup followed by cosine decay.
    Stabilizes early training when fine-tuning large pretrained ViTs."""

    def __init__(self, optimizer, warmup_epochs, total_epochs, steps_per_epoch):
        self.optimizer = optimizer
        self.warmup_steps = warmup_epochs * steps_per_epoch
        self.total_steps = total_epochs * steps_per_epoch
        self.base_lrs = [pg['lr'] for pg in optimizer.param_groups]
        self.current_step = 0

    def step(self):
        self.current_step += 1
        if self.current_step <= self.warmup_steps:
            # Linear warmup
            scale = self.current_step / max(1, self.warmup_steps)
        else:
            # Cosine decay
            progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
            scale = 0.5 * (1 + math.cos(math.pi * progress))

        for i, pg in enumerate(self.optimizer.param_groups):
            pg['lr'] = self.base_lrs[i] * scale

    def get_last_lr(self):
        return [pg['lr'] for pg in self.optimizer.param_groups]

In [None]:
# ============================================================
# EXTRACT EMBEDDINGS WITH MULTI-SCALE TTA
# ============================================================
def extract_multiscale_tta(model, unique_images, img_dir, batch_size=4):
    """Extract embeddings at multiple resolutions with horizontal flip.
    For each image: len(TTA_SIZES) scales x 2 flips = 6 views averaged."""
    model.eval()
    all_embs = []

    for tta_size in TTA_SIZES:
        loader = DataLoader(
            JaguarDataset(
                pd.DataFrame({'filename': unique_images}), img_dir,
                get_test_transform(tta_size), is_test=True,
                use_alpha=USE_ALPHA_MASK
            ),
            batch_size=batch_size, shuffle=False, num_workers=2
        )
        feats = []
        with torch.no_grad():
            for imgs, fnames in loader:
                imgs = imgs.to(DEVICE)
                with torch.amp.autocast('cuda'):
                    f1 = model(imgs)
                    f2 = model(torch.flip(imgs, [3]))
                f_avg = (f1 + f2) / 2
                feats.append(f_avg.float().cpu())
        emb = torch.cat(feats, 0)
        emb = F.normalize(emb, dim=1)
        all_embs.append(emb)
        del loader
        print(f"    TTA @ {tta_size}px done")

    # Average across scales and re-normalize
    combined = torch.stack(all_embs).mean(dim=0)
    combined = F.normalize(combined, dim=1)
    return combined.numpy()

In [None]:
# ============================================================
# POST-PROCESSING
# ============================================================
def query_expansion(emb, top_k=3):
    print("  Applying Query Expansion...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / np.linalg.norm(new_emb, axis=1, keepdims=True)

def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=0.3):
    print("  Applying K-Reciprocal Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)

    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, :k1 + 1]
        backward_k1 = initial_rank[forward_k1, :k1 + 1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])

    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [inv for inv in ind_non_zero
                      if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union

    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

In [None]:
# ============================================================
# MAIN: 5-FOLD TRAINING WITH ALL IMPROVEMENTS
# ============================================================
unique_test = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
img_map = {n: i for i, n in enumerate(unique_test)}

print(f"\nUnique test images: {len(unique_test)}")
print(f"\n{'#'*60}")
print(f"# V8: 5-FOLD EVA02-Large + Alpha Mask + Focal Loss")
print(f"# + Val mAP + LR Warmup + CLS+GeM + Multi-Scale TTA")
print(f"# TTA: {len(TTA_SIZES)} scales x 2 flips = {len(TTA_SIZES)*2} views per image")
print(f"# Then: QE + K-Reciprocal Re-ranking")
print(f"{'#'*60}")

skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=42)

all_embeddings = []  # Store normalized embeddings from each fold
fold_val_maps = []   # Track best val mAP per fold

for fold, (train_idx, val_idx) in enumerate(skf.split(train_df, train_df['label'])):
    seed_everything(42 + fold)

    print(f"\n{'='*60}")
    print(f"FOLD {fold+1}/{N_FOLDS} | Train: {len(train_idx)} | Val: {len(val_idx)}")
    print(f"{'='*60}")

    fold_train = train_df.iloc[train_idx].copy()
    fold_val = train_df.iloc[val_idx].copy()

    model = JaguarModel(NUM_CLASSES).to(DEVICE)
    print(f"  Params: {sum(p.numel() for p in model.parameters()):,}")

    train_loader = DataLoader(
        JaguarDataset(fold_train, TRAIN_DIR, get_train_transform(),
                      use_alpha=USE_ALPHA_MASK),
        batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
    )

    # ---- Differential LR: lower for backbone, higher for head ----
    backbone_params = list(model.backbone.parameters())
    head_params = (list(model.gem.parameters()) +
                   list(model.neck.parameters()) +
                   list(model.head.parameters()))

    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': BACKBONE_LR},
        {'params': head_params, 'lr': HEAD_LR},
    ], weight_decay=WD)

    # Steps-level warmup + cosine scheduler
    steps_per_epoch = math.ceil(len(train_loader) / GRAD_ACCUM)
    scheduler = WarmupCosineScheduler(optimizer, WARMUP_EPOCHS, EPOCHS, steps_per_epoch)

    scaler = torch.amp.GradScaler('cuda')
    criterion = FocalLoss(gamma=FOCAL_GAMMA, alpha=FOCAL_ALPHA)

    best_val_map = 0.0
    ckpt_path = f'./best_fold{fold}.pt'

    # ---- Train this fold ----
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        optimizer.zero_grad()

        for i, (imgs, labels) in enumerate(tqdm(train_loader, leave=False, desc=f"F{fold+1} E{epoch+1}")):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            with torch.amp.autocast('cuda'):
                loss = criterion(model(imgs, labels), labels) / GRAD_ACCUM
            scaler.scale(loss).backward()

            if (i + 1) % GRAD_ACCUM == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                scheduler.step()  # Step-level scheduling

            total_loss += loss.item() * GRAD_ACCUM

        # Handle remaining gradients
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        avg_loss = total_loss / len(train_loader)
        lrs = scheduler.get_last_lr()

        # ---- Validate every 2 epochs or on the last epoch ----
        if (epoch + 1) % 2 == 0 or epoch == EPOCHS - 1:
            gc.collect()
            torch.cuda.empty_cache()
            val_map = compute_val_map(model, fold_val, TRAIN_DIR, batch_size=BATCH_SIZE)
            improved = ''
            if val_map > best_val_map:
                best_val_map = val_map
                torch.save(model.state_dict(), ckpt_path)
                improved = ' ★ BEST'
            print(f"  Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | "
                  f"BB_LR: {lrs[0]:.2e} | Head_LR: {lrs[1]:.2e} | "
                  f"Val mAP: {val_map:.4f}{improved}")
        else:
            print(f"  Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | "
                  f"BB_LR: {lrs[0]:.2e} | Head_LR: {lrs[1]:.2e}")

    # ---- Load best checkpoint for this fold ----
    print(f"  Loading best checkpoint (val mAP: {best_val_map:.4f})")
    model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
    fold_val_maps.append(best_val_map)

    # ---- Extract embeddings with multi-scale TTA ----
    print(f"  Extracting multi-scale TTA embeddings...")
    fold_emb = extract_multiscale_tta(model, unique_test, TEST_DIR, batch_size=BATCH_SIZE)
    all_embeddings.append(fold_emb)
    print(f"  Fold {fold+1} done! Embeddings: {fold_emb.shape} | Best val mAP: {best_val_map:.4f}")

    # Free memory
    del model, optimizer, scaler, scheduler, train_loader
    gc.collect()
    torch.cuda.empty_cache()

print(f"\nPer-fold val mAP: {[f'{m:.4f}' for m in fold_val_maps]}")
print(f"Mean val mAP: {np.mean(fold_val_maps):.4f} ± {np.std(fold_val_maps):.4f}")

In [None]:
# ============================================================
# ENSEMBLE ALL FOLDS + POST-PROCESSING
# ============================================================
print(f"\n{'='*60}")
print(f"ENSEMBLING {len(all_embeddings)} FOLDS")
print(f"{'='*60}")

# Average embeddings across folds, then normalize
avg_emb = np.mean(all_embeddings, axis=0)
avg_emb = avg_emb / np.linalg.norm(avg_emb, axis=1, keepdims=True)

# Query Expansion on the averaged embeddings
avg_emb = query_expansion(avg_emb)

# Compute similarity matrix
sim_matrix = avg_emb @ avg_emb.T

# Re-ranking
sim_matrix = k_reciprocal_rerank(sim_matrix)

# ---- Generate Submission ----
print("\nGenerating submission...")
preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Mapping"):
    s = sim_matrix[img_map[row['query_image']], img_map[row['gallery_image']]]
    preds.append(max(0.0, min(1.0, float(s))))

sub = pd.DataFrame({'row_id': test_df['row_id'], 'similarity': preds})
sub.to_csv('submission.csv', index=False)

print(f"\n{'='*60}")
print(f"DONE! V8 Complete")
print(f"{'='*60}")
print(f"Folds: {len(all_embeddings)}")
print(f"Per-fold val mAP: {[f'{m:.4f}' for m in fold_val_maps]}")
print(f"Mean val mAP: {np.mean(fold_val_maps):.4f} ± {np.std(fold_val_maps):.4f}")
print(f"TTA scales: {TTA_SIZES} x 2 flips = {len(TTA_SIZES)*2} views")
print(f"Alpha masking: {USE_ALPHA_MASK}")
print(f"Mean similarity: {np.mean(preds):.4f}")
print(f"Range: [{min(preds):.4f}, {max(preds):.4f}]")
print(f"\nSubmission saved to submission.csv")
print(sub.head(10))