In [3]:
from wildlife_datasets.datasets import AnimalCLEF2025

dataset = AnimalCLEF2025(root)

ModuleNotFoundError: No module named 'wildlife_datasets'

In [2]:
from huggingface_hub import snapshot_download
snapshot_download("facebook/vit-base-patch16-224", cache_dir="./vit_weights")

LocalEntryNotFoundError: An error happened while trying to locate the files on the Hub and we cannot find the appropriate snapshot folder for the specified revision on the local disk. Please check your internet connection and try again.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import timm
import pandas as pd
from PIL import Image
import os
import json
from tqdm import tqdm

# ===== Configuration =====
root_dir = "/kaggle/input/animal-clef-2025"
metadata_path = os.path.join(root_dir, "metadata.csv")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 224
batch_size = 32
num_workers = 2
confidence_threshold = 0.80
num_epochs = 250
checkpoint_dir = "./checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# ===== Load metadata =====
df = pd.read_csv(metadata_path)

# ===== Dataset Class =====
class AnimalReIDDataset(Dataset):
    def __init__(self, dataframe, transform):
        self.df = dataframe.reset_index(drop=True)
        self.transform = transform
        self.label2id = {label: i for i, label in enumerate(sorted(self.df['identity'].unique()))}
        self.id2label = {v: k for k, v in self.label2id.items()}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(root_dir, row["path"])
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        label = self.label2id[row["identity"]]
        return img, label

# ===== Image Transform =====
transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# ===== Build Dataset and Dataloader =====
train_df = df[(df["split"] == "database") & (df["identity"].notna())]
train_dataset = AnimalReIDDataset(train_df, transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

# ===== Model Definition =====
class MAEClassifier(nn.Module):
    def __init__(self, num_classes, freeze_encoder=True):
        super(MAEClassifier, self).__init__()
        self.encoder = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=0)
        self.encoder_frozen = freeze_encoder

        if self.encoder_frozen:
            for param in self.encoder.parameters():
                param.requires_grad = False

        self.classifier = nn.Linear(self.encoder.num_features, num_classes)

    def forward(self, x):
        x = self.encoder(x)
        return self.classifier(x)

    def unfreeze_encoder(self, num_blocks=6):
        if self.encoder_frozen:
            print(f"Unfreezing last {num_blocks} encoder blocks for fine-tuning...")
            for block in self.encoder.blocks[-num_blocks:]:
                for param in block.parameters():
                    param.requires_grad = True
            for param in self.encoder.norm.parameters():
                param.requires_grad = True
            self.encoder_frozen = False

# ===== Initialize or Resume Model =====
model = MAEClassifier(num_classes=len(train_dataset.label2id), freeze_encoder=True).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

start_epoch = 0
best_loss = float('inf')
checkpoint_path = os.path.join(checkpoint_dir, "last_checkpoint.pth")

if os.path.exists(checkpoint_path):
    print("Resuming from last checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    best_loss = checkpoint["best_loss"]
    model.encoder_frozen = checkpoint.get("encoder_frozen", True)
    print(f"Resumed from epoch {start_epoch} with best loss {best_loss:.4f}")

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-4, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs - start_epoch)
scaler = torch.cuda.amp.GradScaler()

# ===== Training Loop =====
for epoch in range(start_epoch, num_epochs):
    if epoch == 10 and model.encoder_frozen:
        model.unfreeze_encoder(num_blocks=6)
        optimizer = optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-4)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs - epoch)

    model.train()
    total_loss = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
    for images, labels in loop:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        loop.set_postfix(loss=loss.item())

    scheduler.step()
    print(f"Epoch {epoch+1}/{num_epochs} - Total Loss: {total_loss:.4f}")

    if total_loss < best_loss:
        best_loss = total_loss
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, "best_model.pth"))
        print(" Best model saved.")

    if (epoch + 1) % 20 == 0:
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "best_loss": best_loss,
            "encoder_frozen": model.encoder_frozen
        }, os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth"))
        print(f"Checkpoint saved at epoch {epoch+1}")

torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "best_loss": best_loss,
    "encoder_frozen": model.encoder_frozen
}, checkpoint_path)
print("Last checkpoint saved.")

with open(os.path.join(checkpoint_dir, "label_map.json"), "w") as f:
    json.dump(train_dataset.id2label, f)
print(" Label map saved.")





####you can run all the above at once for training then run the one below 









model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
                                                                         

Epoch 1/250 - Total Loss: 2499.9117
 Best model saved.


                                                                         

Epoch 2/250 - Total Loss: 2239.5651
 Best model saved.


                                                                         

Epoch 3/250 - Total Loss: 2156.8118
 Best model saved.


                                                                         

Epoch 4/250 - Total Loss: 2097.8204
 Best model saved.


                                                                         

Epoch 5/250 - Total Loss: 2049.8361
 Best model saved.


                                                                         

Epoch 6/250 - Total Loss: 2011.0941
 Best model saved.


                                                                         

Epoch 7/250 - Total Loss: 1980.5683
 Best model saved.


                                                                         

Epoch 8/250 - Total Loss: 1944.4984
 Best model saved.


                                                                         

Epoch 9/250 - Total Loss: 1912.2999
 Best model saved.


                                                                          

Epoch 10/250 - Total Loss: 1881.9845
 Best model saved.
Unfreezing last 6 encoder blocks for fine-tuning...


                                                                          

Epoch 11/250 - Total Loss: 1421.1053
 Best model saved.


                                                                          

Epoch 12/250 - Total Loss: 1083.9597
 Best model saved.


                                                                          

Epoch 13/250 - Total Loss: 903.1479
 Best model saved.


                                                                          

Epoch 14/250 - Total Loss: 778.5106
 Best model saved.


                                                                          

Epoch 15/250 - Total Loss: 688.1588
 Best model saved.


                                                                          

Epoch 16/250 - Total Loss: 599.2748
 Best model saved.


                                                                          

Epoch 17/250 - Total Loss: 594.4878
 Best model saved.


                                                                          

Epoch 18/250 - Total Loss: 562.0109
 Best model saved.


                                                                          

Epoch 19/250 - Total Loss: 539.1511
 Best model saved.


                                                                          

Epoch 20/250 - Total Loss: 524.5387
 Best model saved.
Checkpoint saved at epoch 20


                                                                          

Epoch 21/250 - Total Loss: 513.9394
 Best model saved.


                                                                          

Epoch 22/250 - Total Loss: 495.7649
 Best model saved.


                                                                          

Epoch 23/250 - Total Loss: 487.2299
 Best model saved.


                                                                          

Epoch 24/250 - Total Loss: 484.1017
 Best model saved.


                                                                          

Epoch 25/250 - Total Loss: 480.4074
 Best model saved.


                                                                          

Epoch 26/250 - Total Loss: 477.7402
 Best model saved.


                                                                          

Epoch 27/250 - Total Loss: 472.2629
 Best model saved.


                                                                          

Epoch 28/250 - Total Loss: 468.9397
 Best model saved.


                                                                          

Epoch 29/250 - Total Loss: 467.6421
 Best model saved.


                                                                          

Epoch 30/250 - Total Loss: 466.3239
 Best model saved.


                                                                          

Epoch 31/250 - Total Loss: 465.7701
 Best model saved.


                                                                          

Epoch 32/250 - Total Loss: 464.3702
 Best model saved.


                                                                          

Epoch 33/250 - Total Loss: 463.7789
 Best model saved.


                                                                          

Epoch 34/250 - Total Loss: 467.1166


                                                                          

Epoch 35/250 - Total Loss: 466.7836


                                                                          

Epoch 36/250 - Total Loss: 465.6167


                                                                          

Epoch 37/250 - Total Loss: 462.9671
 Best model saved.


                                                                          

Epoch 38/250 - Total Loss: 461.8429
 Best model saved.


                                                                          

Epoch 39/250 - Total Loss: 457.1085
 Best model saved.


                                                                          

Epoch 40/250 - Total Loss: 455.2278
 Best model saved.
Checkpoint saved at epoch 40


                                                                          

Epoch 41/250 - Total Loss: 454.0866
 Best model saved.


                                                                          

Epoch 42/250 - Total Loss: 453.4291
 Best model saved.


                                                                          

Epoch 43/250 - Total Loss: 452.7997
 Best model saved.


                                                                          

Epoch 44/250 - Total Loss: 452.0289
 Best model saved.


                                                                          

Epoch 45/250 - Total Loss: 451.5159
 Best model saved.


                                                                          

Epoch 46/250 - Total Loss: 450.6421
 Best model saved.


                                                                          

Epoch 47/250 - Total Loss: 450.6023
 Best model saved.


                                                                          

Epoch 48/250 - Total Loss: 449.8192
 Best model saved.


                                                                          

Epoch 49/250 - Total Loss: 449.2373
 Best model saved.


                                                                          

Epoch 50/250 - Total Loss: 449.0377
 Best model saved.


                                                                          

Epoch 51/250 - Total Loss: 448.3159
 Best model saved.


                                                                          

Epoch 52/250 - Total Loss: 448.1705
 Best model saved.


                                                                          

Epoch 53/250 - Total Loss: 447.5078
 Best model saved.


                                                                          

Epoch 54/250 - Total Loss: 446.2238
 Best model saved.


                                                                          

Epoch 55/250 - Total Loss: 446.9246


                                                                          

Epoch 56/250 - Total Loss: 446.3559


                                                                          

Epoch 57/250 - Total Loss: 446.2549


                                                                          

Epoch 58/250 - Total Loss: 445.8619
 Best model saved.


                                                                          

Epoch 59/250 - Total Loss: 447.7929


                                                                          

Epoch 60/250 - Total Loss: 448.8202
Checkpoint saved at epoch 60


                                                                          

Epoch 61/250 - Total Loss: 448.0874


                                                                          

Epoch 62/250 - Total Loss: 447.8789


                                                                          

Epoch 63/250 - Total Loss: 446.5089


                                                                          

Epoch 64/250 - Total Loss: 443.9175
 Best model saved.


                                                                          

Epoch 65/250 - Total Loss: 443.9319


                                                                          

Epoch 66/250 - Total Loss: 445.1978


                                                                          

Epoch 67/250 - Total Loss: 444.7724


                                                                          

Epoch 68/250 - Total Loss: 444.5940


                                                                          

Epoch 69/250 - Total Loss: 443.8910
 Best model saved.


                                                                          

Epoch 70/250 - Total Loss: 442.3767
 Best model saved.


                                                                          

Epoch 71/250 - Total Loss: 440.5229
 Best model saved.


                                                                          

Epoch 72/250 - Total Loss: 439.9575
 Best model saved.


                                                                          

Epoch 73/250 - Total Loss: 439.7917
 Best model saved.


                                                                          

Epoch 74/250 - Total Loss: 439.5638
 Best model saved.


                                                                          

Epoch 75/250 - Total Loss: 438.9300
 Best model saved.


                                                                          

Epoch 76/250 - Total Loss: 439.0779


                                                                          

Epoch 77/250 - Total Loss: 438.8719
 Best model saved.


                                                                          

Epoch 78/250 - Total Loss: 438.7442
 Best model saved.


                                                                          

Epoch 79/250 - Total Loss: 438.2961
 Best model saved.


                                                                          

Epoch 80/250 - Total Loss: 438.1022
 Best model saved.
Checkpoint saved at epoch 80


                                                                          

Epoch 81/250 - Total Loss: 437.2170
 Best model saved.


                                                                          

Epoch 82/250 - Total Loss: 437.6698


                                                                          

Epoch 83/250 - Total Loss: 437.7018


                                                                          

Epoch 84/250 - Total Loss: 437.5334


                                                                          

Epoch 85/250 - Total Loss: 437.4652


                                                                          

Epoch 86/250 - Total Loss: 437.0179
 Best model saved.


                                                                          

Epoch 87/250 - Total Loss: 436.9501
 Best model saved.


                                                                          

Epoch 88/250 - Total Loss: 436.8356
 Best model saved.


                                                                          

Epoch 89/250 - Total Loss: 436.5854
 Best model saved.


                                                                          

Epoch 90/250 - Total Loss: 436.3619
 Best model saved.


                                                                          

Epoch 91/250 - Total Loss: 436.1508
 Best model saved.


                                                                          

Epoch 92/250 - Total Loss: 436.1057
 Best model saved.


                                                                          

Epoch 93/250 - Total Loss: 435.9489
 Best model saved.


                                                                          

Epoch 94/250 - Total Loss: 435.6878
 Best model saved.


                                                                          

Epoch 95/250 - Total Loss: 435.5230
 Best model saved.


                                                                          

Epoch 96/250 - Total Loss: 435.6410


                                                                          

Epoch 97/250 - Total Loss: 434.8469
 Best model saved.


                                                                          

Epoch 99/250 - Total Loss: 435.0535


                                                                           

Epoch 100/250 - Total Loss: 434.9239
Checkpoint saved at epoch 100


                                                                           

Epoch 101/250 - Total Loss: 434.8831


                                                                           

Epoch 102/250 - Total Loss: 434.6390
 Best model saved.


                                                                           

Epoch 103/250 - Total Loss: 434.6361
 Best model saved.


                                                                           

Epoch 104/250 - Total Loss: 434.4084
 Best model saved.


                                                                           

Epoch 105/250 - Total Loss: 434.3330
 Best model saved.


                                                                           

Epoch 106/250 - Total Loss: 434.2666
 Best model saved.


                                                                           

Epoch 107/250 - Total Loss: 434.0867
 Best model saved.


                                                                           

Epoch 108/250 - Total Loss: 433.4426
 Best model saved.


                                                                           

Epoch 109/250 - Total Loss: 433.0848
 Best model saved.


                                                                           

Epoch 110/250 - Total Loss: 433.0078
 Best model saved.


                                                                           

Epoch 111/250 - Total Loss: 432.8664
 Best model saved.


                                                                           

Epoch 112/250 - Total Loss: 432.9480


                                                                           

Epoch 113/250 - Total Loss: 432.7389
 Best model saved.


                                                                           

Epoch 114/250 - Total Loss: 432.7419


                                                                           

Epoch 115/250 - Total Loss: 432.6516
 Best model saved.


                                                                           

Epoch 116/250 - Total Loss: 432.5945
 Best model saved.


                                                                           

Epoch 117/250 - Total Loss: 432.5015
 Best model saved.


                                                                           

Epoch 118/250 - Total Loss: 432.4064
 Best model saved.


                                                                           

Epoch 119/250 - Total Loss: 432.4473


                                                                           

Epoch 120/250 - Total Loss: 432.2937
 Best model saved.
Checkpoint saved at epoch 120


                                                                           

Epoch 121/250 - Total Loss: 432.1809
 Best model saved.


                                                                           

Epoch 122/250 - Total Loss: 432.2279


                                                                           

Epoch 123/250 - Total Loss: 433.0517


                                                                           

Epoch 124/250 - Total Loss: 434.0378


                                                                           

Epoch 125/250 - Total Loss: 433.9860


                                                                           

Epoch 126/250 - Total Loss: 433.6739


                                                                           

Epoch 127/250 - Total Loss: 433.5586


                                                                           

Epoch 128/250 - Total Loss: 433.3189


                                                                           

Epoch 129/250 - Total Loss: 432.4508


                                                                           

Epoch 130/250 - Total Loss: 432.2581


                                                                           

Epoch 131/250 - Total Loss: 432.1827


Epoch 132/250:  51%|█████     | 209/409 [01:09<01:01,  3.26it/s, loss=1.06]

In [None]:
###Inference code

import torch.nn.functional as F 
checkpoint_dir = "./checkpoints"
checkpoint_path = os.path.join(checkpoint_dir, "best_model.pth")
label_map_path = os.path.join(checkpoint_dir, "label_map.json")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_size = 224
confidence_threshold = 0.80

# ===== Load metadata and label map =====
df = pd.read_csv(metadata_path)
with open(label_map_path, "r") as f:
    id2label = json.load(f)
label_map = {int(k): v for k, v in id2label.items()}
num_classes = len(label_map)

# ===== Transform =====
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

# ===== Model Definition =====
class MAEClassifier(torch.nn.Module):
    def __init__(self, num_classes):
        super(MAEClassifier, self).__init__()
        import timm
        self.encoder = timm.create_model("vit_base_patch16_224", pretrained=False, num_classes=0)
        self.classifier = torch.nn.Linear(self.encoder.num_features, num_classes)

    def forward(self, x):
        x = self.encoder(x)
        return self.classifier(x)

# ===== Load Model from Checkpoint =====
model = MAEClassifier(num_classes=num_classes).to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
print(f" Loaded checkpoint from {checkpoint_path}")

# ===== Evaluate on Query Set =====
results = []
query_df = df[df["split"] == "query"].reset_index(drop=True)

for _, row in tqdm(query_df.iterrows(), total=len(query_df), desc="Evaluating"):
    img_path = os.path.join(root_dir, row["path"])
    img = Image.open(img_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(img_tensor)
        probs = F.softmax(logits, dim=1)
        max_prob, pred_class = torch.max(probs, dim=1)
        identity = label_map[pred_class.item()] if max_prob.item() >= confidence_threshold else "new_individual"

    results.append({
        "image_id": row["image_id"],
        "identity": identity
    })

# ===== Save Submission =====
submission_df = pd.DataFrame(results)
submission_df.to_csv("submission.csv", index=False)
print(" Final classifier-based submission saved as 'submission.csv'")