In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import pandas as pd
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights
from tqdm import tqdm
import matplotlib.pyplot as plt

# =============================
# Configurations
# =============================
BATCH_SIZE = 32
NUM_WORKERS = 4
IMAGE_SIZE = 224
EPOCHS_CLP = 40
EPOCHS_PLF = 40
LR = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# reproducibility
torch.manual_seed(42)

# =============================
# Dataset
# =============================
class AnimalCLEFDataset(Dataset):
    def __init__(self, root, split="database", transform=None):
        self.root = root.rstrip('/')
        meta = pd.read_csv(f"{self.root}/metadata.csv")
        sel = meta[meta['path'].str.contains(f"/{split}/")].reset_index(drop=True)
        if sel.empty:
            raise ValueError(f"No entries for split '{split}'")

        self.paths = sel['path'].tolist()
        self.image_ids = sel['image_id'].tolist()

        if split == 'database':
            #  Use individual identity,  
            ids = sel['identity'].astype(str)

            #  Build mapping from identity string → label index
            self.id2idx = {iid: i for i, iid in enumerate(sorted(ids.unique()))}

            #  Map each sample's identity to its label
            self.labels = ids.map(self.id2idx).tolist()

            # Safety check
            num_classes = len(self.id2idx)
            assert all(0 <= label < num_classes for label in self.labels), "Invalid labels found"
        else:
            self.labels = [-1] * len(sel)

        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, i):
        img = Image.open(f"{self.root}/{self.paths[i]}").convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[i]

# =============================
# MAE Encoder + Projection Head + Decoder
# =============================
class MAEFramework(nn.Module):
    def __init__(self,
                 embed_dim: int = 768,
                 proj_dim: int = 256,
                 decoder_dim: int = 256,
                 layer_indices: list[int] = [3, 6, 9]):
        super().__init__()
        self.embed_dim = embed_dim
        # 1) Backbone ViT
        self.encoder = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        self.encoder.head = nn.Identity()
        self.layer_indices = set(layer_indices)

        # 3) Projection head
        self.proj_head = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(embed_dim, proj_dim),
        )
        # 4) Decoder
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, decoder_dim),
            nn.ReLU(inplace=True),
            nn.Linear(decoder_dim, 3 * IMAGE_SIZE * IMAGE_SIZE),
        )

    def forward(self, x: torch.Tensor, return_feats: bool = False):
        B = x.size(0)
        # patch embed
        x_p = self.encoder.conv_proj(x)
        x_p = x_p.flatten(2).transpose(1, 2)
        cls_tok = self.encoder.class_token.expand(B, -1, -1)
        tokens = torch.cat([cls_tok, x_p], dim=1)
        tokens = tokens + self.encoder.encoder.pos_embedding

        feats = []
        for idx, block in enumerate(self.encoder.encoder.layers):
            tokens = block(tokens)
            if idx in self.layer_indices:
                feats.append(tokens.clone())

        cls_feat = self.encoder.encoder.ln(tokens[:, 0])
        proj = self.proj_head(cls_feat)
        rec = self.decoder(cls_feat).view(B, 3, IMAGE_SIZE, IMAGE_SIZE)
        rec_loss = F.mse_loss(rec, x, reduction='none').mean([1, 2, 3])

        if return_feats:
            return cls_feat, proj, rec_loss, rec, feats
        return proj, rec_loss, rec

# =============================
# Losses
# =============================
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temp = temperature

    def forward(self, p1, p2, labels=None):
        # NT-Xent instance-level contrastive loss
        z = torch.cat([p1, p2], dim=0)
        z = F.normalize(z, dim=1)
        N = p1.size(0)
        sim = torch.matmul(z, z.T) / self.temp
        mask = torch.eye(2*N, device=sim.device).bool()
        sim.masked_fill_(mask, -9e15)
        idx = torch.arange(N, device=sim.device)
        targets = torch.cat([idx + N, idx])
        return F.cross_entropy(sim, targets)

class ProtoLoss(nn.Module):
    def forward(self, feats, prototypes, labels):
        dist = torch.cdist(feats, prototypes)
        return F.cross_entropy(-dist, labels)

# =============================
# Utilities
# =============================
def compute_layer_distances(bef_feats, aft_feats, temperature=0.5):
    total = []
    for b, a in zip(bef_feats, aft_feats):  # each b/a: (B, N, D)
        b_f, a_f = b.flatten(1), a.flatten(1)  # (B, N*D)
        eu = F.pairwise_distance(b_f, a_f)  # (B,)
        cos = 1 - F.cosine_similarity(b_f, a_f, dim=1)  # (B,)
        score = eu + temperature * cos  # (B,)
        total.append(score)  # keep per-sample scores
    return torch.stack(total, dim=0).mean(dim=0)  # (B,)


def calculate_unknown_score(feat_bef, feat_aft, feat_vec, prototypes, lamda=1.0):
    s_total = compute_layer_distances(feat_bef, feat_aft)  # shape: (B,)
    
    fv_n = F.normalize(feat_vec, dim=1)
    p_n = F.normalize(prototypes, dim=1)
    sim = torch.matmul(fv_n, p_n.T)  # (B, C)
    s_proto, _ = sim.max(dim=1)  # max similarity per sample

    unknown_score = (1 - s_proto) + lamda * s_total
    return unknown_score  # higher = more unknown




# =============================
# Training: CLP
# =============================
def train_CLP(model, loader, epochs, lr, ckpt, alpha_cl=1.0, alpha_rec=1.0):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    scl = SupConLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"CLP Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x1 = x + 0.05 * torch.randn_like(x)
            x2 = x + 0.05 * torch.randn_like(x)
            f1, p1, r1_loss, r1, _ = model(x1, return_feats=True)
            f2, p2, r2_loss, r2, _ = model(x2, return_feats=True)
            l_cl = scl(p1, p2)
            l_rec = r1_loss.mean() + r2_loss.mean()
            loss = alpha_cl * l_cl + alpha_rec * l_rec
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"CLP Epoch {ep+1}: {total/len(loader):.4f}")


# =============================
# Training: PLF
# =============================
def train_PLF(model, loader, epochs, lr, ckpt, num_classes, encoder_pre):
    path = os.path.join(CHECKPOINT_DIR, ckpt)
    # unwrap if DataParallel
    base_model = model.module if isinstance(model, nn.DataParallel) else model
    embed_dim = base_model.embed_dim
    proto_tensor = torch.randn(num_classes, embed_dim, device=DEVICE)
    proto = nn.Parameter(proto_tensor, requires_grad=True)
    opt = torch.optim.Adam(list(model.parameters()) + [proto], lr=lr)
    start = 0
    if os.path.exists(path):
        ck = torch.load(path)
        model.load_state_dict(ck['model'])
        proto.data = ck['proto']
        opt.load_state_dict(ck['opt'])
        start = ck['ep'] + 1
    ploss = ProtoLoss()
    model.to(DEVICE)

    for ep in range(start, epochs):
        model.train()
        total = 0.0
        for x, y in tqdm(loader, desc=f"PLF Ep{ep+1}"):
            x, y = x.to(DEVICE), y.to(DEVICE)
            feats, _, _, _, _ = model(x, return_feats=True)
            loss = ploss(feats, proto, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total += loss.item()
        torch.save({'model': model.state_dict(), 'proto': proto.data, 'opt': opt.state_dict(), 'ep': ep}, path)
        print(f"PLF Epoch {ep+1}: {total/len(loader):.4f}")
    return proto

# =============================
# Inference
# =============================

@torch.no_grad()
def inference(enc_pre, model, proto, loader, threshold, lamda=0.2):
    enc_pre.eval()
    model.eval()
    preds = []

    for x, _ in tqdm(loader, desc="Infer"):
        x = x.to(DEVICE)
        enc_pre = enc_pre.to(DEVICE)
        model = model.to(DEVICE)

        _, _, _, _, bef_feats = enc_pre(x, return_feats=True)
        feat, _, _, _, aft_feats = model(x, return_feats=True)

        # 1. Calculate unknown score
        scores = calculate_unknown_score(bef_feats, aft_feats, feat, proto, lamda=lamda)  # shape: (B,)
        print("Score stats:", torch.min(scores).item(), torch.mean(scores).item(), torch.max(scores).item())

        # 2. Predict the most similar known class
        idx = torch.argmax(torch.matmul(F.normalize(feat, dim=1), F.normalize(proto, dim=1).T), dim=1)  # (B,)

        # 3. Threshold to filter unknowns
        known = scores < threshold
        pred = idx.clone()
        pred[~known] = -1  #  invert known to get unknown


        preds.append(pred.cpu())

    return torch.cat(preds, dim=0)


# =============================
# Submission
# =============================
def generate_submission(root, preds, db_ds):
    sub = pd.read_csv(f"{root}/sample_submission.csv")
    meta = pd.read_csv(f"{root}/metadata.csv")
    q = meta[meta['path'].str.contains('/query/')].reset_index(drop=True)
    q['pred_idx'] = preds.numpy()

    # Step 1: Print raw predicted indices
    print("\nRaw predicted indices (pred_idx):")
    print(q[['image_id', 'pred_idx']].head(10))

    # Step 2: Index to identity mapping
    idx2id = {v: k for k, v in db_ds.id2idx.items()}
    q['prediction'] = q['pred_idx'].apply(lambda i: 'new_individual' if i < 0 else idx2id.get(int(i), f"unknown_{i}"))

    # Step 3: Print mapped predictions
    print("\nMapped predictions (after idx2id):")
    print(q[['image_id', 'prediction']].head(10))

    out = sub[['image_id']].merge(q[['image_id', 'prediction']], on='image_id')

    import time
    timestamp = time.strftime("%Y%m%d-%H%M%S")
    save_path = f'/kaggle/working/submission_{timestamp}.csv'
    out.to_csv(save_path, index=False)

    return save_path
 
# =============================
# Main Workflow
# =============================
def main():
    root = '/kaggle/input/animal-clef-2025'
    tf = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor()
    ])
    #idx2id = {v: k for k, v in db_ds.id2idx.items()}

    db_ds = AnimalCLEFDataset(root, 'database', transform=tf)
    db_loader = DataLoader(db_ds, batch_size=8, shuffle=True,
                           num_workers=NUM_WORKERS, pin_memory=True)
    query_ds = AnimalCLEFDataset(root, 'query', transform=tf)
    q_loader = DataLoader(query_ds, batch_size=8, shuffle=False,
                           num_workers=NUM_WORKERS, pin_memory=True)

    clp_model = MAEFramework()
    plf_model = MAEFramework()

    num_gpus = torch.cuda.device_count()
    if num_gpus >= 2:
        device_ids = [0, 1]
    elif num_gpus == 1:
        device_ids = [0]
    else:
        device_ids = None

    clp_model = nn.DataParallel(clp_model, device_ids=device_ids).to(DEVICE)
    plf_model = nn.DataParallel(plf_model, device_ids=device_ids).to(DEVICE)

    train_CLP(clp_model, db_loader, EPOCHS_CLP, LR, 'clp.pth', alpha_cl=1.0, alpha_rec=1.0)
    prototype = train_PLF(plf_model, db_loader, EPOCHS_PLF, LR,
                          'plf.pth', len(db_ds.id2idx), clp_model)
    print("Prototype shape:", prototype.shape)
    print("Number of unique classes:", len(db_ds.id2idx))

    # Plot the score distribution before estimating threshold
    #plot_score_distribution(clp_model, plf_model, prototype, db_loader)

    # Estimate threshold from database distribution
    dists = []
    with torch.no_grad():
        for x, _ in DataLoader(db_ds, batch_size=8, shuffle=False,
                               num_workers=NUM_WORKERS, pin_memory=True):
            x = x.to(DEVICE)
            feat, _, _, _, aft_feats = plf_model(x, return_feats=True)
            _, _, _, _, bef_feats = clp_model(x, return_feats=True)
            scores = calculate_unknown_score(bef_feats, aft_feats, feat, prototype, lamda = 0.2)
            dists.extend(scores.cpu().tolist())


    # Ensure flat list of floats
    flat_dists = [v.item() if isinstance(v, torch.Tensor) else v for v in dists]
    threshold = torch.quantile(torch.tensor(flat_dists, dtype=torch.float32), 0.95)


    print("Estimated threshold:", threshold.item())
    print("Score range — min:", min(flat_dists), "max:", max(flat_dists))


    # Inference on query set
    with torch.no_grad():
        lamda = 0.2 # or any value you want to test
        preds = inference(clp_model, plf_model, prototype, q_loader, threshold, lamda=lamda)
    print("Total predictions:", len(preds))
    print("Predicted as new_individual:", (preds == -1).sum().item())

    print("Saving submission…")
    generate_submission(root, preds, db_ds)

if __name__ == '__main__':
    main()


  ck = torch.load(path)
CLP Ep21: 100%|██████████| 1635/1635 [14:03<00:00,  1.94it/s]


CLP Epoch 21: 1.2769


CLP Ep22: 100%|██████████| 1635/1635 [14:03<00:00,  1.94it/s]


CLP Epoch 22: 1.1955


CLP Ep23: 100%|██████████| 1635/1635 [14:04<00:00,  1.94it/s]


CLP Epoch 23: 1.1957


CLP Ep24: 100%|██████████| 1635/1635 [14:04<00:00,  1.94it/s]


CLP Epoch 24: 1.1872


CLP Ep25:  67%|██████▋   | 1088/1635 [09:22<04:42,  1.93it/s]

20% Accuracy

In [1]:
# ====== Imports and Setup ======
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import timm
import joblib

# ====== Device ======
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Configuration ======
root_dir = "/kaggle/input/animal-clef-2025"
metadata_path = "/kaggle/input/animal-clef-2025/metadata.csv"
embedding_dim = 512
confidence_threshold = 0.90

# ====== Load Metadata ======
df = pd.read_csv(metadata_path)

# ====== Encoder ======
encoder = LabelEncoder()
database_df = df[df["split"] == "database"].dropna(subset=["identity"])
database_df["label"] = encoder.fit_transform(database_df["identity"])
joblib.dump(encoder, "label_encoder.pkl")

# ====== Transform ======
transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
])

# ====== Dataset Class ======
class InferenceDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["path"].tolist()
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(root_dir, self.paths[idx])).convert("RGB")
        img = self.transform(img)
        return img

# ====== Load Model ======
model = timm.create_model("hf-hub:BVRA/MegaDescriptor-L-384", pretrained=True)
model = model.to(device)
model.eval()

# ====== Embedding Function ======
def extract_embeddings(model, df, transform, batch_size=64):
    loader = DataLoader(InferenceDataset(df, transform), batch_size=batch_size, shuffle=False, num_workers=2)
    embeddings = []
    with torch.no_grad():
        for batch in tqdm(loader, desc="Extracting embeddings"):
            batch = batch.to(device)
            emb = model(batch)
            embeddings.append(emb.cpu().numpy())
    return np.vstack(embeddings)

# ====== Prepare and Embed ======
all_results = []
for class_name in ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]:
    class_df = df[df["path"].str.contains(class_name)]
    database_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
    query_df = class_df[class_df["split"] == "query"]

    db_embeddings = extract_embeddings(model, database_df, transform)
    query_embeddings = extract_embeddings(model, query_df, transform)

    db_labels = database_df["identity"].tolist()
    for i, query_emb in enumerate(query_embeddings):
        sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
        max_idx = np.argmax(sims)
        max_sim = sims[max_idx]
        identity = db_labels[max_idx] if max_sim >= confidence_threshold else "new_individual"
        all_results.append({
            "image_id": query_df.iloc[i]["image_id"],
            "identity": identity
        })

# ====== Save Submission ======
submission_df = pd.DataFrame(all_results)
submission_df.to_csv("submission.csv", index=False)
print("Submission saved as submission.csv")

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

Extracting embeddings: 100%|██████████| 137/137 [06:15<00:00,  2.74s/it]
Extracting embeddings: 100%|██████████| 8/8 [00:22<00:00,  2.85s/it]
Extracting embeddings: 100%|██████████| 47/47 [02:09<00:00,  2.75s/it]
Extracting embeddings: 100%|██████████| 15/15 [00:43<00:00,  2.91s/it]
Extracting embeddings: 100%|██████████| 22/22 [01:02<00:00,  2.84s/it]
Extracting embeddings: 100%|██████████| 11/11 [00:32<00:00,  2.96s/it]


Submission saved as submission.csv


27% Accuracy

In [3]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 32
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.85

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    results = []
    
    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]
        
        if db_df.empty or query_df.empty:
            continue
            
        db_embeddings = extractor.extract_embeddings(db_df)
        query_embeddings = extractor.extract_embeddings(query_df)
        
        # Normalize embeddings
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
        
        for i, query_emb in enumerate(tqdm(query_embeddings, desc="Matching")):
            identity = find_matches(query_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })
    
    pd.DataFrame(results).to_csv("submission.csv", index=False)
    print("Done! Results saved to submission.csv")


Processing SeaTurtleID2022...


Extracting embeddings: 100%|██████████| 273/273 [06:12<00:00,  1.37s/it]
Extracting embeddings: 100%|██████████| 16/16 [00:21<00:00,  1.37s/it]
Matching: 100%|██████████| 500/500 [00:19<00:00, 26.14it/s]



Processing LynxID2025...


Extracting embeddings: 100%|██████████| 93/93 [02:07<00:00,  1.37s/it]
Extracting embeddings: 100%|██████████| 30/30 [00:41<00:00,  1.39s/it]
Matching: 100%|██████████| 946/946 [00:12<00:00, 76.67it/s]



Processing SalamanderID2025...


Extracting embeddings: 100%|██████████| 44/44 [01:00<00:00,  1.37s/it]
Extracting embeddings: 100%|██████████| 22/22 [00:30<00:00,  1.39s/it]
Matching: 100%|██████████| 689/689 [00:03<00:00, 185.11it/s]

Done! Results saved to submission.csv





22% with argumentations

In [5]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 32
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.85

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
                transforms.Resize((384, 384)),  # Keep consistent with model
                transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),  # Add slight augmentation
                transforms.ColorJitter(brightness=0.1, contrast=0.1),  # Help with lighting variations
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet stats
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    results = []
    
    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]
        
        if db_df.empty or query_df.empty:
            continue
            
        db_embeddings = extractor.extract_embeddings(db_df)
        query_embeddings = extractor.extract_embeddings(query_df)
        
        # Normalize embeddings
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
        
        for i, query_emb in enumerate(tqdm(query_embeddings, desc="Matching")):
            identity = find_matches(query_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })
    
    pd.DataFrame(results).to_csv("submission.csv", index=False)
    print("Done! Results saved to submission.csv")


Processing SeaTurtleID2022...


Extracting embeddings: 100%|██████████| 273/273 [06:12<00:00,  1.36s/it]
Extracting embeddings: 100%|██████████| 16/16 [00:22<00:00,  1.38s/it]
Matching: 100%|██████████| 500/500 [00:18<00:00, 26.42it/s]



Processing LynxID2025...


Extracting embeddings: 100%|██████████| 93/93 [02:07<00:00,  1.37s/it]
Extracting embeddings: 100%|██████████| 30/30 [00:41<00:00,  1.40s/it]
Matching: 100%|██████████| 946/946 [00:12<00:00, 73.99it/s]



Processing SalamanderID2025...


Extracting embeddings: 100%|██████████| 44/44 [01:00<00:00,  1.38s/it]
Extracting embeddings: 100%|██████████| 22/22 [00:30<00:00,  1.40s/it]
Matching: 100%|██████████| 689/689 [00:03<00:00, 190.61it/s]

Done! Results saved to submission.csv





In [3]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import timm
import joblib
from collections import defaultdict

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 64  # Increased batch size
    NUM_WORKERS = 4
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.87  # Adjusted threshold
    VAL_SIZE = 0.2
    RANDOM_SEED = 42
    MIN_SAMPLES_PER_CLASS = 2  # Minimum samples required per class

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset Class with Enhanced Augmentation ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None, is_train=True):
        self.df = df.reset_index(drop=True)
        self.is_train = is_train
        self.transform = transform or self.get_default_transform(is_train)
        
    def get_default_transform(self, is_train):
        if is_train:
            return transforms.Compose([
                transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
                transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            return transforms.Compose([
                transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                if self.transform:
                    img = self.transform(img.convert("RGB"))
                return img
        except Exception as e:
            print(f"Error loading {img_path}: {str(e)}")
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

# ====== Feature Extractor with Enhanced Model ======
class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        # Freeze all layers
        for param in self.model.parameters():
            param.requires_grad = False
        
    @torch.no_grad()
    def extract_embeddings(self, df, is_train=False):
        dataset = AnimalDataset(df, is_train=is_train)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS,
                          pin_memory=True)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Robust Data Splitting ======

def safe_train_test_split(db_df, val_size=0.2, min_samples=2, random_state=42):
    """Handle classes with insufficient samples and ensure all validation classes exist in training"""
    # Step 1: Identify classes with enough samples
    class_counts = db_df['identity'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    
    if len(valid_classes) == 0:
        print("Warning: No classes have sufficient samples for validation split")
        return db_df.copy(), pd.DataFrame(columns=db_df.columns)
    
    # Step 2: First split to ensure all validation classes exist in training
    train_classes, val_classes = train_test_split(
        valid_classes,
        test_size=val_size,
        random_state=random_state
    )
    
    # Step 3: Create splits
    train_df = db_df[db_df['identity'].isin(train_classes)]
    val_df = db_df[db_df['identity'].isin(val_classes)]
    
    # Step 4: Add rare classes to training
    rare_df = db_df[~db_df['identity'].isin(valid_classes)]
    if len(rare_df) > 0:
        train_df = pd.concat([train_df, rare_df])
        print(f"Added {len(rare_df)} samples from rare classes to training")
    
    return train_df, val_df
 
# ====== Enhanced Evaluation Metrics ======
def evaluate_accuracy(query_embeddings, query_labels, db_embeddings, db_labels, threshold):
    """Calculate accuracy with confidence thresholding"""
    # Normalize embeddings
    db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
    query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
    
    # Batch processing for efficiency
    sim_matrix = cosine_similarity(query_embeddings, db_embeddings)
    max_sims = np.max(sim_matrix, axis=1)
    max_indices = np.argmax(sim_matrix, axis=1)
    
    correct = 0
    total = 0
    
    for i in range(len(query_labels)):
        if max_sims[i] >= threshold:
            if db_labels[max_indices[i]] == query_labels[i]:
                correct += 1
        total += 1
    
    return correct / total if total > 0 else 0.0

# ====== Main Execution with Improved Pipeline ======
if __name__ == "__main__":
    # Load and prepare data
    df = pd.read_csv(Config.METADATA_PATH)
    encoder = LabelEncoder()
    extractor = FeatureExtractor()
    
    results = []
    val_accuracies = []
    class_reports = []
    
    for class_name in Config.CLASSES:
        print(f"\n{'='*40}\nProcessing {class_name}\n{'='*40}")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        
        if db_df.empty:
            print(f"No database samples found for {class_name}")
            continue
            
        # Split database into train and validation
        train_df, val_df = safe_train_test_split(
            db_df,
            val_size=Config.VAL_SIZE,
            min_samples=Config.MIN_SAMPLES_PER_CLASS,
            random_state=Config.RANDOM_SEED
        )
        
        if val_df.empty:
            print(f"Skipping validation for {class_name} - insufficient samples")
            val_accuracies.append(0)
            continue
         
    
        # Encode labels - fit on combined data first
        all_identities = pd.concat([train_df['identity'], val_df['identity']]).unique()
        encoder.fit(all_identities)
        
        train_df["label"] = encoder.transform(train_df["identity"])
        val_df["label"] = encoder.transform(val_df["identity"])

        
        
        # Extract embeddings
        print("Extracting training embeddings...")
        train_embeddings = extractor.extract_embeddings(train_df, is_train=True)
        print("Extracting validation embeddings...")
        val_embeddings = extractor.extract_embeddings(val_df, is_train=False)
        
        # Evaluate on validation set
        val_accuracy = evaluate_accuracy(
            val_embeddings, 
            val_df["label"].values,
            train_embeddings,
            train_df["label"].values,
            Config.BASE_THRESHOLD
        )
        val_accuracies.append(val_accuracy)
        print(f"\nValidation Accuracy for {class_name}: {val_accuracy:.4f}")
        
        # Process queries if available
        query_df = class_df[class_df["split"] == "query"]
        if not query_df.empty:
            print("Processing query images...")
            query_embeddings = extractor.extract_embeddings(query_df, is_train=False)
            query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
            train_embeddings = train_embeddings / np.linalg.norm(train_embeddings, axis=1, keepdims=True)
            
            # Batch processing for queries
            sim_matrix = cosine_similarity(query_embeddings, train_embeddings)
            max_sims = np.max(sim_matrix, axis=1)
            max_indices = np.argmax(sim_matrix, axis=1)
            
            for i in tqdm(range(len(query_df)), desc="Matching queries"):
                identity = (train_df.iloc[max_indices[i]]["identity"] 
                          if max_sims[i] >= Config.BASE_THRESHOLD 
                          else "new_individual")
                results.append({
                    "image_id": query_df.iloc[i]["image_id"],
                    "identity": identity,
                    "confidence": float(max_sims[i])
                })
    
    # Save results and print summary
    if results:
        submission_df = pd.DataFrame(results)
        submission_df.to_csv("submission.csv", index=False)
        print("\nSubmission saved to submission.csv")
    
    # Validation report
    print("\n\nValidation Accuracy Summary:")
    for class_name, acc in zip(Config.CLASSES, val_accuracies):
        print(f"{class_name}: {acc:.4f}")
    
    valid_accs = [acc for acc in val_accuracies if acc > 0]
    if valid_accs:
        print(f"\nMean Validation Accuracy: {np.mean(valid_accs):.4f}")
    else:
        print("\nNo valid validation results available")

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_df["label"] = encoder.transform(val_df["identity"])



Processing SeaTurtleID2022
Added 1 samples from rare classes to training
Extracting training embeddings...


Extracting embeddings: 100%|██████████| 106/106 [04:47<00:00,  2.71s/it]


Extracting validation embeddings...


Extracting embeddings: 100%|██████████| 32/32 [01:25<00:00,  2.66s/it]



Validation Accuracy for SeaTurtleID2022: 0.0000
Processing query images...


Extracting embeddings: 100%|██████████| 8/8 [00:22<00:00,  2.87s/it]
Matching queries: 100%|██████████| 500/500 [00:00<00:00, 21411.09it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_df["label"] = encoder.transform(val_df["identity"])



Processing LynxID2025
Added 6 samples from rare classes to training
Extracting training embeddings...


Extracting embeddings: 100%|██████████| 31/31 [01:26<00:00,  2.79s/it]


Extracting validation embeddings...


Extracting embeddings: 100%|██████████| 16/16 [00:44<00:00,  2.77s/it]



Validation Accuracy for LynxID2025: 0.0000
Processing query images...


Extracting embeddings: 100%|██████████| 15/15 [00:44<00:00,  2.94s/it]
Matching queries: 100%|██████████| 946/946 [00:00<00:00, 22625.37it/s]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  val_df["label"] = encoder.transform(val_df["identity"])



Processing SalamanderID2025
Added 310 samples from rare classes to training
Extracting training embeddings...


Extracting embeddings: 100%|██████████| 19/19 [00:53<00:00,  2.81s/it]


Extracting validation embeddings...


Extracting embeddings: 100%|██████████| 4/4 [00:13<00:00,  3.26s/it]



Validation Accuracy for SalamanderID2025: 0.0000
Processing query images...


Extracting embeddings:  55%|█████▍    | 6/11 [00:22<00:19,  3.80s/it]


KeyboardInterrupt: 

23% Accuracy

In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 64
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.87

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    results = []
    
    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]
        
        if db_df.empty or query_df.empty:
            continue
            
        db_embeddings = extractor.extract_embeddings(db_df)
        query_embeddings = extractor.extract_embeddings(query_df)
        
        # Normalize embeddings
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
        
        for i, query_emb in enumerate(tqdm(query_embeddings, desc="Matching")):
            identity = find_matches(query_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })
    
    pd.DataFrame(results).to_csv("submission.csv", index=False)
    print("Done! Results saved to submission.csv")

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.94G [00:00<?, ?B/s]


Processing SeaTurtleID2022...


Extracting embeddings: 100%|██████████| 137/137 [06:17<00:00,  2.76s/it]
Extracting embeddings: 100%|██████████| 8/8 [00:23<00:00,  2.88s/it]
Matching: 100%|██████████| 500/500 [00:19<00:00, 25.39it/s]



Processing LynxID2025...


Extracting embeddings: 100%|██████████| 47/47 [02:09<00:00,  2.76s/it]
Extracting embeddings: 100%|██████████| 15/15 [00:43<00:00,  2.93s/it]
Matching: 100%|██████████| 946/946 [00:12<00:00, 76.42it/s]



Processing SalamanderID2025...


Extracting embeddings: 100%|██████████| 22/22 [01:02<00:00,  2.85s/it]
Extracting embeddings: 100%|██████████| 11/11 [00:32<00:00,  2.98s/it]
Matching: 100%|██████████| 689/689 [00:03<00:00, 197.98it/s]

Done! Results saved to submission.csv





In [None]:
27% Accuracy

In [5]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 96
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.85

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    results = []
    
    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]
        
        if db_df.empty or query_df.empty:
            continue
            
        db_embeddings = extractor.extract_embeddings(db_df)
        query_embeddings = extractor.extract_embeddings(query_df)
        
        # Normalize embeddings
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
        
        for i, query_emb in enumerate(tqdm(query_embeddings, desc="Matching")):
            identity = find_matches(query_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })
    
    pd.DataFrame(results).to_csv("submission.csv", index=False)
    print("Done! Results saved to submission.csv")


Processing SeaTurtleID2022...


Extracting embeddings: 100%|██████████| 91/91 [06:19<00:00,  4.17s/it]
Extracting embeddings: 100%|██████████| 6/6 [00:23<00:00,  3.88s/it]
Matching: 100%|██████████| 500/500 [00:18<00:00, 26.73it/s]



Processing LynxID2025...


Extracting embeddings: 100%|██████████| 31/31 [02:10<00:00,  4.21s/it]
Extracting embeddings: 100%|██████████| 10/10 [00:44<00:00,  4.45s/it]
Matching: 100%|██████████| 946/946 [00:11<00:00, 82.84it/s]



Processing SalamanderID2025...


Extracting embeddings: 100%|██████████| 15/15 [01:03<00:00,  4.24s/it]
Extracting embeddings: 100%|██████████| 8/8 [00:33<00:00,  4.18s/it]
Matching: 100%|██████████| 689/689 [00:03<00:00, 196.93it/s]

Done! Results saved to submission.csv





In [2]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 96
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.85

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                            num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

# ====== Safe Stratified Split ======
def safe_stratified_split(df, test_size=0.2, seed=42):
    train_indices = []
    val_indices = []

    grouped = df.groupby("identity").indices
    for label, indices in grouped.items():
        indices = list(indices)
        if len(indices) == 1:
            train_indices.append(indices[0])  # Cannot split
        else:
            tr_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=seed)
            train_indices.extend(tr_idx)
            val_indices.extend(val_idx)
    return df.iloc[train_indices].reset_index(drop=True), df.iloc[val_indices].reset_index(drop=True)

# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    submission_results = []

    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        full_db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]

        if full_db_df.empty or query_df.empty:
            continue

        # Split database into training and validation
        db_df, val_df = safe_stratified_split(full_db_df)

        print(f"DB: {len(db_df)}, VAL: {len(val_df)}, QUERY: {len(query_df)}")

        # Extract embeddings
        db_embeddings = extractor.extract_embeddings(db_df)
        val_embeddings = extractor.extract_embeddings(val_df)
        query_embeddings = extractor.extract_embeddings(query_df)

        # Normalize
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        val_embeddings = val_embeddings / np.linalg.norm(val_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)

        # ====== Validation Evaluation ======
        correct = 0
        total = len(val_df)
        for i, val_emb in enumerate(tqdm(val_embeddings, desc=f"Validating {class_name}")):
            predicted = find_matches(val_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            true_id = val_df.iloc[i]["identity"]
            if predicted == true_id:
                correct += 1
        acc = correct / total if total > 0 else 0.0
        print(f"Validation Accuracy for {class_name}: {acc:.4f} ({correct}/{total})")

        # ====== Query Predictions for Submission ======
        for i, query_emb in enumerate(tqdm(query_embeddings, desc=f"Matching Query {class_name}")):
            identity = find_matches(query_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            submission_results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })

    pd.DataFrame(submission_results).to_csv("submission.csv", index=False)
    print("Done! Submission results saved to submission.csv.")



Processing SeaTurtleID2022...
DB: 6808, VAL: 1921, QUERY: 500


Extracting embeddings: 100%|██████████| 71/71 [04:57<00:00,  4.19s/it]
Extracting embeddings: 100%|██████████| 21/21 [01:24<00:00,  4.03s/it]
Extracting embeddings: 100%|██████████| 6/6 [00:23<00:00,  3.87s/it]
Validating SeaTurtleID2022: 100%|██████████| 1921/1921 [01:01<00:00, 31.10it/s]


Validation Accuracy for SeaTurtleID2022: 0.6158 (1183/1921)


Matching Query SeaTurtleID2022: 100%|██████████| 500/500 [00:15<00:00, 31.72it/s]



Processing LynxID2025...
DB: 2339, VAL: 618, QUERY: 946


Extracting embeddings: 100%|██████████| 25/25 [01:43<00:00,  4.13s/it]
Extracting embeddings: 100%|██████████| 7/7 [00:28<00:00,  4.14s/it]
Extracting embeddings: 100%|██████████| 10/10 [00:44<00:00,  4.46s/it]
Validating LynxID2025: 100%|██████████| 618/618 [00:06<00:00, 97.49it/s] 


Validation Accuracy for LynxID2025: 0.0696 (43/618)


Matching Query LynxID2025: 100%|██████████| 946/946 [00:09<00:00, 96.95it/s] 



Processing SalamanderID2025...
DB: 1059, VAL: 329, QUERY: 689


Extracting embeddings: 100%|██████████| 12/12 [00:49<00:00,  4.15s/it]
Extracting embeddings: 100%|██████████| 4/4 [00:17<00:00,  4.49s/it]
Extracting embeddings: 100%|██████████| 8/8 [00:33<00:00,  4.20s/it]
Validating SalamanderID2025: 100%|██████████| 329/329 [00:01<00:00, 257.53it/s]


Validation Accuracy for SalamanderID2025: 0.1216 (40/329)


Matching Query SalamanderID2025: 100%|██████████| 689/689 [00:02<00:00, 261.61it/s]

Done! Submission results saved to submission.csv.





In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 96
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.85

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                            num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

def find_best_threshold(val_embeddings, val_df, db_embeddings, db_df):
    best_thresh = 0.0
    best_acc = 0.0
    for t in np.arange(0.5, 0.96, 0.01):
        correct = 0
        for i, emb in enumerate(val_embeddings):
            pred = find_matches(emb, db_embeddings, db_df, t)
            if pred == val_df.iloc[i]["identity"]:
                correct += 1
        acc = correct / len(val_df)
        if acc > best_acc:
            best_acc = acc
            best_thresh = t
    return best_thresh, best_acc


# ====== Safe Stratified Split ======
def safe_stratified_split(df, test_size=0.2, seed=42):
    train_indices = []
    val_indices = []

    grouped = df.groupby("identity").indices
    for label, indices in grouped.items():
        indices = list(indices)
        if len(indices) == 1:
            train_indices.append(indices[0])  # Cannot split
        else:
            tr_idx, val_idx = train_test_split(indices, test_size=test_size, random_state=seed)
            train_indices.extend(tr_idx)
            val_indices.extend(val_idx)
    return df.iloc[train_indices].reset_index(drop=True), df.iloc[val_indices].reset_index(drop=True)

 
# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    submission_results = []
    thresholds = {}

    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        full_db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]

        if full_db_df.empty or query_df.empty:
            continue

        # Split database into training and validation
        db_df, val_df = safe_stratified_split(full_db_df)

        print(f"DB: {len(db_df)}, VAL: {len(val_df)}, QUERY: {len(query_df)}")

        # Extract embeddings
        db_embeddings = extractor.extract_embeddings(db_df)
        val_embeddings = extractor.extract_embeddings(val_df)
        query_embeddings = extractor.extract_embeddings(query_df)

        # Normalize
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        val_embeddings = val_embeddings / np.linalg.norm(val_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)

        # ====== Threshold Tuning ======
        best_thresh, best_acc = find_best_threshold(
            val_embeddings, val_df, db_embeddings, db_df
        )
        thresholds[class_name] = best_thresh
        print(f"Validation Accuracy for {class_name}: {best_acc:.4f} with threshold {best_thresh:.2f}")

        # ====== Query Predictions for Submission ======
        for i, query_emb in enumerate(tqdm(query_embeddings, desc=f"Matching Query {class_name}")):
            identity = find_matches(query_emb, db_embeddings, db_df, threshold=best_thresh)
            submission_results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })

    # Save submission
    pd.DataFrame(submission_results).to_csv("submission.csv", index=False)
    print("Done! Submission results saved to submission.csv.")




Processing SeaTurtleID2022...
DB: 6808, VAL: 1921, QUERY: 500


Extracting embeddings: 100%|██████████| 71/71 [04:56<00:00,  4.18s/it]
Extracting embeddings: 100%|██████████| 21/21 [01:24<00:00,  4.04s/it]
Extracting embeddings: 100%|██████████| 6/6 [00:23<00:00,  3.92s/it]


In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
import timm
import joblib

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 96
    NUM_WORKERS = 2
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.85

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset and Model Setup ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform or transforms.Compose([
            transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                return self.transform(img.convert("RGB"))
        except:
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        
    @torch.no_grad()
    def extract_embeddings(self, df):
        dataset = AnimalDataset(df)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Similarity Search ======
def find_matches(query_emb, db_embeddings, db_df, threshold):
    """NumPy implementation of similarity search"""
    sims = cosine_similarity(query_emb.reshape(1, -1), db_embeddings)[0]
    max_idx = np.argmax(sims)
    return db_df.iloc[max_idx]["identity"] if sims[max_idx] >= threshold else "new_individual"

# ====== Main Execution ======
if __name__ == "__main__":
    df = pd.read_csv(Config.METADATA_PATH)
    extractor = FeatureExtractor()
    results = []
    
    for class_name in Config.CLASSES:
        print(f"\nProcessing {class_name}...")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        query_df = class_df[class_df["split"] == "query"]
        
        if db_df.empty or query_df.empty:
            continue
            
        db_embeddings = extractor.extract_embeddings(db_df)
        query_embeddings = extractor.extract_embeddings(query_df)
        
        # Normalize embeddings
        db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
        query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
        
        for i, query_emb in enumerate(tqdm(query_embeddings, desc="Matching")):
            identity = find_matches(query_emb, db_embeddings, db_df, Config.BASE_THRESHOLD)
            results.append({
                "image_id": query_df.iloc[i]["image_id"],
                "identity": identity
            })
    
    pd.DataFrame(results).to_csv("submission.csv", index=False)
    print("Done! Results saved to submission.csv")

config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.94G [00:00<?, ?B/s]


Processing SeaTurtleID2022...


Extracting embeddings: 100%|██████████| 91/91 [06:20<00:00,  4.18s/it]
Extracting embeddings: 100%|██████████| 6/6 [00:23<00:00,  3.98s/it]
Matching: 100%|██████████| 500/500 [00:20<00:00, 24.72it/s]



Processing LynxID2025...


Extracting embeddings: 100%|██████████| 31/31 [02:11<00:00,  4.25s/it]
Extracting embeddings: 100%|██████████| 10/10 [00:45<00:00,  4.57s/it]
Matching: 100%|██████████| 946/946 [00:13<00:00, 71.08it/s]



Processing SalamanderID2025...


Extracting embeddings: 100%|██████████| 15/15 [01:04<00:00,  4.29s/it]
Extracting embeddings: 100%|██████████| 8/8 [00:34<00:00,  4.32s/it]
Matching: 100%|██████████| 689/689 [00:03<00:00, 177.14it/s]

Done! Results saved to submission.csv





In [6]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.model_selection import train_test_split
import timm
import joblib
from collections import defaultdict

# ====== Configuration ======
class Config:
    ROOT_DIR = "/kaggle/input/animal-clef-2025"
    METADATA_PATH = os.path.join(ROOT_DIR, "metadata.csv")
    EMBEDDING_DIM = 512
    BATCH_SIZE = 128  # Increased batch size
    NUM_WORKERS = 4
    IMAGE_SIZE = 384
    CLASSES = ["SeaTurtleID2022", "LynxID2025", "SalamanderID2025"]
    MODEL_NAME = "hf-hub:BVRA/MegaDescriptor-L-384"
    BASE_THRESHOLD = 0.87  # Adjusted threshold
    VAL_SIZE = 0.2
    RANDOM_SEED = 42
    MIN_SAMPLES_PER_CLASS = 2  # Minimum samples required per class

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ====== Dataset Class with Enhanced Augmentation ======
class AnimalDataset(Dataset):
    def __init__(self, df, transform=None, is_train=True):
        self.df = df.reset_index(drop=True)
        self.is_train = is_train
        self.transform = transform or self.get_default_transform(is_train)
        
    def get_default_transform(self, is_train):
        if is_train:
            return transforms.Compose([
                transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
                #transforms.RandomHorizontalFlip(),
                #transforms.RandomAffine(degrees=10, translate=(0.1, 0.1)),
                #transforms.ColorJitter(brightness=0.2, contrast=0.2),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            return transforms.Compose([
                transforms.Resize((Config.IMAGE_SIZE, Config.IMAGE_SIZE)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(Config.ROOT_DIR, self.df.iloc[idx]['path'])
        try:
            with Image.open(img_path) as img:
                if self.transform:
                    img = self.transform(img.convert("RGB"))
                return img
        except Exception as e:
            print(f"Error loading {img_path}: {str(e)}")
            return torch.zeros((3, Config.IMAGE_SIZE, Config.IMAGE_SIZE))

# ====== Feature Extractor with Enhanced Model ======
class FeatureExtractor:
    def __init__(self):
        self.model = timm.create_model(Config.MODEL_NAME, pretrained=True)
        self.model = self.model.to(device).eval()
        # Freeze all layers
        for param in self.model.parameters():
            param.requires_grad = False
        
    @torch.no_grad()
    def extract_embeddings(self, df, is_train=False):
        dataset = AnimalDataset(df, is_train=is_train)
        loader = DataLoader(dataset, batch_size=Config.BATCH_SIZE,
                          num_workers=Config.NUM_WORKERS,
                          pin_memory=True)
        
        embeddings = []
        for batch in tqdm(loader, desc="Extracting embeddings"):
            embeddings.append(self.model(batch.to(device)).cpu())
        return torch.cat(embeddings).numpy()

# ====== Robust Data Splitting ======
def safe_train_test_split(db_df, val_size=0.2, min_samples=2, random_state=42):
    """Handle classes with insufficient samples"""
    class_counts = db_df['identity'].value_counts()
    valid_classes = class_counts[class_counts >= min_samples].index
    
    if len(valid_classes) == 0:
        print("Warning: No classes have sufficient samples for validation split")
        return db_df.copy(), pd.DataFrame(columns=db_df.columns)
    
    valid_df = db_df[db_df['identity'].isin(valid_classes)]
    train_df, val_df = train_test_split(
        valid_df,
        test_size=val_size,
        stratify=valid_df['identity'],
        random_state=random_state
    )
    
    # Add insufficient samples to training
    insufficient_df = db_df[~db_df['identity'].isin(valid_classes)]
    if len(insufficient_df) > 0:
        train_df = pd.concat([train_df, insufficient_df])
        print(f"Added {len(insufficient_df)} samples from rare classes to training")
    
    return train_df, val_df

# ====== Enhanced Evaluation Metrics ======
def evaluate_accuracy(query_embeddings, query_labels, db_embeddings, db_labels, threshold):
    """Calculate accuracy with confidence thresholding"""
    # Normalize embeddings
    db_embeddings = db_embeddings / np.linalg.norm(db_embeddings, axis=1, keepdims=True)
    query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
    
    # Batch processing for efficiency
    sim_matrix = cosine_similarity(query_embeddings, db_embeddings)
    max_sims = np.max(sim_matrix, axis=1)
    max_indices = np.argmax(sim_matrix, axis=1)
    
    correct = 0
    total = 0
    
    for i in range(len(query_labels)):
        if max_sims[i] >= threshold:
            if db_labels[max_indices[i]] == query_labels[i]:
                correct += 1
        total += 1
    
    return correct / total if total > 0 else 0.0

# ====== Main Execution with Improved Pipeline ======
if __name__ == "__main__":
    # Load and prepare data
    df = pd.read_csv(Config.METADATA_PATH)
    encoder = LabelEncoder()
    extractor = FeatureExtractor()
    
    results = []
    val_accuracies = []
    class_reports = []
    
    for class_name in Config.CLASSES:
        print(f"\n{'='*40}\nProcessing {class_name}\n{'='*40}")
        class_df = df[df["path"].str.contains(class_name)]
        db_df = class_df[class_df["split"] == "database"].dropna(subset=["identity"])
        
        if db_df.empty:
            print(f"No database samples found for {class_name}")
            continue
            
        # Split database into train and validation
        train_df, val_df = safe_train_test_split(
            db_df,
            val_size=Config.VAL_SIZE,
            min_samples=Config.MIN_SAMPLES_PER_CLASS,
            random_state=Config.RANDOM_SEED
        )
        
        if val_df.empty:
            print(f"Skipping validation for {class_name} - insufficient samples")
            val_accuracies.append(0)
            continue
            
        # Encode labels
        train_df["label"] = encoder.fit_transform(train_df["identity"])
        val_df["label"] = encoder.transform(val_df["identity"])
        
        # Extract embeddings
        print("Extracting training embeddings...")
        train_embeddings = extractor.extract_embeddings(train_df, is_train=True)
        print("Extracting validation embeddings...")
        val_embeddings = extractor.extract_embeddings(val_df, is_train=False)
        
        # Evaluate on validation set
        val_accuracy = evaluate_accuracy(
            val_embeddings, 
            val_df["label"].values,
            train_embeddings,
            train_df["label"].values,
            Config.BASE_THRESHOLD
        )
        val_accuracies.append(val_accuracy)
        print(f"\nValidation Accuracy for {class_name}: {val_accuracy:.4f}")
        
        # Process queries if available
        query_df = class_df[class_df["split"] == "query"]
        if not query_df.empty:
            print("Processing query images...")
            query_embeddings = extractor.extract_embeddings(query_df, is_train=False)
            query_embeddings = query_embeddings / np.linalg.norm(query_embeddings, axis=1, keepdims=True)
            train_embeddings = train_embeddings / np.linalg.norm(train_embeddings, axis=1, keepdims=True)
            
            # Batch processing for queries
            sim_matrix = cosine_similarity(query_embeddings, train_embeddings)
            max_sims = np.max(sim_matrix, axis=1)
            max_indices = np.argmax(sim_matrix, axis=1)
            
            for i in tqdm(range(len(query_df)), desc="Matching queries"):
                identity = (train_df.iloc[max_indices[i]]["identity"] 
                          if max_sims[i] >= Config.BASE_THRESHOLD 
                          else "new_individual")
                results.append({
                    "image_id": query_df.iloc[i]["image_id"],
                    "identity": identity,
                    "confidence": float(max_sims[i])
                })
    
    # Save results and print summary
    if results:
        submission_df = pd.DataFrame(results)
        submission_df.to_csv("submission.csv", index=False)
        print("\nSubmission saved to submission.csv")
    
    # Validation report
    print("\n\nValidation Accuracy Summary:")
    for class_name, acc in zip(Config.CLASSES, val_accuracies):
        print(f"{class_name}: {acc:.4f}")
    
    valid_accs = [acc for acc in val_accuracies if acc > 0]
    if valid_accs:
        print(f"\nMean Validation Accuracy: {np.mean(valid_accs):.4f}")
    else:
        print("\nNo valid validation results available")


Processing SeaTurtleID2022
Added 1 samples from rare classes to training
Extracting training embeddings...


Extracting embeddings:   0%|          | 0/55 [00:05<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.80 GiB. GPU 0 has a total capacity of 15.89 GiB of which 3.63 GiB is free. Process 5540 has 12.25 GiB memory in use. Of the allocated memory 10.14 GiB is allocated by PyTorch, and 1.82 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)