In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path

In [2]:
IMG_SIZE = 224

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

In [None]:
root_dir   = Path("dataset")
batch_size = 64
num_workers = 4

# Datasets
train_ds = datasets.ImageFolder(root_dir / "train_extracted",
                                transform=train_transforms)
val_ds   = datasets.ImageFolder(root_dir / "val_extracted",
                                transform=val_test_transforms)
# test_ds  = datasets.ImageFolder(root_dir / "test_extracted",
#                                 transform=val_test_transforms)

# DataLoaders
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                      num_workers=num_workers, pin_memory=True, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                      num_workers=num_workers, pin_memory=True)
# test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
#                       num_workers=num_workers, pin_memory=True)

# Quick sanity check
idx_to_class = {v: k for k, v in train_ds.class_to_idx.items()}
print(f"{len(idx_to_class)} classes detected:", idx_to_class)

imgs, labels = next(iter(train_dl))
print("Batch tensor shape:", imgs.shape)
print("Labels shape:", labels.shape)

📚  20 classes detected: {0: '00175_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Aptera_fusca', 1: '00176_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Panchlora_nivea', 2: '00177_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Pycnoscelus_surinamensis', 3: '00178_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Blatta_orientalis', 4: '00179_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_americana', 5: '00180_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_australasiae', 6: '00181_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_fuliginosa', 7: '00182_Animalia_Arthropoda_Insecta_Blattodea_Ectobiidae_Pseudomops_septentrionalis', 8: '00443_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_aegypti', 9: '00444_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_albopictus', 10: '00445_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_vexans', 11: '00446_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Culex_quinquefasciatus', 12: '00447_Animalia_A

In [None]:
# 📒  Cell 4 — ConvNeXt-Tiny: load backbone weights, add fresh 20-class head
import torch, timm, pathlib, re
import torch.nn as nn
from torchsummary import summary               # optional; pip install torchsummary

# ------------------------------------------------------------------ config
ckpt_path   = pathlib.Path("convnext_tiny_in12k.pth")   # your local file
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = len(idx_to_class)                         # from Cell 3
IMG_SIZE    = 224                                       # must match Cell 2

# ------------------------------------------------------------------ 1️⃣ build backbone (NO classifier yet)
model = timm.create_model(
    "convnext_tiny.in12k",
    pretrained=False,           # prevent auto-download
    num_classes=0               # ← this gives us the backbone only
).to(DEVICE)

# ------------------------------------------------------------------ 2️⃣ load checkpoint strictly
state = torch.load(ckpt_path, map_location="cpu")

# Many timm checkpoints store weights under 'model'.  Fallback to raw dict.
state_dict = state["model"] if isinstance(state, dict) and "model" in state else state
# Remove any 'module.' prefix (common in DDP training)
state_dict = {re.sub(r"^module\.", "", k): v for k, v in state_dict.items()}

missing, unexpected = model.load_state_dict(state_dict, strict=True)
assert not missing and not unexpected, f"Mismatch when loading weights"

print("✅  Backbone loaded strictly; no missing / unexpected keys")

# ------------------------------------------------------------------ 3️⃣ attach a new classifier head
in_features = model.num_features        # ConvNeXt attribute
model.head = nn.Linear(in_features, num_classes)
nn.init.trunc_normal_(model.head.weight, std=0.02)   # match timm init
nn.init.zeros_(model.head.bias)

model = model.to(DEVICE)

# Optional: freeze backbone for 1–2 epochs by uncommenting below
# for name, param in model.named_parameters():
#     if not name.startswith("head."):
#         param.requires_grad = False

# ------------------------------------------------------------------ 4️⃣ sanity forward pass
model.eval()
with torch.no_grad():
    imgs, _ = next(iter(train_dl))      # from Cell 3
    logits = model(imgs.to(DEVICE))
    print("Logits shape:", logits.shape)   # expect [batch, 20]

# ------------------------------------------------------------------ 5️⃣ (optional) layer summary
try:
    summary(model, input_size=(3, IMG_SIZE, IMG_SIZE))
except Exception:
    pass