In [None]:
!pip install --quiet --upgrade torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import math, os, time, copy, random
import gc
from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets, models
from torch.utils.data import random_split, DataLoader
from tqdm.auto import tqdm
from contextlib import nullcontext

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

SEED = 42
random.seed(SEED);  torch.manual_seed(SEED);  torch.cuda.manual_seed_all(SEED)

In [None]:
CFG = dict(
    num_epochs       = 100,
    batch_size       = 128,
    lr               = 4e-3,
    weight_decay     = 0.05,
    warmup_epochs    = 5,
    num_workers      = 0,
    image_size       = 224,          # ConvNeXt default
    val_split_ratio  = 0.1,
    amp              = True,         # Automatic Mixed Precision
    pretrained       = False,         # Start from ImageNet weights
    ckpt_dir         = "/content/drive/MyDrive/11785/project/cifar10-baseline-checkpoint",
)
Path(CFG["ckpt_dir"]).mkdir(exist_ok=True)

In [None]:
# Transforms -----------------------------------------------------------
train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(CFG["image_size"]),
    transforms.RandomHorizontalFlip(),
    transforms.AutoAugment(policy=transforms.AutoAugmentPolicy.CIFAR10),
    transforms.ToTensor(),
    transforms.Normalize((0.4915, 0.4823, 0.4468),
                         (0.2470, 0.2435, 0.2616)),
])

val_tfms = transforms.Compose([
    transforms.Resize(CFG["image_size"] + 32),
    transforms.CenterCrop(CFG["image_size"]),
    transforms.ToTensor(),
    transforms.Normalize((0.4915, 0.4823, 0.4468),
                         (0.2470, 0.2435, 0.2616)),
])

# Datasets -------------------------------------------------------------
root = "./data"
full_train = datasets.CIFAR10(root, train=True,  download=True, transform=train_tfms)
val_len      = int(len(full_train) * CFG["val_split_ratio"])
train_len    = len(full_train) - val_len
train_ds, val_ds = random_split(full_train, [train_len, val_len],
                                generator=torch.Generator().manual_seed(SEED))

test_ds = datasets.CIFAR10(root, train=False, download=True, transform=val_tfms)

# Dataloaders ----------------------------------------------------------

train_loader = DataLoader(
    train_ds, batch_size=CFG["batch_size"],
    shuffle=True, num_workers=CFG["num_workers"], pin_memory=True)

val_loader = DataLoader(
    val_ds,   batch_size=CFG["batch_size"],
    shuffle=False, num_workers=CFG["num_workers"], pin_memory=True)

test_loader = DataLoader(
    test_ds,  batch_size=CFG["batch_size"],
    shuffle=False, num_workers=CFG["num_workers"], pin_memory=True)

In [None]:
import math
from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR

if CFG["pretrained"]:
    weights = models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1
else:
    weights = None

model = models.convnext_tiny(weights=weights)
# Replace classification head (last linear) to output 10 classes
in_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(in_features, 10)
model.to(device)

criterion  = nn.CrossEntropyLoss()
optimizer  = torch.optim.AdamW(model.parameters(),
                               lr=CFG["lr"], weight_decay=CFG["weight_decay"])


def build_warmup_cosine_scheduler(optimizer, steps_per_epoch, num_epochs,
                                  warmup_epochs=1, eta_min=1e-5, accum_steps=1):
    steps_per_epoch = math.ceil(steps_per_epoch / max(1, accum_steps))
    total_steps = num_epochs * steps_per_epoch
    warmup_steps = warmup_epochs * steps_per_epoch
    cosine_steps = max(1, total_steps - warmup_steps)

    scheds = []
    milestones = []

    if warmup_steps > 0:
        # start_factor cannot be 0 in some versions; use a tiny epsilon if needed.
        s1 = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=warmup_steps)
        scheds.append(s1)
        milestones.append(warmup_steps)

    s2 = CosineAnnealingLR(optimizer, T_max=cosine_steps, eta_min=eta_min)
    scheds.append(s2)

    scheduler = SequentialLR(optimizer, schedulers=scheds, milestones=milestones or [0])
    return scheduler

scheduler = build_warmup_cosine_scheduler(
    optimizer,
    steps_per_epoch=len(train_loader),
    num_epochs=CFG["num_epochs"],
    warmup_epochs=CFG["warmup_epochs"]
)
scaler     = torch.cuda.amp.GradScaler(enabled=CFG["amp"])

In [None]:
def accuracy(preds, targets, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        _, pred = preds.topk(maxk, dim=1, largest=True, sorted=True)
        pred   = pred.t()
        correct= pred.eq(targets.view(1, -1).expand_as(pred))
        return [correct[:k].reshape(-1).float().mean().item()*100. for k in topk]


def run_epoch(loader, model, optimizer=None, epoch:int=0, phase:str="train"):
    """
    If `optimizer` is given → training mode, otherwise evaluation mode.
    Memory-safe: no graph is kept when we don't need gradients.
    """
    train = optimizer is not None
    model.train(train)

    running_loss, running_acc = 0.0, 0.0
    steps = len(loader)

    bar = tqdm(loader, desc=f"{phase.title():>5} | Epoch {epoch:02}", leave=False)

    # Choose the right context managers
    grad_ctx = nullcontext() if train else torch.no_grad()
    amp_ctx  = torch.amp.autocast(device_type="cuda",
                                  dtype=torch.float16,
                                  enabled=CFG["amp"] and torch.cuda.is_available())

    with grad_ctx:
        for images, labels in bar:
            images, labels = images.to(device, non_blocking=True), labels.to(device)

            with amp_ctx:
                outputs = model(images)
                loss    = criterion(outputs, labels)

            if train:
                scaler.scale(loss).backward()
                scaler.step(optimizer); scaler.update()
                optimizer.zero_grad()
                scheduler.step()

            running_loss += loss.item()
            running_acc  += accuracy(outputs, labels)[0]
            bar.set_postfix(loss=f"{loss.item():.4f}")

    torch.cuda.empty_cache()     # free any leftover cached blocks
    return running_loss/steps, running_acc/steps

In [None]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

In [None]:
def save_model(model, optimizer, scheduler, metrics, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict() if scheduler is not None else '',
         'metric'                   : metrics,
         'epoch'                    : epoch},
         path)


def load_model(model, optimizer=None, scheduler=None, path=f"{CFG['ckpt_dir']}/current_epoch.pth"):
    checkpoint = torch.load(path, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = None
    epoch = checkpoint['epoch']
    metrics = checkpoint['metric']
    return model, optimizer, scheduler, epoch, metrics

In [None]:
best_val_acc = 0.0
patience = 15
epoches_no_improve = 0

history = {"train_loss": [], "train_acc": [],
           "val_loss": [],   "val_acc": []}

for epoch in range(1, CFG["num_epochs"]+1):
    t0 = time.time()

    tr_loss, tr_acc = run_epoch(train_loader, model, optimizer, epoch, "train")
    val_loss, val_acc= run_epoch(val_loader,   model, None,     epoch, "val")

    history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
    history["val_loss"].append(val_loss);   history["val_acc"].append(val_acc)

    if val_acc >= best_val_acc:
        epoches_no_improve = 0
        best_val_acc = val_acc
        metrics = {
            "train_loss": tr_loss,
            "train_acc": tr_acc,
            "val_loss": val_loss,
            "val_acc": val_loss,
        }
        save_model(model, optimizer, scheduler, metrics, epoch, f"{CFG['ckpt_dir']}/best_convnext_tiny.pth")

    else:
        epoches_no_improve += 1

    save_model(model, optimizer, scheduler, metrics, epoch, f"{CFG['ckpt_dir']}/current_epoch.pth")


    print(f"Epoch {epoch:02}/{CFG['num_epochs']} "
          f"| train loss {tr_loss:.4f} acc {tr_acc:.2f}% "
          f"| val loss {val_loss:.4f} acc {val_acc:.2f}% "
          f"| lr {scheduler.get_last_lr()[0]:.2e} "
          f"| time {(time.time()-t0):.1f}s")

    if epoches_no_improve >= patience:
        print("Early stopping")
        break

In [None]:
model, optimizer, scheduler, epoch, metrics = load_model(model, optimizer, scheduler, path=f"{CFG['ckpt_dir']}/current_epoch.pth")

optimizer, scheduler, epoch, metrics

In [None]:
best_val_acc = 0.0
patience = 15
epoches_no_improve = 0

history = {"train_loss": [], "train_acc": [],
           "val_loss": [],   "val_acc": []}

for epoch in range(27, CFG["num_epochs"]+1):
    t0 = time.time()

    tr_loss, tr_acc = run_epoch(train_loader, model, optimizer, epoch, "train")
    val_loss, val_acc= run_epoch(val_loader,   model, None,     epoch, "val")

    history["train_loss"].append(tr_loss); history["train_acc"].append(tr_acc)
    history["val_loss"].append(val_loss);   history["val_acc"].append(val_acc)

    if val_acc >= best_val_acc:
        epoches_no_improve = 0
        best_val_acc = val_acc
        metrics = {
            "train_loss": tr_loss,
            "train_acc": tr_acc,
            "val_loss": val_loss,
            "val_acc": val_loss,
        }
        save_model(model, optimizer, scheduler, metrics, epoch, f"{CFG['ckpt_dir']}/best_convnext_tiny.pth")

    else:
        epoches_no_improve += 1

    save_model(model, optimizer, scheduler, metrics, epoch, f"{CFG['ckpt_dir']}/current_epoch.pth")


    print(f"Epoch {epoch:02}/{CFG['num_epochs']} "
          f"| train loss {tr_loss:.4f} acc {tr_acc:.2f}% "
          f"| val loss {val_loss:.4f} acc {val_acc:.2f}% "
          f"| lr {scheduler.get_last_lr()[0]:.2e} "
          f"| time {(time.time()-t0):.1f}s")

    if epoches_no_improve >= patience:
        print("Early stopping")
        break

In [None]:
# Load best weights
# model, optimizer, scheduler, epoch, metrics = load_model(model, optimizer, scheduler, path=f"{CFG['ckpt_dir']}/current_epoch.pth")

test_loss, test_acc = run_epoch(test_loader, model, None)
print(f"Test  - loss: {test_loss:.4f} - accuracy: {test_acc:.2f}%")

In [None]:
import matplotlib.pyplot as plt
