## Import neccessary pachages

In [10]:
import sys
sys.path.append("..")

from Utils.TinyImageNet_loader_new import get_tinyimagenet_dataloaders
import os
import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import DataLoader
from timm.data import Mixup
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.scheduler import CosineLRScheduler

## Define Hyperparameters

In [11]:
# General training settings
NUM_EPOCHS = 30
BATCH_SIZE = 128
IMAGE_SIZE = 384
NUM_CLASSES = 200  # Tiny ImageNet has 200 classes

# Optimizer settings
INIT_LR = 1e-3
WEIGHT_DECAY = 0.05

# Data augmentation settings
MIXUP_ALPHA = 0.8
CUTMIX_ALPHA = 1.0
RANDOM_ERASE_PROB = 0.25
LABEL_SMOOTHING = 0.1
STOCHASTIC_DEPTH = 0.1  # drop_path_rate


## Load Tiny Image Net

In [12]:
tiny_transform_train = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
    transforms.RandomErasing(p=RANDOM_ERASE_PROB)
])
tiny_transform_val = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])


tiny_transform_test = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
])



train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
    data_dir='../datasets',
    transform_train=tiny_transform_train,
    transform_val=tiny_transform_val,
    transform_test=tiny_transform_test,
    batch_size=BATCH_SIZE,
    image_size=IMAGE_SIZE
)


## Define Model

In [None]:
# Create model
model = timm.create_model(
    'swin_large_patch4_window12_384',
    pretrained=False,
    num_classes=NUM_CLASSES,
    drop_path_rate=STOCHASTIC_DEPTH
)



# Total parameters in the model
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params:,}")

print("_____________________________________________________________")

# Total trainable parameters in the model
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")


model = model.cuda()

# Define loss: cross-entropy with label smoothing
# Note: If using TIMM's mixup, label smoothing can also be handled inside Mixup,
# but if we want it explicitly here:
criterion = nn.CrossEntropyLoss()

# Define optimizer and LR scheduler
optimizer = optim.AdamW(model.parameters(), lr=INIT_LR, weight_decay=WEIGHT_DECAY)

# Cosine annealing over NUM_EPOCHS
# Using timm's CosineLRScheduler:
scheduler = CosineLRScheduler(
    optimizer,
    t_initial=NUM_EPOCHS,
    lr_min=1e-5,
    warmup_lr_init=1e-6,      # adjust if you want a warmup phase
    warmup_t=5,               # number of warmup epochs
    cycle_limit=1,
    t_in_epochs=True,
)


AssertionError: Torch not compiled with CUDA enabled

## Accuracy

In [14]:
def accuracy(output, target, topk=(1,5)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    # top maxk indices
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


## Training Loop

In [15]:
def train_one_epoch(epoch, model, train_loader, optimizer, scheduler, criterion):
    model.train()
    running_loss = 0.0
    top1_acc = 0.0
    top5_acc = 0.0
    total_samples = 0

    for step, (images, targets) in enumerate(train_loader):
        images = images.cuda(non_blocking=True)
        targets = targets.cuda(non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)

        loss = criterion(outputs, targets)  # if mixup, targets are soft
        loss.backward()
        optimizer.step()

        # Update metrics
        acc1, acc5 = accuracy(outputs, targets.argmax(dim=1) if targets.ndim == 2 else targets)
        running_loss += loss.item() * images.size(0)
        top1_acc += acc1.item() * images.size(0)
        top5_acc += acc5.item() * images.size(0)
        total_samples += images.size(0)

    scheduler.step(epoch)  # update LR schedule after each epoch

    epoch_loss = running_loss / total_samples
    epoch_top1 = top1_acc / total_samples
    epoch_top5 = top5_acc / total_samples

    print(f"Train Epoch [{epoch}] | Loss: {epoch_loss:.4f} | Top-1: {epoch_top1:.2f}% | Top-5: {epoch_top5:.2f}%")
    return epoch_loss, epoch_top1, epoch_top5


def validate(epoch, model, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    top1_acc = 0.0
    top5_acc = 0.0
    total_samples = 0

    with torch.no_grad():
        for images, targets in val_loader:
            images = images.cuda(non_blocking=True)
            targets = targets.cuda(non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, targets)

            acc1, acc5 = accuracy(outputs, targets)
            running_loss += loss.item() * images.size(0)
            top1_acc += acc1.item() * images.size(0)
            top5_acc += acc5.item() * images.size(0)
            total_samples += images.size(0)

    epoch_loss = running_loss / total_samples
    epoch_top1 = top1_acc / total_samples
    epoch_top5 = top5_acc / total_samples

    print(f"Val Epoch   [{epoch}] | Loss: {epoch_loss:.4f} | Top-1: {epoch_top1:.2f}% | Top-5: {epoch_top5:.2f}%")
    return epoch_loss, epoch_top1, epoch_top5


## Main Training Loop

best_top1 = 0.0

for epoch in range(NUM_EPOCHS):
    # --- Train ---
    train_loss, train_top1, train_top5 = train_one_epoch(
        epoch, model, train_loader, optimizer, scheduler, criterion
    )

    # --- Validate ---
    val_loss, val_top1, val_top5 = validate(
        epoch, model, val_loader, criterion
    )

    # Save checkpoint if it’s the best so far
    if val_top1 > best_top1:
        best_top1 = val_top1
        torch.save(model.state_dict(), f"swin_l_patch4_384_best.pth")

    # Optionally save at each epoch:
    # torch.save(model.state_dict(), f"swin_l_patch4_384_epoch{epoch}.pth")
