In [None]:
!pip install -qU timm

import os, math, random, warnings, gc
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm.auto import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import timm
warnings.filterwarnings('ignore')

# ---- Device & Seed ----
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(42)

# ---- Load Data ----
INPUT_DIR = '/kaggle/input/jaguar-re-id'
TRAIN_DIR = os.path.join(INPUT_DIR, 'train/train')
TEST_DIR = os.path.join(INPUT_DIR, 'test/test')

train_df = pd.read_csv(os.path.join(INPUT_DIR, 'train.csv'))
test_df = pd.read_csv(os.path.join(INPUT_DIR, 'test.csv'))
sample_sub = pd.read_csv(os.path.join(INPUT_DIR, 'sample_submission.csv'))

# Label encoding
label_to_idx = {l: i for i, l in enumerate(sorted(train_df['ground_truth'].unique()))}
train_df['label'] = train_df['ground_truth'].map(label_to_idx)
NUM_CLASSES = len(label_to_idx)

print(f"PyTorch {torch.__version__} | Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"\nTraining images: {len(train_df)}")
print(f"Individuals: {NUM_CLASSES}")
print(f"Test pairs: {len(test_df):,}")


In [None]:
# ============================================================
# EXPLORATORY DATA ANALYSIS
# ============================================================

fig, axes = plt.subplots(4, 5, figsize=(16, 13))
fig.suptitle('Sample Jaguar Images by Individual', fontsize=16, fontweight='bold')

individuals = sorted(train_df['ground_truth'].unique())
np.random.seed(42)
selected = list(np.random.choice(individuals, size=min(4, len(individuals)), replace=False))

for row_idx, indiv in enumerate(selected):
    indiv_df = train_df[train_df['ground_truth'] == indiv]
    samples = indiv_df.sample(min(5, len(indiv_df)), random_state=42)
    for col_idx, (_, img_row) in enumerate(samples.iterrows()):
        if col_idx >= 5:
            break
        try:
            img = Image.open(os.path.join(TRAIN_DIR, img_row['filename']))
            axes[row_idx, col_idx].imshow(img)
        except:
            pass
        axes[row_idx, col_idx].axis('off')
        if col_idx == 0:
            n_imgs = len(indiv_df)
            axes[row_idx, col_idx].set_title(f"{indiv[:20]}\n({n_imgs} imgs)", fontsize=9, fontweight='bold')
    for col_idx in range(len(samples), 5):
        axes[row_idx, col_idx].axis('off')

plt.tight_layout()
plt.show()

fig, ax = plt.subplots(figsize=(14, 4))
counts = train_df['ground_truth'].value_counts().sort_index()
colors = plt.cm.viridis(np.linspace(0.3, 0.9, NUM_CLASSES))
ax.bar(range(NUM_CLASSES), counts.values, color=colors)
ax.set_xlabel('Individual Jaguar')
ax.set_ylabel('Number of Images')
ax.set_title(f'Images per Individual ({len(train_df)} total, {NUM_CLASSES} jaguars)')
ax.set_xticks(range(NUM_CLASSES))
ax.set_xticklabels(counts.index, rotation=45, ha='right', fontsize=7)
plt.tight_layout()
plt.show()

print(f"Images per individual:")
print(f"  Min: {counts.min()} | Max: {counts.max()} | Mean: {counts.mean():.1f} | Median: {counts.median():.1f}")
print(f"\nTest: {len(set(test_df['query_image']) | set(test_df['gallery_image']))} unique images, {len(test_df):,} pairs")
print(f"\nCLOSED-SET: all {NUM_CLASSES} test individuals appear in training.")
print(f"Strategy: 2x EVA02 (different seeds) + DINOv2, weighted ensemble, QE + re-rank")


In [None]:
# ============================================================
# 2x EVA02-Large (different seeds) + DINOv2-Base
# Weighted ensemble + QE + K-Reciprocal Re-ranking
# ============================================================

# ---- GeM Pooling ----
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p),
                           (x.size(-2), x.size(-1))).pow(1.0 / self.p)

# ---- ArcFace ----
class ArcFaceLayer(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.5):
        super().__init__()
        self.s, self.m = s, m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
    def forward(self, input, label=None):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        if label is None:
            return cosine
        phi = cosine - self.m
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        return ((one_hot * phi) + ((1.0 - one_hot) * cosine)) * self.s

# ---- Model (EVA02 + ViT compatible) ----
class JaguarModel(nn.Module):
    def __init__(self, model_name, num_classes, s=30.0, m=0.5):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True,
                                          num_classes=0, dynamic_img_size=True)
        self.feat_dim = self.backbone.num_features
        self.gem = GeM()
        self.bn = nn.BatchNorm1d(self.feat_dim)
        self.head = ArcFaceLayer(self.feat_dim, num_classes, s=s, m=m)
        n_params = sum(p.numel() for p in self.parameters())
        print(f"  Dim: {self.feat_dim} | Params: {n_params:,}")

    def forward(self, x, label=None):
        features = self.backbone.forward_features(x)

        if features.dim() == 3:
            B, N, C = features.shape
            H = W = int(math.sqrt(N))
            if H * W != N:
                features = features[:, -H * W:, :]
            features = features.permute(0, 2, 1).reshape(B, C, H, W)

        emb = self.gem(features).flatten(1)
        emb = self.bn(emb)
        if label is not None:
            return self.head(emb, label)
        return emb

# ---- Dataset ----
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        try:
            img = Image.open(os.path.join(self.img_dir, row['filename'])).convert('RGB')
        except Exception:
            img = Image.new('RGB', (448, 448))
        if self.transform:
            img = self.transform(img)
        if self.is_test:
            return img, row['filename']
        return img, torch.tensor(row['label'], dtype=torch.long)

# ---- Transforms ----
def get_train_transform(img_size, strength='normal'):
    if strength == 'strong':
        return T.Compose([
            T.Resize((img_size, img_size)),
            T.RandomHorizontalFlip(),
            T.RandomAffine(degrees=20, translate=(0.15, 0.15), scale=(0.85, 1.15)),
            T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.2),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            T.RandomErasing(p=0.35),
        ])
    return T.Compose([
        T.Resize((img_size, img_size)),
        T.RandomHorizontalFlip(),
        T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        T.ColorJitter(brightness=0.2, contrast=0.2),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        T.RandomErasing(p=0.25),
    ])

def get_test_transform(img_size):
    return T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

# ============================================================
# TRAINING FUNCTION
# ============================================================
def train_and_extract(model_name, img_size, batch_size, grad_accum, epochs, lr, wd,
                      seed=42, aug_strength='normal'):
    seed_everything(seed)

    print(f"\n{'='*60}")
    print(f"TRAINING: {model_name.split('.')[-1]} (seed={seed}, aug={aug_strength})")
    print(f"  {img_size}px | bs={batch_size}x{grad_accum} | {epochs} epochs | lr={lr}")
    print(f"{'='*60}")

    model = JaguarModel(model_name, NUM_CLASSES).to(DEVICE)

    train_loader = DataLoader(
        JaguarDataset(train_df.copy(), TRAIN_DIR, get_train_transform(img_size, aug_strength)),
        batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True
    )

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = torch.amp.GradScaler('cuda')
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        optimizer.zero_grad()
        for i, (imgs, labels) in enumerate(tqdm(train_loader, leave=False, desc=f"Epoch {epoch+1}")):
            imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
            with torch.amp.autocast('cuda'):
                loss = criterion(model(imgs, labels), labels) / grad_accum
            scaler.scale(loss).backward()
            if (i + 1) % grad_accum == 0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
            total_loss += loss.item() * grad_accum
        # Flush remaining gradients
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        scheduler.step()
        print(f"  Epoch {epoch+1}/{epochs} | Loss: {total_loss/len(train_loader):.4f} | LR: {scheduler.get_last_lr()[0]:.2e}")

    # Extract embeddings with TTA (original + hflip)
    print("  Extracting embeddings...")
    test_loader = DataLoader(
        JaguarDataset(pd.DataFrame({'filename': unique_test}), TEST_DIR,
                     get_test_transform(img_size), is_test=True),
        batch_size=batch_size, shuffle=False, num_workers=2
    )

    model.eval()
    feats, names = [], []
    with torch.no_grad():
        for imgs, fnames in tqdm(test_loader, leave=False, desc="Inference"):
            imgs = imgs.to(DEVICE)
            with torch.amp.autocast('cuda'):
                f1 = model(imgs)
                f2 = model(torch.flip(imgs, [3]))
            f_avg = F.normalize((f1 + f2) / 2, dim=1)
            feats.append(f_avg.float().cpu())
            names.extend(fnames)

    embeddings = torch.cat(feats, 0).numpy()

    del model, optimizer, scaler, scheduler, train_loader, test_loader
    gc.collect()
    torch.cuda.empty_cache()

    print(f"  Done! Embeddings: {embeddings.shape}")
    return embeddings, names

# ============================================================
# POST-PROCESSING
# ============================================================
def query_expansion(emb, top_k=3):
    print("  Applying Query Expansion (top_k=3)...")
    sims = emb @ emb.T
    indices = np.argsort(-sims, axis=1)[:, :top_k]
    new_emb = np.zeros_like(emb)
    for i in range(len(emb)):
        new_emb[i] = np.mean(emb[indices[i]], axis=0)
    return new_emb / np.linalg.norm(new_emb, axis=1, keepdims=True)

def k_reciprocal_rerank(prob, k1=20, k2=6, lambda_value=0.3):
    print("  Applying K-Reciprocal Re-ranking...")
    q_g_dist = 1 - prob
    original_dist = q_g_dist.copy()
    initial_rank = np.argsort(original_dist, axis=1)

    nn_k1 = []
    for i in range(prob.shape[0]):
        forward_k1 = initial_rank[i, :k1 + 1]
        backward_k1 = initial_rank[forward_k1, :k1 + 1]
        fi = np.where(backward_k1 == i)[0]
        nn_k1.append(forward_k1[fi])

    jaccard_dist = np.zeros_like(original_dist)
    for i in range(prob.shape[0]):
        ind_non_zero = np.where(original_dist[i, :] < 0.6)[0]
        ind_images = [inv for inv in ind_non_zero
                      if len(np.intersect1d(nn_k1[i], nn_k1[inv])) > 0]
        for j in ind_images:
            intersection = len(np.intersect1d(nn_k1[i], nn_k1[j]))
            union = len(np.union1d(nn_k1[i], nn_k1[j]))
            jaccard_dist[i, j] = 1 - intersection / union

    return 1 - (jaccard_dist * lambda_value + original_dist * (1 - lambda_value))

# ============================================================
# MAIN: TRAIN 3 MODELS, WEIGHTED ENSEMBLE, SUBMIT
# ============================================================
unique_test = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
print(f"\nUnique test images: {len(unique_test)}")
print(f"\n{'#'*60}")
print(f"# STRATEGY: Beat 0.937")
print(f"# 1. EVA02-Large seed=42  (normal aug)  -> weight 0.40")
print(f"# 2. EVA02-Large seed=123 (strong aug)  -> weight 0.40")
print(f"# 3. DINOv2-Base seed=42  (normal aug)  -> weight 0.20")
print(f"# Weighted ensemble -> QE -> Re-ranking")
print(f"{'#'*60}")

MODELS = [
    # (name, img_size, bs, ga, epochs, lr, wd, seed, aug, weight)
    ('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', 448, 4, 4, 10, 2e-5, 1e-3, 42,  'normal', 0.40),
    ('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', 448, 4, 4, 10, 2e-5, 1e-3, 123, 'strong', 0.40),
    ('vit_base_patch14_dinov2.lvd142m',                 518, 8, 2, 10, 1e-5, 1e-4, 42,  'normal', 0.20),
]

all_sim_matrices = []
all_weights = []
img_map = None

for model_name, img_size, bs, ga, epochs, lr, wd, seed, aug, weight in MODELS:
    try:
        emb, names = train_and_extract(model_name, img_size, bs, ga, epochs, lr, wd,
                                       seed=seed, aug_strength=aug)
        if img_map is None:
            img_map = {n: i for i, n in enumerate(names)}

        emb = query_expansion(emb)
        sim = emb @ emb.T
        all_sim_matrices.append(sim)
        all_weights.append(weight)
        print(f"  Model complete! (weight={weight}) \u2705")

    except Exception as e:
        print(f"\n  FAILED: {e}")
        import traceback
        traceback.print_exc()
        print(f"  Skipping...\n")
        gc.collect()
        torch.cuda.empty_cache()
        continue

# ---- Weighted Ensemble ----
print(f"\n{'='*60}")
print(f"WEIGHTED ENSEMBLE: {len(all_sim_matrices)} models")
print(f"{'='*60}")

if len(all_sim_matrices) == 0:
    raise RuntimeError("No models trained successfully!")

# Normalize weights to sum to 1.0
total_w = sum(all_weights)
norm_weights = [w / total_w for w in all_weights]
print(f"  Weights: {norm_weights}")

combined_sim = sum(w * s for w, s in zip(norm_weights, all_sim_matrices))
combined_sim = k_reciprocal_rerank(combined_sim)

# ---- Generate Submission ----
print("\nGenerating submission...")
preds = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Mapping"):
    s = combined_sim[img_map[row['query_image']], img_map[row['gallery_image']]]
    preds.append(max(0.0, min(1.0, float(s))))

sub = pd.DataFrame({'row_id': test_df['row_id'], 'similarity': preds})
sub.to_csv('submission.csv', index=False)

print(f"\n{'='*60}")
print(f"DONE!")
print(f"{'='*60}")
print(f"Models ensembled: {len(all_sim_matrices)}")
print(f"Weights: {norm_weights}")
print(f"Mean similarity: {np.mean(preds):.4f}")
print(f"Range: [{min(preds):.4f}, {max(preds):.4f}]")
print(f"\nSubmission saved to submission.csv")
print(sub.head(10))
