<a href="https://colab.research.google.com/github/KaushikKoirala/hybrid-curriculum-learning/blob/main/imagenet_lerac.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchsummary torchvision tqdm wandb typing_extensions pytorch_metric_learning timm==0.9.16 einops


In [None]:
!pip freeze > requirements.txt

In [None]:
import torch
from torchsummary import summary
import torchvision
from torchvision.utils import make_grid
from torchvision import datasets, transforms as T
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import gc
# from tqdm import tqdm
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers
import csv
import logging
from timm.data import create_loader, create_transform
import torch.nn as nn
from functools import partial

from einops import rearrange
from einops.layers.torch import Rearrange

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

In [None]:
# Configuration for ResNet-18 with LeRaC on ImageNet100
config = {
  "OUTPUT_DIR": "OUTPUT/ImageNet100-ResNet18-LeRaC",
  "WORKERS": 8,
  "PRINT_FREQ": 500,
  "AMP": {
    "ENABLED": True
  },
  "MODEL": {
    "NAME": "resnet18",
    "SPEC": {
      "NUM_CLASSES": 100
    }
  },
  "AUG": {
    "MIXUP_PROB": 1.0,
    "MIXUP": 0.8,
    "MIXCUT": 1.0,
    "TIMM_AUG": {
      "USE_LOADER": False,
      "RE_COUNT": 1,
      "RE_MODE": "pixel",
      "RE_SPLIT": False,
      "RE_PROB": 0.25,
      "AUTO_AUGMENT": "rand-m9-mstd0.5-inc1",
      "HFLIP": 0.5,
      "VFLIP": 0.0,
      "COLOR_JITTER": 0.4,
      "INTERPOLATION": "bicubic"
    }
  },
  "LOSS": {
    "LABEL_SMOOTHING": 0.1
  },
  "CUDNN": {
    "BENCHMARK": True,
    "DETERMINISTIC": False,
    "ENABLED": True
  },
  "DATASET": {
    "DATASET": "imagenet",
    "DATA_FORMAT": "jpg",
    "ROOT": "./ImageNet100_224",
    "TEST_SET": "val",
    "TRAIN_SET": "train"
  },
  "TEST": {
    "BATCH_SIZE_PER_GPU": 64,
    "IMAGE_SIZE": [224, 224],
    "MODEL_FILE": "",
    "INTERPOLATION": "bicubic"
  },
  "TRAIN": {
    "BATCH_SIZE_PER_GPU": 128,
    "GRADIENT_ACCUMULATION_STEPS": 4,
    "LR": 0.05,  # ⭐ SAME AS CIFAR-10! (10x higher than before)
    "IMAGE_SIZE": [224, 224],
    "BEGIN_EPOCH": 0,
    "END_EPOCH": 100,
    "LR_CURRICULUM": {
        "MIN_LR": 5e-5,  # ⭐ ADJUSTED: Higher min_lr to match scale
        "WARMUP_EPOCHS": 5,
        "C": 10.0
    },
    "CLF_LR_MULTIPLIER": 0.01,
    "LR_SCHEDULER": {
      "METHOD": "lerac",
      "ARGS": {
        "sched": "cosine",
        "warmup_epochs": 5,
        "warmup_lr": 1e-5,  # ⭐ ADJUSTED: Higher to match LR scale
        "min_lr": 1e-6,     # ⭐ ADJUSTED: Floor for cosine decay
        "cooldown_epochs": 0,
        "decay_rate": 0.1
      }
    },
    "OPTIMIZER": "adamW",
    "WD": 0.05,
    "WITHOUT_WD_LIST": ["bn", "bias", "ln"],
    "SHUFFLE": True
  },
  "DEBUG": {
    "DEBUG": False
  }
}

# Create output directory
import os
os.makedirs(config['OUTPUT_DIR'], exist_ok=True)

print("="*80)
print("✓ Config loaded for ResNet-18 with LeRaC on ImageNet100")
print("="*80)
print(f"Base LR: {config['TRAIN']['LR']} (SAME as CIFAR-10)")
print(f"Batch size: {config['TRAIN']['BATCH_SIZE_PER_GPU']}")
print(f"Gradient accumulation: {config['TRAIN']['GRADIENT_ACCUMULATION_STEPS']}")
print(f"Effective batch size: {config['TRAIN']['BATCH_SIZE_PER_GPU'] * config['TRAIN']['GRADIENT_ACCUMULATION_STEPS']}")
print(f"Min LR: {config['TRAIN']['LR_CURRICULUM']['MIN_LR']}")
print("="*80)

### Defining transforms

In [None]:
def build_transforms(config, is_train):
    if is_train:
        img_size = config['TRAIN']['IMAGE_SIZE'][0]
        timm_cfg = config['AUG']['TIMM_AUG']
        transforms = create_transform(
            input_size = img_size,
            is_training = True,
            use_prefetcher=False,
            no_aug=False,
            re_prob=timm_cfg['RE_PROB'],
            re_mode=timm_cfg['RE_MODE'],
            re_count=timm_cfg['RE_COUNT'],
            scale=(0.08, 1.0),
            ratio=(3.0/4.0, 4.0/3.0),
            hflip=timm_cfg['HFLIP'],
            vflip=timm_cfg['VFLIP'],
            color_jitter=timm_cfg['COLOR_JITTER'],
            auto_augment=timm_cfg['AUTO_AUGMENT'],
            interpolation=timm_cfg['INTERPOLATION'],
            mean=(0.485, 0.456, 0.406),
            std=(0.229, 0.224, 0.225),

        )
    else:
        normalize = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        img_size = config['TEST']['IMAGE_SIZE'][0]
        transforms = T.Compose([
            T.Resize(int(img_size/ 0.875), interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
            T.CenterCrop(img_size),
            T.ToTensor(),
            normalize
        ])
    return transforms

### Building Datasets

In [None]:
def _build_imagenet_dataset(config, is_train):
    """Build ImageNet dataset from directory structure"""
    transforms = build_transforms(config, is_train)
    dataset_name = config['DATASET']['TRAIN_SET'] if is_train else config['DATASET']['TEST_SET']
    dataset_path = os.path.join(config['DATASET']['ROOT'], dataset_name)

    print(f"Loading dataset from: {dataset_path}")
    dataset = datasets.ImageFolder(dataset_path, transforms)

    logging.info(f'Loaded {len(dataset)} samples, is_train: {is_train}')
    print(f"✓ Dataset loaded: {len(dataset)} images, {len(dataset.classes)} classes")

    return dataset


def build_dataset(config, is_train):
    """
    Build dataset - calls ImageFolder for ImageNet structure
    """
    dataset = _build_imagenet_dataset(config, is_train)
    return dataset

### Building Dataloader

In [None]:
def build_dataloader(config, is_train):
    """Build data loader for ImageNet"""
    if is_train:
        batch_size_per_gpu = config['TRAIN']['BATCH_SIZE_PER_GPU']
        shuffle = True
    else:
        batch_size_per_gpu = config['TEST']['BATCH_SIZE_PER_GPU']
        shuffle = False

    dataset = build_dataset(config, is_train)
    sampler = None


    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_per_gpu,
        shuffle=shuffle,
        num_workers=config['WORKERS'],
        pin_memory=True,
        sampler=sampler,
        drop_last=is_train,  # Drop last incomplete batch in training
    )

    return data_loader


# Build dataloaders
print("\n" + "="*60)
print("Building DataLoaders...")
print("="*60)

train_loader = build_dataloader(config, is_train=True)
val_loader = build_dataloader(config, is_train=False)

print(f"\nDataLoader Info:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Train dataset size: {len(train_loader.dataset)}")
print(f"  Val dataset size: {len(val_loader.dataset)}")
print(f"  Number of classes: {len(train_loader.dataset.classes)}")
print(f"  Train batch size: {config['TRAIN']['BATCH_SIZE_PER_GPU']}")
print(f"  Val batch size: {config['TEST']['BATCH_SIZE_PER_GPU']}")

# Verify data loading
print("\n" + "="*60)
print("Testing Data Loading...")
print("="*60)

images, labels = next(iter(train_loader))
print(f"✓ Batch loaded successfully!")
print(f"  Images shape: {images.shape}")
print(f"  Labels shape: {labels.shape}")
print(f"  Image dtype: {images.dtype}")
print(f"  Label range: [{labels.min()}, {labels.max()}]")

# Verify directory structure
train_path = os.path.join(config['DATASET']['ROOT'], config['DATASET']['TRAIN_SET'])
val_path = os.path.join(config['DATASET']['ROOT'], config['DATASET']['TEST_SET'])
print(f"\nDataset Directories:")
print(f"  Train: {train_path} - Exists: {os.path.exists(train_path)}")
print(f"  Val: {val_path} - Exists: {os.path.exists(val_path)}")
print("="*60)

In [None]:
# ResNet-18 Model Definition for ImageNet
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # Initial convolution for ImageNet (7x7 kernel, stride 2)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # ResNet layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.maxpool(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes=100):
    """ResNet-18 for ImageNet"""
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

print("ResNet-18 model defined for ImageNet")

In [None]:
# Create ResNet-18 model
print("\n" + "="*60)
print("Creating ResNet-18 Model...")
print("="*60)

num_classes = config['MODEL']['SPEC']['NUM_CLASSES']
model = ResNet18(num_classes=num_classes)

print(f"✓ ResNet-18 created with {num_classes} output classes")

In [None]:
def count_parameters(model):
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("="*60)
    print("Model Parameter Count")
    print("="*60)
    print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
    print("="*60)

    return total_params, trainable_params

# Count parameters
total_params, trainable_params = count_parameters(model)

## Setup Optimizer

In [None]:
def _is_depthwise(m):
    return (
        isinstance(m, nn.Conv2d)
        and m.groups == m.in_channels
        and m.groups == m.out_channels
    )

def set_wd(cfg, model):
    """Separate parameters by weight decay"""
    without_decay_list = cfg["TRAIN"]["WITHOUT_WD_LIST"]
    without_decay_depthwise = []
    without_decay_norm = []

    for m in model.modules():
        if _is_depthwise(m) and 'dw' in without_decay_list:
            without_decay_depthwise.append(m.weight)
        elif isinstance(m, nn.BatchNorm2d) and 'bn' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)
        elif isinstance(m, nn.GroupNorm) and 'gn' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)
        elif isinstance(m, nn.LayerNorm) and 'ln' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)

    with_decay = []
    without_decay = []

    skip = {}
    if hasattr(model, 'no_weight_decay'):
        skip = model.no_weight_decay()

    skip_keys = {}
    if hasattr(model, 'no_weight_decay_keywords'):
        skip_keys = model.no_weight_decay_keywords()

    for n, p in model.named_parameters():
        ever_set = False

        if p.requires_grad is False:
            continue

        skip_flag = False
        if n in skip:
            print('=> set {} wd to 0'.format(n))
            without_decay.append(p)
            skip_flag = True
        else:
            for i in skip:
                if i in n:
                    print('=> set {} wd to 0'.format(n))
                    without_decay.append(p)
                    skip_flag = True

        if skip_flag:
            continue

        for i in skip_keys:
            if i in n:
                print('=> set {} wd to 0'.format(n))

        if skip_flag:
            continue

        for pp in without_decay_depthwise:
            if p is pp:
                if cfg['DEBUG']['DEBUG']:
                    print('=> set depthwise({}) wd to 0'.format(n))
                without_decay.append(p)
                ever_set = True
                break

        for pp in without_decay_norm:
            if p is pp:
                if cfg['DEBUG']['DEBUG']:
                    print('=> set norm({}) wd to 0'.format(n))
                without_decay.append(p)
                ever_set = True
                break

        if (
            (not ever_set)
            and 'bias' in without_decay_list
            and n.endswith('.bias')
        ):
            if cfg['DEBUG']['DEBUG']:
                print('=> set bias({}) wd to 0'.format(n))
            without_decay.append(p)
        elif not ever_set:
            with_decay.append(p)

    params = [
        {'params': with_decay},
        {'params': without_decay, 'weight_decay': 0.}
    ]
    return params


def build_optimizer_resnet18_lerac(model, config):
    """
    Build optimizer with LeRaC (Learning Rate Curriculum) for ResNet-18 on ImageNet
    """
    base_lr = config['TRAIN']['LR']
    min_lr = config['TRAIN']['LR_CURRICULUM']['MIN_LR']
    weight_decay = config['TRAIN']['WD']

    # ResNet-18 layer structure: conv1, bn1, maxpool, layer1, layer2, layer3, layer4, linear
    # Each layer has 2 blocks (indexed 0, 1)
    layers = ['layer1', 'layer2', 'layer3', 'layer4']
    blocks = [str(i) for i in range(2)]  # ResNet-18 has 2 blocks per layer

    # Build learning rate dictionary with curriculum
    dictionary_lr = {}
    start_lr = base_lr * 0.5  # Start at 50% of base for earlier layers

    for layer in layers:
        for block in blocks:
            key = f"{layer}.{block}"
            dictionary_lr[key] = max(start_lr, min_lr)
            if start_lr >= min_lr:
                start_lr *= 0.5  # Decay by 50% each step

    # Get parameter groups with weight decay handling
    params_with_wd_info = set_wd(config, model)

    # Build parameter groups with layer-wise learning rates
    param_to_name = {id(p): n for n, p in model.named_parameters()}
    new_param_groups = []

    # Process both with_decay and without_decay groups
    for group in params_with_wd_info:
        group_params = group['params'] if isinstance(group['params'], list) else [group['params']]
        wd = group.get('weight_decay', weight_decay)

        for param in group_params:
            name = param_to_name.get(id(param), "unknown")
            assigned_lr = base_lr

            # Check if parameter belongs to a specific layer block
            if name.startswith('conv1') or name.startswith('bn1') or name.startswith('maxpool'):
                assigned_lr = base_lr
            elif 'linear' in name or 'fc' in name:
                assigned_lr = base_lr * config['TRAIN'].get('CLF_LR_MULTIPLIER', 0.01)
            else:
                # Check layer blocks
                for key in dictionary_lr:
                    if name.startswith(key):
                        assigned_lr = dictionary_lr[key]
                        break

            new_param_groups.append({
                'params': param,
                'lr': assigned_lr,
                'weight_decay': wd
            })

    optimizer = torch.optim.AdamW(new_param_groups)

    print("=" * 80)
    print(f"Optimizer: AdamW with LeRaC for ResNet-18 (ImageNet)")
    print(f"Total parameter groups: {len(new_param_groups)}")
    print(f"Base LR: {base_lr:.10f}")
    print(f"Min LR: {min_lr:.10f}")
    print(f"Weight Decay: {weight_decay}")
    print("=" * 80)
    print("Layer-wise Learning Rates:")
    print("-" * 80)

    print(f"conv1/bn1:       {base_lr:.6f}")
    for layer in layers:
        for block in blocks:
            key = f"{layer}.{block}"
            if key in dictionary_lr:
                print(f"{key:15s} {dictionary_lr[key]:.6f}")
    print(f"classifier:      {base_lr * config['TRAIN'].get('CLF_LR_MULTIPLIER', 0.01):.6f}")
    print("=" * 80)

    return optimizer

# Build optimizer
optimizer = build_optimizer_resnet18_lerac(model, config)

## Setup Scheduler

In [None]:
from torch.optim.lr_scheduler import _LRScheduler, SequentialLR, CosineAnnealingLR

class LeRaCScheduler(_LRScheduler):
    """
    LeRaC Warmup Scheduler: Exponential growth during warmup
    Based on Eq.(9): lr_t = init_lr * (c^t) for t in {1..warmup_epochs}
    """
    def __init__(self, optimizer, base_lr, warmup_epochs, c=10.0, last_epoch=-1):
        self.base_lr = float(base_lr)
        self.warmup_epochs = max(1, int(warmup_epochs))
        self.c = float(c)
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        # After warmup: all groups use base_lr
        if self.last_epoch >= self.warmup_epochs:
            return [self.base_lr for _ in self.optimizer.param_groups]

        # Warmup epoch counter t ∈ {1..warmup_epochs}
        t = self.last_epoch + 1

        lrs = []
        for g in self.optimizer.param_groups:
            init = float(g.get('_init_lr', g['lr']))
            # Exponential growth per Eq.(9), clamped to base_lr
            lr_t = init * (self.c ** t)
            lrs.append(min(lr_t, self.base_lr))  # Don't exceed base_lr
        return lrs


def get_scheduler(config, optimizer):
    """
    Build LeRaC scheduler: Exponential warmup → Cosine annealing
    """
    num_epochs = int(config['TRAIN']['END_EPOCH'])
    warmup_epochs = int(config['TRAIN']['LR_CURRICULUM'].get('WARMUP_EPOCHS', 5))
    base_lr = float(config['TRAIN']['LR'])
    min_lr = float(config['TRAIN']['LR_SCHEDULER']['ARGS'].get('min_lr', 1e-7))
    c = float(config['TRAIN']['LR_CURRICULUM'].get('C', 10.0))  # Growth factor

    # Store initial LRs for LeRaC warmup
    for g in optimizer.param_groups:
        if '_init_lr' not in g:
            g['_init_lr'] = g['lr']

    # Phase 1: LeRaC exponential warmup
    lerac_scheduler = LeRaCScheduler(
        optimizer,
        base_lr=base_lr,
        warmup_epochs=warmup_epochs,
        c=c,
        last_epoch=-1
    )

    # Phase 2: Cosine annealing after warmup
    cosine_epochs = max(1, num_epochs - warmup_epochs)
    cosine_scheduler = CosineAnnealingLR(
        optimizer,
        T_max=cosine_epochs,
        eta_min=min_lr
    )

    # Combine: LeRaC warmup → Cosine decay
    scheduler = SequentialLR(
        optimizer,
        schedulers=[lerac_scheduler, cosine_scheduler],
        milestones=[warmup_epochs]
    )

    print("="*80)
    print("LeRaC Scheduler Configuration (ImageNet100)")
    print("="*80)
    print(f"Phase 1 - LeRaC Warmup:")
    print(f"  Duration: {warmup_epochs} epochs")
    print(f"  Growth factor (c): {c}")
    print(f"  Target LR: {base_lr}")
    print(f"\nPhase 2 - Cosine Annealing:")
    print(f"  Duration: {cosine_epochs} epochs")
    print(f"  Start LR: {base_lr}")
    print(f"  Min LR: {min_lr}")
    print(f"\nTotal epochs: {num_epochs}")
    print("="*80)

    return scheduler


# Build scheduler for ImageNet100
scheduler = get_scheduler(config, optimizer)

## Set Criterion

In [None]:
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy

def build_criterion(config):
    """
    Build loss criterion with label smoothing
    """
    if config['AUG']['MIXUP'] > 0.:
        # Use soft target cross entropy for mixup/cutmix
        criterion = SoftTargetCrossEntropy()
    elif config['LOSS']['LABEL_SMOOTHING'] > 0.:
        # Use label smoothing
        criterion = LabelSmoothingCrossEntropy(smoothing=config['LOSS']['LABEL_SMOOTHING'])
    else:
        # Standard cross entropy
        criterion = torch.nn.CrossEntropyLoss()

    return criterion

# Build criterion for training (with mixup support)
criterion = build_criterion(config)

# Build criterion for validation (no mixup/label smoothing)
criterion_eval = torch.nn.CrossEntropyLoss()

print(f"\nCriterion Info:")
print(f"Training: Label Smoothing = {config['LOSS']['LABEL_SMOOTHING']}")
print(f"Training: Mixup/Cutmix enabled = {config['AUG']['MIXUP'] > 0.}")
print(f"Validation: Standard CrossEntropyLoss")

In [None]:
from timm.data.mixup import Mixup

def build_mixup(config):
    """
    Build mixup/cutmix augmentation
    """
    mixup_fn = None
    mixup_active = config['AUG']['MIXUP'] > 0 or config['AUG']['MIXCUT'] > 0.

    if mixup_active:
        mixup_fn = Mixup(
            mixup_alpha=config['AUG']['MIXUP'],
            cutmix_alpha=config['AUG']['MIXCUT'],
            prob=config['AUG']['MIXUP_PROB'],
            switch_prob=0.5,
            mode='batch',
            label_smoothing=config['LOSS']['LABEL_SMOOTHING'],
            num_classes=config['MODEL']['SPEC']['NUM_CLASSES']
        )

    return mixup_fn

# Build mixup function
mixup_fn = build_mixup(config)
print(f"\nMixup Info:")
print(f"Mixup alpha: {config['AUG']['MIXUP']}")
print(f"Cutmix alpha: {config['AUG']['MIXCUT']}")
print(f"Mixup probability: {config['AUG']['MIXUP_PROB']}")

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

In [None]:
def train_one_epoch(epoch, model, train_loader, criterion, optimizer, scheduler,
                    config, scaler, mixup_fn):
    """
    Train for one epoch
    Following: https://github.com/microsoft/CvT/blob/main/lib/core/function.py
    Note: scheduler is passed but NOT stepped here - stepping happens in main loop
    """
    losses = AverageMeter()
    acc_m = AverageMeter()
    model.train()

    accumulation_steps = config['TRAIN'].get('GRADIENT_ACCUMULATION_STEPS', 1)
    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)

    for idx, (images, targets) in enumerate(train_loader):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        # Apply mixup
        images, targets = mixup_fn(images, targets)

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss = loss / accumulation_steps

        # Backward pass
        scaler.scale(loss).backward()

        # Gradient accumulation step
        if (idx + 1) % accumulation_steps == 0:
            scaler.unscale_(optimizer)

            # Optional: Print gradient norms for debugging
            if idx % 100 == 0 and epoch >= 6:
                total_norm = 0
                for p in model.parameters():
                    if p.grad is not None:
                        total_norm += p.grad.norm().item() ** 2
                total_norm = total_norm ** 0.5
                print(f"Epoch {epoch}, Batch {idx}, Grad norm: {total_norm:.4f}")

            # Gradient clipping (manual + torch utility)
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.data.clamp_(-0.5, 0.5)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)

            # Optimizer step
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Calculate accuracy
        targets_for_acc = torch.argmax(targets, dim=1)
        acc = accuracy(outputs, targets_for_acc)

        # Update metrics
        losses.update(loss.item() * accumulation_steps, images.size(0))
        acc_m.update(acc[0].item(), images.size(0))

        # Update progress bar
        batch_bar.set_postfix(
            acc="{:.02f}% ({:.02f}%)".format(acc[0].item(), acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item() * accumulation_steps, losses.avg),
            lr="{:.06f}".format(float(optimizer.param_groups[0]['lr']))
        )
        batch_bar.update()

    # Handle final accumulated gradients if any
    if (idx + 1) % accumulation_steps != 0:
        scaler.unscale_(optimizer)
        for param in model.parameters():
            if param.grad is not None:
                param.grad.data.clamp_(-0.5, 0.5)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    batch_bar.close()


    torch.cuda.empty_cache()

    return losses.avg, acc_m.avg

In [None]:
@torch.no_grad()
def validate(model, val_loader, criterion, config):
    losses = AverageMeter()
    acc_m = AverageMeter()

    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val.', ncols=5)

    for idx, (images, targets) in enumerate(val_loader):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=False):
            outputs = model(images)
            loss = criterion(outputs, targets)

        acc = accuracy(outputs, targets)
        losses.update(loss.item(), images.size(0))
        acc_m.update(acc[0].item(), images.size(0))
        batch_bar.set_postfix(
            acc="{:.02f}% ({:.02f}%)".format(acc[0].item(), acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item(), losses.avg))

        batch_bar.update()

    batch_bar.close()
    print(f' * Acc {acc_m.avg:.3f}')

    return losses.avg, acc_m.avg

In [None]:
def save_model(model, optimizer, scheduler, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         'epoch'                    : epoch},
         path)


def load_model(model, optimizer=None, scheduler=None, path='./checkpoint.pth'):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = None
    epoch = checkpoint['epoch']
    return model, optimizer, scheduler, epoch

In [None]:
os.makedirs(config['OUTPUT_DIR'], exist_ok=True)
model = model.to(DEVICE)
scaler = scaler = torch.cuda.amp.GradScaler()

In [None]:
wandb.login(key=os.environ.get('WANDB_API_KEY')) # API Key is in your wandb account, under settings (wandb.ai/settings)


In [None]:
# For a NEW run (ResNet-18 with LeRaC)
run = wandb.init(
    name="imagenet100-resnet18-lerac-v1",  # NEW name for your ResNet run
    project="idl-project",
    config=config
)

In [None]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

In [None]:
# Training loop with checkpoint saving every 10 epochs
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_loss = -1
start_epoch = config['TRAIN']['BEGIN_EPOCH']
end_epoch = config['TRAIN']['END_EPOCH']

print("Starting Training")
print(f"Epochs: {start_epoch} -> {end_epoch}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Device: {DEVICE}")
print(f"Checkpoint save frequency: Every 10 epochs")

for epoch in range(start_epoch, end_epoch):
    # Training
    print(f"\nEpoch {epoch+1}/{end_epoch}")
    train_loss, train_acc = train_one_epoch(
        epoch, model, train_loader, criterion, optimizer, scheduler,
        config, scaler, mixup_fn
    )

    # Validation
    val_loss, val_acc = validate(model, val_loader, criterion_eval, config)

    scheduler.step()
    # Check if best model
    is_best = (best_loss == -1) or (val_loss < best_loss)
    if is_best:
        best_loss = val_loss

    # Save checkpoint every 10 epochs OR if it's the best model
    if (epoch + 1) % 10 == 0:
        checkpoint_path = os.path.join(config['OUTPUT_DIR'], f'epoch_{epoch+1}.pth')
        save_model(model, optimizer, scheduler, epoch, checkpoint_path)
        print(f"✓ Saved checkpoint: epoch_{epoch+1}.pth")

    # Always save best model
    if is_best:
        best_path = os.path.join(config['OUTPUT_DIR'], 'best.pth')
        save_model(model, optimizer, scheduler, epoch, best_path)
        print(f"✓ Saved best model: best.pth (Val Loss: {val_loss:.4f})")

    # Print summary
    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"  Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}")
    print(f"  Best Val Loss: {best_loss:.4f}")

    # Log to wandb
    metrics = {
        'train_loss': train_loss,
        'val_loss': val_loss,
        'train_acc': train_acc,
        'val_acc': val_acc,
        'epoch': epoch,
        'best_val_loss': best_loss
    }
    if run is not None:
        run.log(metrics)

print("\n" + "="*80)
print("Training Complete!")
print(f"Best Validation Loss: {best_loss:.4f}")
print("="*80)

In [None]:
import os

# ========================================
# STEP 1: LOAD CHECKPOINT FROM Best Epoch
# ========================================
resume_path = os.path.join(config['OUTPUT_DIR'], 'best.pth')

print("="*80)
print(f"Loading checkpoint: {resume_path}")
print("="*80)

model, optimizer, scheduler, loaded_epoch = load_model(
    model, optimizer, scheduler,
    path=resume_path
)

# Set start epoch for training loop
start_epoch = loaded_epoch + 1
end_epoch = config['TRAIN']['END_EPOCH']

print(f"✓ Checkpoint loaded successfully!")
print(f"   Completed epoch: {loaded_epoch}")
print(f"   Resuming from: {start_epoch}")
print(f"   Training until: {end_epoch}")
print(f"   Remaining epochs: {end_epoch - start_epoch}")
print("="*80 + "\n")

In [None]:
@torch.no_grad()
def test(model, test_loader, config):

    model.eval()

    all_predictions = []
    all_targets = []
    all_probabilities = []

    losses = AverageMeter()
    top1 = AverageMeter()

    criterion = torch.nn.CrossEntropyLoss()

    for images, targets in tqdm(test_loader, desc='Testing'):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images)
            loss = criterion(outputs, targets)

        probs = torch.softmax(outputs, dim=1)

        _, preds = outputs.topk(1, 1, True, True)

        all_predictions.extend(preds.cpu().numpy().flatten())
        all_targets.extend(targets.cpu().numpy())
        all_probabilities.extend(probs.cpu().numpy())

        acc1 = accuracy(outputs, targets)
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0].item(), images.size(0))

    return {
        'predictions': np.array(all_predictions),
        'targets': np.array(all_targets),
        'probabilities': np.array(all_probabilities),
        'loss': losses.avg,
        'top1_acc': top1.avg
    }

test_results = test(model, val_loader, config)

print("Final Test Results")
print(f"Test Loss: {test_results['loss']:.4f}")
print(f"Test Acc@1: {test_results['top1_acc']:.3f}%")
