In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms, datasets
from torch.utils.data import DataLoader
import wandb
from pytorch_lightning.callbacks import EarlyStopping


def train():
    # 1) Initialize wandb
    wandb.init(project="iNaturalist_EffNetV2S_finetune_2", config={
        "architecture": "EfficientNetV2-S",
        "dataset": "iNaturalist12K",
        "num_classes": 10,
        "batch_size": 32,
        "epochs_stage1": 30,
        "epochs_stage2": 30,
        "epochs_stage3": 50,
        "lr_stage1": 1e-3,
        "lr_stage2": 1e-4,
        "lr_stage3": 1e-5,
        "img_size": 224,
    })
    config = wandb.config

    # 2) Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 3) Data augmentation and normalization
    data_transforms = {
        "train": transforms.Compose([
            transforms.RandomResizedCrop(config.img_size),
            transforms.RandomHorizontalFlip(),
            transforms.RandomAffine(degrees=15, translate=(0.1,0.1), scale=(0.9,1.1)),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ]),
        "val": transforms.Compose([
            transforms.Resize(int(config.img_size*1.15)),
            transforms.CenterCrop(config.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
        ]),
    }

    # 4) Datasets & Loaders (update these paths)
    train_dir = "/home/user/kartikey_phd/DA6401/nature_12K/inaturalist_12K/train"
    val_dir   = "/home/user/kartikey_phd/DA6401/nature_12K/inaturalist_12K/val"
    image_datasets = {
        x: datasets.ImageFolder(train_dir if x=="train" else val_dir,
                                 data_transforms[x])
        for x in ["train","val"]
    }
    dataloaders = {
        x: DataLoader(image_datasets[x],
                      batch_size=config.batch_size,
                      shuffle=(x=="train"),
                      num_workers=4)
        for x in ["train","val"]
    }

    # 5) Load EfficientNetV2-S pretrained on ImageNet
    model = models.efficientnet_v2_s(
        weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
    )
    # Replace classifier head
    in_features = model.classifier[1].in_features
    model.classifier[1] = nn.Linear(in_features, config.num_classes)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()

    def run_epoch(stage):
        """ helper to run one epoch of train+val and log metrics """
        model.train()
        running_loss, running_corrects = 0.0, 0
        for inputs, labels in dataloaders["train"]:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            running_corrects += (outputs.argmax(1) == labels).sum().item()
        epoch_loss = running_loss / len(image_datasets["train"])
        epoch_acc  = running_corrects / len(image_datasets["train"])

        model.eval()
        val_loss, val_corrects = 0.0, 0
        with torch.no_grad():
            for inputs, labels in dataloaders["val"]:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                val_corrects += (outputs.argmax(1) == labels).sum().item()
        val_loss /= len(image_datasets["val"])
        val_acc  = val_corrects / len(image_datasets["val"])

        wandb.log({
            f"{stage}_train_loss": epoch_loss,
            f"{stage}_train_acc": epoch_acc,
            f"{stage}_val_loss": val_loss,
            f"{stage}_val_acc": val_acc,
        })
        print(f"{stage} → TrainLoss: {epoch_loss:.4f} TrainAcc: {epoch_acc:.4f} | "
              f"ValLoss: {val_loss:.4f} ValAcc: {val_acc:.4f}")

    # ─────────────────────────────────────────────────────────────
    # Stage 1: Freeze all but classifier
    for param in model.parameters():
        param.requires_grad = False
    for param in model.classifier.parameters():
        param.requires_grad = True

    optimizer = optim.Adam(model.classifier.parameters(), lr=config.lr_stage1)
    print(">>> Stage 1: Training classifier only")
    for epoch in range(config.epochs_stage1):
        run_epoch("stage1")

    # ─────────────────────────────────────────────────────────────
    # Stage 2: Unfreeze last feature block + classifier
    # EfficientNetV2-S features are in model.features (a Sequential). Unfreeze last block.
    for param in model.features[-1].parameters():
        param.requires_grad = True
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=config.lr_stage2
    )
    print("\n>>> Stage 2: Unfreezing last block + classifier")
    for epoch in range(config.epochs_stage2):
        run_epoch("stage2")

    # ─────────────────────────────────────────────────────────────
    # Stage 3: Unfreeze everything
    for param in model.parameters():
        param.requires_grad = True
    optimizer = optim.Adam(model.parameters(), lr=config.lr_stage3)
    print("\n>>> Stage 3: Fine‑tuning entire network")
    for epoch in range(config.epochs_stage3):
        run_epoch("stage3")

    wandb.finish()

if __name__ == "__main__":
    train()

>>> Stage 1: Training classifier only
stage1 → TrainLoss: 1.2273 TrainAcc: 0.6309 | ValLoss: 0.7512 ValAcc: 0.7835
stage1 → TrainLoss: 0.9858 TrainAcc: 0.6855 | ValLoss: 0.7049 ValAcc: 0.7845
stage1 → TrainLoss: 0.9567 TrainAcc: 0.6929 | ValLoss: 0.6758 ValAcc: 0.7980
stage1 → TrainLoss: 0.9553 TrainAcc: 0.6930 | ValLoss: 0.6618 ValAcc: 0.7885
stage1 → TrainLoss: 0.9473 TrainAcc: 0.6971 | ValLoss: 0.6557 ValAcc: 0.7930
stage1 → TrainLoss: 0.9339 TrainAcc: 0.6982 | ValLoss: 0.6637 ValAcc: 0.7985
stage1 → TrainLoss: 0.9406 TrainAcc: 0.6979 | ValLoss: 0.6386 ValAcc: 0.8015
stage1 → TrainLoss: 0.9230 TrainAcc: 0.7051 | ValLoss: 0.6273 ValAcc: 0.8045
stage1 → TrainLoss: 0.9405 TrainAcc: 0.7000 | ValLoss: 0.6436 ValAcc: 0.8050
stage1 → TrainLoss: 0.9303 TrainAcc: 0.7047 | ValLoss: 0.6237 ValAcc: 0.8045
stage1 → TrainLoss: 0.9409 TrainAcc: 0.7008 | ValLoss: 0.6314 ValAcc: 0.8035
stage1 → TrainLoss: 0.9297 TrainAcc: 0.7011 | ValLoss: 0.6321 ValAcc: 0.8070
stage1 → TrainLoss: 0.9197 TrainAcc: 0