In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
import os
import copy

# Config
data_dir = "/kaggle/working/dataset_split"
num_classes = 8
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data loaders
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])  # ImageNet norm
])

train_ds = datasets.ImageFolder(os.path.join(data_dir,"train"), transform=transform)
val_ds   = datasets.ImageFolder(os.path.join(data_dir,"val"), transform=transform)
test_ds  = datasets.ImageFolder(os.path.join(data_dir,"test"), transform=transform)

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader   = torch.utils.data.DataLoader(val_ds, batch_size=batch_size)
test_loader  = torch.utils.data.DataLoader(test_ds, batch_size=batch_size)

class_names = train_ds.classes
print("Classes:", class_names)


# ----- Training function with early stopping -----
def train_model(model, name, head_epochs=5, finetune_epochs=20, patience=5, lr_head=1e-3, lr_finetune=1e-5):

    criterion = nn.CrossEntropyLoss()
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improve = 0

    # ---- Phase A: Train classifier head ----
    for param in model.parameters():
        param.requires_grad = False
    for param in model.fc.parameters() if hasattr(model, 'fc') else model.classifier.parameters():
        param.requires_grad = True

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_head, weight_decay=1e-4)

    print(f"\n[{name}] Phase A: Training head for {head_epochs} epochs")
    for epoch in range(head_epochs):
        train_one_epoch(model, optimizer, criterion, train_loader, val_loader, epoch, head_epochs, name)
    
    # ---- Phase B: Fine-tune backbone ----
    for param in model.parameters():
        param.requires_grad = True

    optimizer = optim.AdamW(model.parameters(), lr=lr_finetune, weight_decay=1e-4)

    print(f"\n[{name}] Phase B: Fine-tuning backbone for {finetune_epochs} epochs")
    for epoch in range(finetune_epochs):
        train_loss, train_acc, val_acc = train_one_epoch(model, optimizer, criterion, train_loader, val_loader, epoch, finetune_epochs, name)

        # Early stopping
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            no_improve = 0
        else:
            no_improve += 1
            if no_improve >= patience:
                print(f"[{name}] Early stopping at epoch {epoch+1}")
                break

    # Load best weights and save
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(), f"{name}_best.pth")
    print(f"[{name}] Best Val Acc: {best_acc:.2f}%")
    return model


def train_one_epoch(model, optimizer, criterion, train_loader, val_loader, epoch, total_epochs, name):
    model.to(device)
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_acc = 100. * correct / total

    # Validation accuracy
    model.eval()
    correct_val, total_val = 0, 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, predicted = outputs.max(1)
            total_val += labels.size(0)
            correct_val += predicted.eq(labels).sum().item()
    val_acc = 100. * correct_val / total_val

    print(f"[{name}] Epoch {epoch+1}/{total_epochs} "
          f"Train Loss: {running_loss/len(train_loader):.4f}, "
          f"Train Acc: {train_acc:.2f}%, Val Acc: {val_acc:.2f}%")

    return running_loss/len(train_loader), train_acc, val_acc


# ---- Train Models ----
resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
resnet.fc = nn.Linear(resnet.fc.in_features, num_classes)
resnet = train_model(resnet, "ResNet50", head_epochs=8, finetune_epochs=25)

eff = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
eff.classifier[1] = nn.Linear(eff.classifier[1].in_features, num_classes)
eff = train_model(eff, "EfficientNet", head_epochs=5, finetune_epochs=20)

conv = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
conv.classifier[2] = nn.Linear(conv.classifier[2].in_features, num_classes)
conv = train_model(conv, "ConvNeXt", head_epochs=10, finetune_epochs=30)


# ---- Evaluate on Test ----
def evaluate_model(model, name):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for imgs, labels in test_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    acc = 100. * correct / total
    print(f"[{name}] Test Accuracy: {acc:.2f}%")


evaluate_model(resnet, "ResNet50")
evaluate_model(eff, "EfficientNet")
evaluate_model(conv, "ConvNeXt")
