In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pathlib import Path

In [2]:
IMG_SIZE = 224

train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ColorJitter(brightness=0.2,
                           contrast=0.2,
                           saturation=0.2,
                           hue=0.1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=10),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

val_test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406),
                         std=(0.229, 0.224, 0.225)),
])

In [None]:
root_dir   = Path("dataset")
batch_size = 64
num_workers = 4

# Datasets
train_ds = datasets.ImageFolder(root_dir / "train_extracted",
                                transform=train_transforms)
val_ds   = datasets.ImageFolder(root_dir / "val_extracted",
                                transform=val_test_transforms)
# test_ds  = datasets.ImageFolder(root_dir / "test_extracted",
#                                 transform=val_test_transforms)

# DataLoaders
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                      num_workers=num_workers, pin_memory=True, drop_last=True)
val_dl   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False,
                      num_workers=num_workers, pin_memory=True)
# test_dl  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False,
#                       num_workers=num_workers, pin_memory=True)

# Quick sanity check
idx_to_class = {v: k for k, v in train_ds.class_to_idx.items()}
print(f"{len(idx_to_class)} classes detected:", idx_to_class)

imgs, labels = next(iter(train_dl))
print("Batch tensor shape:", imgs.shape)
print("Labels shape:", labels.shape)

📚  20 classes detected: {0: '00175_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Aptera_fusca', 1: '00176_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Panchlora_nivea', 2: '00177_Animalia_Arthropoda_Insecta_Blattodea_Blaberidae_Pycnoscelus_surinamensis', 3: '00178_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Blatta_orientalis', 4: '00179_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_americana', 5: '00180_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_australasiae', 6: '00181_Animalia_Arthropoda_Insecta_Blattodea_Blattidae_Periplaneta_fuliginosa', 7: '00182_Animalia_Arthropoda_Insecta_Blattodea_Ectobiidae_Pseudomops_septentrionalis', 8: '00443_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_aegypti', 9: '00444_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_albopictus', 10: '00445_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Aedes_vexans', 11: '00446_Animalia_Arthropoda_Insecta_Diptera_Culicidae_Culex_quinquefasciatus', 12: '00447_Animalia_A

In [None]:
# 📒  Cell 4 — ConvNeXt-Tiny with local IN-12k weights (head stripped)

import torch, timm, pathlib, re, torch.nn as nn
from torchsummary import summary                # optional: pip install torchsummary

# --------------------------------------------------------------- paths & config
ckpt_path   = pathlib.Path("convnext_tiny_in12k.pth")   # local file you downloaded
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
num_classes = len(idx_to_class)                         # from Cell 3
IMG_SIZE    = 224

# --------------------------------------------------------------- 1️⃣  build full model (20-class head)
model = timm.create_model(
    "convnext_tiny.in12k",
    pretrained=False,          # ← don't hit the Internet
    num_classes=num_classes
).to(DEVICE)

# --------------------------------------------------------------- 2️⃣  load backbone weights
raw = torch.load(ckpt_path, map_location="cpu")
state_dict = raw["model"] if isinstance(raw, dict) and "model" in raw else raw

# strip 'module.' prefix (DDP) and *discard* the old 11 821-class head
clean_sd = {
    re.sub(r'^module\.', '', k): v
    for k, v in state_dict.items()
    if not k.startswith("head.")          # <-- drop obsolete classifier weights
}

missing, unexpected = model.load_state_dict(clean_sd, strict=False)
print(f"✅  Backbone loaded — skipped {len(unexpected)} head weights "
      f"• missing keys (new 20-way head): {len(missing)}")

# --------------------------------------------------------------- 3️⃣  sanity forward pass
model.eval()
with torch.no_grad():
    imgs, _ = next(iter(train_dl))        # from Cell 3
    logits = model(imgs.to(DEVICE))
    print("Logits shape:", logits.shape)  # expect [batch, 20]

# --------------------------------------------------------------- 4️⃣  (optional) layer table
try:
    summary(model, input_size=(3, IMG_SIZE, IMG_SIZE))
except Exception:
    pass

In [None]:
# 📒  Cell 5 — Train / validate ConvNeXt-Tiny
import torch, time, math
from torch.cuda.amp import autocast, GradScaler

# ---------------------- hyper-parameters -----------------------------
EPOCHS                = 10
FREEZE_BACKBONE_EPOCH = 1          # unfreeze after this many epochs
LR_HEAD               = 1e-3       # while backbone frozen
LR_FULL               = 3e-4       # after unfreeze
WEIGHT_DECAY          = 1e-2

criterion = torch.nn.CrossEntropyLoss()
# head params only (backbone frozen) ----------------
head_params = [p for n,p in model.named_parameters() if n.startswith("head.")]
optimizer   = torch.optim.AdamW(head_params, lr=LR_HEAD, weight_decay=WEIGHT_DECAY)
scheduler   = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

scaler = GradScaler()              # AMP

# ---------------------- helper to toggle backbone grads --------------
def set_backbone_trainable(flag: bool):
    for name, param in model.named_parameters():
        if not name.startswith("head."):
            param.requires_grad = flag

set_backbone_trainable(False)      # freeze initially

# ---------------------- training loop --------------------------------
for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    model.train()
    running_loss, correct, seen = 0.0, 0, 0

    for images, labels in train_dl:
        images, labels = images.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        with autocast():                           # mixed precision
            logits = model(images)
            loss   = criterion(logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * images.size(0)
        correct      += (logits.argmax(1) == labels).sum().item()
        seen         += images.size(0)

    train_loss = running_loss / seen
    train_acc  = correct / seen

    # ------------- validation -------------
    model.eval()
    val_loss, val_correct, val_seen = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in val_dl:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            with autocast():
                logits = model(images)
                loss   = criterion(logits, labels)

            val_loss   += loss.item() * images.size(0)
            val_correct += (logits.argmax(1) == labels).sum().item()
            val_seen   += images.size(0)

    val_loss /= val_seen
    val_acc  = val_correct / val_seen

    scheduler.step()

    print(f"[{epoch:02}/{EPOCHS}] "
          f"train {train_loss:.4f} / {train_acc:.2%} │ "
          f"val {val_loss:.4f} / {val_acc:.2%} │ "
          f"lr {optimizer.param_groups[0]['lr']:.2e} │ "
          f"{(time.time()-t0):.1f}s")

    # --------- unfreeze backbone after first epoch ----------
    if epoch == FREEZE_BACKBONE_EPOCH:
        print("🟢 Unfreezing backbone & switching to lower LR.")
        set_backbone_trainable(True)
        # re-build optimizer to include *all* parameters
        optimizer = torch.optim.AdamW(model.parameters(), lr=LR_FULL,
                                      weight_decay=WEIGHT_DECAY)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                        optimizer, T_max=EPOCHS - epoch)

print("Training complete.")