In [None]:
# =============================================================================
# JAGUAR RE-ID V3: DINOv2 + ArcFace (Single Cell - No Ordering Issues)
# =============================================================================
import os, random, warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import torchvision.transforms as T
import timm
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
warnings.filterwarnings('ignore')

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# =============================================================================
# CONFIG
# =============================================================================
class CFG:
    INPUT_DIR = "/kaggle/input/jaguar-re-id"
    TRAIN_DIR = os.path.join(INPUT_DIR, "train/train")
    TEST_DIR = os.path.join(INPUT_DIR, "test/test")
    MODEL_NAME = 'vit_small_patch14_dinov2.lvd142m'
    IMG_SIZE = 518
    BATCH_SIZE = 16
    EPOCHS = 12
    LR = 1e-5
    WEIGHT_DECAY = 1e-4
    ARCFACE_SCALE = 30
    ARCFACE_MARGIN = 0.5
    VAL_RATIO = 0.0
    DEVICE = torch.device('cuda')
    SEED = 42

random.seed(CFG.SEED)
np.random.seed(CFG.SEED)
torch.manual_seed(CFG.SEED)
torch.cuda.manual_seed_all(CFG.SEED)
print(f"Model: {CFG.MODEL_NAME} @ {CFG.IMG_SIZE}px")

# =============================================================================
# DATASET
# =============================================================================
class JaguarDataset(Dataset):
    def __init__(self, df, img_dir, transform=None, is_test=False):
        self.df = df.reset_index(drop=True)
        self.img_dir = img_dir
        self.transform = transform
        self.is_test = is_test
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = Image.open(os.path.join(self.img_dir, row['filename'])).convert('RGB')
        if self.transform:
            image = self.transform(image)
        if self.is_test:
            return image, row['filename']
        return image, torch.tensor(row['label'], dtype=torch.long)

train_transform = T.Compose([
    T.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transform = T.Compose([
    T.Resize((CFG.IMG_SIZE, CFG.IMG_SIZE)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# =============================================================================
# ARCFACE
# =============================================================================
class ArcMarginProduct(nn.Module):
    def __init__(self, in_features, out_features, s=30.0, m=0.50):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.s, self.m = s, m
        self.cos_m, self.sin_m = np.cos(m), np.sin(m)
        self.th, self.mm = np.cos(np.pi - m), np.sin(np.pi - m) * m
    def forward(self, input, label):
        cosine = F.linear(F.normalize(input), F.normalize(self.weight)).clamp(-1+1e-7, 1-1e-7)
        sine = torch.sqrt(1.0 - cosine ** 2)
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = torch.zeros(cosine.size(), device=input.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        return ((one_hot * phi) + ((1.0 - one_hot) * cosine)) * self.s

# =============================================================================
# MODEL
# =============================================================================
class JaguarModel(nn.Module):
    def __init__(self, model_name, n_classes, s=30.0, m=0.50):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, num_classes=0, dynamic_img_size=True)
        self.embedding_dim = self.backbone.num_features
        self.arcface = ArcMarginProduct(self.embedding_dim, n_classes, s=s, m=m)
        print(f"Embedding dim: {self.embedding_dim}")
    def forward(self, x, label=None):
        emb = self.backbone(x)
        if label is not None:
            return self.arcface(emb, label)
        return F.normalize(emb, p=2, dim=1)
    def get_embedding(self, x):
        return F.normalize(self.backbone(x), p=2, dim=1)

# =============================================================================
# LOAD DATA
# =============================================================================
print("\nLoading data...")
train_df = pd.read_csv(os.path.join(CFG.INPUT_DIR, "train.csv"))
test_df = pd.read_csv(os.path.join(CFG.INPUT_DIR, "test.csv"))
sample_sub = pd.read_csv(os.path.join(CFG.INPUT_DIR, "sample_submission.csv"))

label_to_idx = {l: i for i, l in enumerate(sorted(train_df['ground_truth'].unique()))}
train_df['label'] = train_df['ground_truth'].map(label_to_idx)
n_classes = len(label_to_idx)
print(f"Images: {len(train_df)}, Classes: {n_classes}, Test pairs: {len(test_df):,}")

# =============================================================================
# DATALOADER
# =============================================================================
train_dataset = JaguarDataset(train_df, CFG.TRAIN_DIR, train_transform)
train_loader = DataLoader(train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True, num_workers=2)
print(f"Train batches: {len(train_loader)}")

# =============================================================================
# INITIALIZE MODEL
# =============================================================================
print(f"\nInitializing {CFG.MODEL_NAME}...")
model = JaguarModel(CFG.MODEL_NAME, n_classes, s=CFG.ARCFACE_SCALE, m=CFG.ARCFACE_MARGIN).to(CFG.DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=CFG.WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CFG.EPOCHS)
criterion = nn.CrossEntropyLoss()
n_params = sum(p.numel() for p in model.parameters())
print(f"Parameters: {n_params:,}")

# =============================================================================
# TRAINING
# =============================================================================
print("\n" + "="*50 + "\nTRAINING\n" + "="*50)
for epoch in range(CFG.EPOCHS):
    model.train()
    total_loss = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{CFG.EPOCHS}")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(CFG.DEVICE), labels.to(CFG.DEVICE)
        optimizer.zero_grad()
        loss = criterion(model(imgs, labels), labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix({'loss': total_loss / (pbar.n + 1)})
    scheduler.step()
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{CFG.EPOCHS} - Loss: {avg_loss:.4f}")

print("\nTraining complete!")

# =============================================================================
# INFERENCE WITH TTA
# =============================================================================
print("\n" + "="*50 + "\nINFERENCE\n" + "="*50)
test_images = sorted(set(test_df['query_image']) | set(test_df['gallery_image']))
print(f"Unique test images: {len(test_images)}")

test_data = pd.DataFrame({'filename': test_images})
test_dataset = JaguarDataset(test_data, CFG.TEST_DIR, val_transform, is_test=True)
test_loader = DataLoader(test_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=2)

emb_dict = {}
model.eval()
with torch.no_grad():
    for imgs, fnames in tqdm(test_loader, desc="Extracting embeddings"):
        imgs = imgs.to(CFG.DEVICE)
        emb_orig = model.get_embedding(imgs)
        emb_flip = model.get_embedding(torch.flip(imgs, dims=[3]))
        emb = F.normalize(emb_orig + emb_flip, p=2, dim=1)
        for fname, e in zip(fnames, emb.cpu().numpy()):
            emb_dict[fname] = e

print(f"Extracted {len(emb_dict)} embeddings")

# =============================================================================
# GENERATE SUBMISSION
# =============================================================================
print("\nGenerating submission...")
similarities = []
for _, row in tqdm(test_df.iterrows(), total=len(test_df), desc="Computing similarities"):
    emb1 = emb_dict[row['query_image']]
    emb2 = emb_dict[row['gallery_image']]
    sim = (np.dot(emb1, emb2) + 1) / 2
    similarities.append(sim)

submission = sample_sub.copy()
submission['similarity'] = similarities
submission.to_csv('submission.csv', index=False)

print("\n" + "="*50)
print("DONE!")
print("="*50)
print(f"Similarity range: [{submission['similarity'].min():.4f}, {submission['similarity'].max():.4f}]")
print(f"Mean: {submission['similarity'].mean():.4f}, Std: {submission['similarity'].std():.4f}")
print(f"\nSubmission saved to submission.csv")
print(submission.head())