In [None]:
!pip install torchsummary torchvision tqdm wandb typing_extensions pytorch_metric_learning timm==0.9.16 einops


In [None]:
!pip freeze > requirements.txt

In [None]:
import torch
from torchsummary import summary
import torchvision
from torchvision.utils import make_grid
from torchvision import datasets, transforms as T
from torch.utils.data import DataLoader
import torch.nn.functional as F
import os
import gc
# from tqdm import tqdm
from tqdm.auto import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers
import csv
import logging
from timm.data import create_loader, create_transform
import torch.nn as nn
from functools import partial

from einops import rearrange
from einops.layers.torch import Rearrange

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)

In [None]:
# Configuration for ResNet-18 with LeRaC on CIFAR-10
config = {
  "OUTPUT_DIR": "OUTPUT/CIFAR-10-ResNet18-LeRaC",
  "WORKERS": 8,
  "PRINT_FREQ": 500,
  "AMP": {
    "ENABLED": True
  },
  "MODEL": {
    "NAME": "resnet18",
    "SPEC": {
      "NUM_CLASSES": 10
    }
  },
  "AUG": {
    "MIXUP_PROB": 1.0,
    "MIXUP": 0.8,
    "MIXCUT": 1.0,
    "TIMM_AUG": {
      "USE_LOADER": False,
      "RE_COUNT": 1,
      "RE_MODE": "pixel",
      "RE_SPLIT": False,
      "RE_PROB": 0.25,
      "AUTO_AUGMENT": "rand-m9-mstd0.5-inc1",
      "HFLIP": 0.5,
      "VFLIP": 0.0,
      "COLOR_JITTER": 0.4,
      "INTERPOLATION": "bicubic"
    }
  },
  "LOSS": {
    "LABEL_SMOOTHING": 0.1
  },
  "CUDNN": {
    "BENCHMARK": True,
    "DETERMINISTIC": False,
    "ENABLED": True
  },
  "DATASET": {
    "DATASET": "cifar-10",
    "DATA_FORMAT": "jpg",
    "ROOT": "./cifar-10",
    "TEST_SET": "val",
    "TRAIN_SET": "train"
  },
  "TEST": {
    "BATCH_SIZE_PER_GPU": 256,
    "IMAGE_SIZE": [32, 32],
    "MODEL_FILE": "",
    "INTERPOLATION": "bicubic"
  },
  "TRAIN": {
    "BATCH_SIZE_PER_GPU": 512,
    "GRADIENT_ACCUMULATION_STEPS": 4,
    "LR": 0.05,
    "IMAGE_SIZE": [32, 32],
    "BEGIN_EPOCH": 0,
    "END_EPOCH": 100,
    "LR_CURRICULUM": {
        "MIN_LR": 5e-5,
        "WARMUP_EPOCHS": 5,  #  ADDED: LeRaC warmup duration
        "C": 10.0            #  ADDED: Exponential growth factor for LeRaC
    },
    "CLF_LR_MULTIPLIER": 0.01,
    "LR_SCHEDULER": {
      "METHOD": "lerac",
      "ARGS": {
        "sched": "cosine",
        "warmup_epochs": 5,  # Used by LeRaC scheduler
        "warmup_lr": 1e-5,
        "min_lr": 1e-6,      # Minimum LR for cosine phase
        "cooldown_epochs": 0,
        "decay_rate": 0.1
      }
    },
    "OPTIMIZER": "adamW",
    "WD": 0.05,
    "WITHOUT_WD_LIST": ["bn", "bias", "ln"],
    "SHUFFLE": True
  },
  "DEBUG": {
    "DEBUG": False
  }
}

# Create output directory
import os
os.makedirs(config['OUTPUT_DIR'], exist_ok=True)
print(f"✓ Config loaded for CIFAR-10 ResNet-18 with LeRaC")





### Defining transforms

In [None]:
def build_transforms(config, is_train):
    if is_train:
        img_size = config['TRAIN']['IMAGE_SIZE'][0]
        timm_cfg = config['AUG']['TIMM_AUG']
        transforms = create_transform(
            input_size = img_size,
            is_training = True,
            use_prefetcher=False,
            no_aug=False,
            re_prob=timm_cfg['RE_PROB'],
            re_mode=timm_cfg['RE_MODE'],
            re_count=timm_cfg['RE_COUNT'],
            scale=(0.8, 1.0),
            ratio=(3.0/4.0, 4.0/3.0),
            hflip=timm_cfg['HFLIP'],
            vflip=timm_cfg['VFLIP'],
            color_jitter=timm_cfg['COLOR_JITTER'],
            auto_augment=timm_cfg['AUTO_AUGMENT'],
            interpolation=timm_cfg['INTERPOLATION'],
            mean=(0.491, 0.482, 0.446),
            std=(0.247, 0.243, 0.261),

        )
    else:
        normalize = T.Normalize(mean=[0.491, 0.482, 0.446], std=[0.247, 0.243, 0.261])
        img_size = config['TEST']['IMAGE_SIZE'][0]
        transforms = T.Compose([
            T.ToTensor(),
            normalize
        ])
    return transforms

### Building Datasets

In [None]:


def build_dataset(config, is_train):
    '''
    In the CIFAR file it will call the appropriate method
    '''
    dataset = None
    transforms = build_transforms(config, is_train)
    dataset = datasets.CIFAR10(root=config['DATASET']['ROOT'], train=is_train, download=True, transform=transforms)
    logging.info(f'load samples: {len(dataset)}, is_train: {is_train}')
    return dataset

### Building Dataloader

In [None]:
def build_dataloader(config, is_train):
    if is_train:
        batch_size_per_gpu = config['TRAIN']['BATCH_SIZE_PER_GPU']
        shuffle = True
    else:
        batch_size_per_gpu = config['TEST']['BATCH_SIZE_PER_GPU']
        shuffle = False
    dataset = build_dataset(config, is_train)
    sampler = None
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size_per_gpu,
        shuffle=shuffle,
        num_workers=config['WORKERS'],
        pin_memory=True,
        sampler=sampler,
        drop_last=is_train,
    )
    return data_loader
train_loader = build_dataloader(config, is_train=True)
val_loader = build_dataloader(config, is_train=False)

print(f"\nDataLoader Info:")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Train dataset size: {len(train_loader.dataset)}")
print(f"Val dataset size: {len(val_loader.dataset)}")
print(f"Number of classes: {len(train_loader.dataset.classes)}")

# Test loading a batch
images, labels = next(iter(train_loader))
print(f"\nBatch shapes:")
print(f"Images: {images.shape}")
print(f"Labels: {labels.shape}")

In [None]:
# ResNet-18 Model Definition
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.planes = planes

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes, first_kernel_size=3):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=first_kernel_size, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn1(out))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        spatial_size = out.size(2)
        out = F.avg_pool2d(out, spatial_size, 1)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18(num_classes, first_kernel_size=3):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, first_kernel_size=first_kernel_size)


print("ResNet-18 model defined")

In [None]:
# Create ResNet-18 model for CIFAR-10
model = ResNet18(num_classes=10, first_kernel_size=3)
print("ResNet-18 model created for CIFAR-10")

In [None]:
def count_parameters(model):
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print("="*60)
    print("Model Parameter Count")
    print("="*60)
    print(f"Total parameters: {total_params:,} ({total_params/1e6:.2f}M)")
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/1e6:.2f}M)")
    print("="*60)

    return total_params, trainable_params

# Count parameters
total_params, trainable_params = count_parameters(model)

## Setup Optimizer

In [None]:
def _is_depthwise(m):
    return (
        isinstance(m, nn.Conv2d)
        and m.groups == m.in_channels
        and m.groups == m.out_channels
    )

def set_wd(cfg, model):
    """Separate parameters by weight decay"""
    without_decay_list = cfg["TRAIN"]["WITHOUT_WD_LIST"]
    without_decay_depthwise = []
    without_decay_norm = []

    for m in model.modules():
        if _is_depthwise(m) and 'dw' in without_decay_list:
            without_decay_depthwise.append(m.weight)
        elif isinstance(m, nn.BatchNorm2d) and 'bn' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)
        elif isinstance(m, nn.GroupNorm) and 'gn' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)
        elif isinstance(m, nn.LayerNorm) and 'ln' in without_decay_list:
            without_decay_norm.append(m.weight)
            without_decay_norm.append(m.bias)

    with_decay = []
    without_decay = []

    skip = {}
    if hasattr(model, 'no_weight_decay'):
        skip = model.no_weight_decay()

    skip_keys = {}
    if hasattr(model, 'no_weight_decay_keywords'):
        skip_keys = model.no_weight_decay_keywords()

    for n, p in model.named_parameters():
        ever_set = False

        if p.requires_grad is False:
            continue

        skip_flag = False
        if n in skip:
            print('=> set {} wd to 0'.format(n))
            without_decay.append(p)
            skip_flag = True
        else:
            for i in skip:
                if i in n:
                    print('=> set {} wd to 0'.format(n))
                    without_decay.append(p)
                    skip_flag = True

        if skip_flag:
            continue

        for i in skip_keys:
            if i in n:
                print('=> set {} wd to 0'.format(n))

        if skip_flag:
            continue

        for pp in without_decay_depthwise:
            if p is pp:
                if cfg['DEBUG']['DEBUG']:
                    print('=> set depthwise({}) wd to 0'.format(n))
                without_decay.append(p)
                ever_set = True
                break

        for pp in without_decay_norm:
            if p is pp:
                if cfg['DEBUG']['DEBUG']:
                    print('=> set norm({}) wd to 0'.format(n))
                without_decay.append(p)
                ever_set = True
                break

        if (
            (not ever_set)
            and 'bias' in without_decay_list
            and n.endswith('.bias')
        ):
            if cfg['DEBUG']['DEBUG']:
                print('=> set bias({}) wd to 0'.format(n))
            without_decay.append(p)
        elif not ever_set:
            with_decay.append(p)

    params = [
        {'params': with_decay},
        {'params': without_decay, 'weight_decay': 0.}
    ]
    return params


def build_optimizer_resnet18_lerac(model, config):
    """
    Build optimizer with LeRaC (Learning Rate Curriculum) for ResNet-18
    """
    base_lr = config['TRAIN']['LR']
    min_lr = config['TRAIN']['LR_CURRICULUM']['MIN_LR']
    weight_decay = config['TRAIN']['WD']


    layers = ['layer1', 'layer2', 'layer3', 'layer4']
    blocks = [str(i) for i in range(2)]  # ResNet-18 has 2 blocks per layer

    # Build learning rate dictionary with GENTLER decay
    dictionary_lr = {}
    start_lr = base_lr * 0.5  # Start at 50% of base (instead of 10%)

    for layer in layers:
        for block in blocks:
            key = f"{layer}.{block}"
            dictionary_lr[key] = max(start_lr, min_lr)
            if start_lr >= min_lr:
                start_lr *= 0.5  # Decay by 50% each step (instead of 10%)

    # Get parameter groups with weight decay handling
    params_with_wd_info = set_wd(config, model)

    # Build parameter groups with layer-wise learning rates
    param_to_name = {id(p): n for n, p in model.named_parameters()}
    new_param_groups = []

    # Process both with_decay and without_decay groups
    for group in params_with_wd_info:
        group_params = group['params'] if isinstance(group['params'], list) else [group['params']]
        wd = group.get('weight_decay', weight_decay)

        for param in group_params:
            name = param_to_name.get(id(param), "unknown")
            assigned_lr = base_lr

            # Check if parameter belongs to a specific layer block
            if name.startswith('conv1') or name.startswith('bn1') or name.startswith('maxpool'):
                assigned_lr = base_lr
            elif 'linear' in name:
                assigned_lr = base_lr * config['TRAIN'].get('CLF_LR_MULTIPLIER', 0.01)
            else:
                # Check layer blocks
                for key in dictionary_lr:
                    if name.startswith(key):
                        assigned_lr = dictionary_lr[key]
                        break

            new_param_groups.append({
                'params': param,
                'lr': assigned_lr,
                'weight_decay': wd
            })

    optimizer = torch.optim.AdamW(new_param_groups)

    print("=" * 80)
    print(f"Optimizer: AdamW with LeRaC for ResNet-18")
    print(f"Total parameter groups: {len(new_param_groups)}")
    print(f"Base LR: {base_lr:.10f}")
    print(f"Min LR: {min_lr:.10f}")
    print(f"Weight Decay: {weight_decay}")
    print("=" * 80)
    print("Layer-wise Learning Rates (NEW):")
    print("-" * 80)

    # Print the LR schedule to verify
    print(f"conv1/bn1:       {base_lr:.6f}")
    for layer in layers:
        for block in blocks:
            key = f"{layer}.{block}"
            if key in dictionary_lr:
                print(f"{key:15s} {dictionary_lr[key]:.6f}")
    print(f"classifier:      {base_lr * config['TRAIN'].get('CLF_LR_MULTIPLIER', 0.01):.6f}")
    print("-" * 80)

    # Show sample from actual params
    import random
    sample_size = min(10, len(new_param_groups))
    sampled_indices = random.sample(range(len(new_param_groups)), sample_size)
    sampled_info = []

    for idx in sampled_indices:
        param = new_param_groups[idx]['params']
        name = param_to_name.get(id(param), "unknown")
        lr = new_param_groups[idx]['lr']
        wd = new_param_groups[idx]['weight_decay']
        sampled_info.append((name, lr, wd))

    print("\nSample Parameters:")
    for name, lr, wd in sorted(sampled_info, key=lambda x: x[1], reverse=True):
        print(f"{name:55s} LR: {lr:.6f}  WD: {wd:.4f}")
    print("=" * 80)

    return optimizer


# Build the optimizer
optimizer = build_optimizer_resnet18_lerac(model, config)


## Setup Scheduler

In [None]:
from torch.optim.lr_scheduler import _LRScheduler, SequentialLR, CosineAnnealingLR

class LeRaCScheduler(_LRScheduler):
    """
    LeRaC Warmup Scheduler: Exponential growth during warmup
    Based on Eq.(9): lr_t = init_lr * (c^t) for t in {1..warmup_epochs}
    """
    def __init__(self, optimizer, base_lr, warmup_epochs, c=10.0, last_epoch=-1):
        self.base_lr = float(base_lr)
        self.warmup_epochs = max(1, int(warmup_epochs))
        self.c = float(c)
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        # After warmup: all groups use base_lr
        if self.last_epoch >= self.warmup_epochs:
            return [self.base_lr for _ in self.optimizer.param_groups]

        # Warmup epoch counter t ∈ {1..warmup_epochs}
        t = self.last_epoch + 1

        lrs = []
        for g in self.optimizer.param_groups:
            init = float(g.get('_init_lr', g['lr']))
            # Exponential growth per Eq.(9), clamped to base_lr
            lr_t = init * (self.c ** t)
            lrs.append(min(lr_t, self.base_lr))  # Don't exceed base_lr
        return lrs


def get_scheduler(config, optimizer):
    """
    Build LeRaC scheduler: Exponential warmup → Cosine annealing
    """
    num_epochs = int(config['TRAIN']['END_EPOCH'])
    warmup_epochs = int(config['TRAIN']['LR_CURRICULUM'].get('WARMUP_EPOCHS', 5))
    base_lr = float(config['TRAIN']['LR'])
    min_lr = float(config['TRAIN']['LR_SCHEDULER']['ARGS'].get('min_lr', 1e-6))
    c = float(config['TRAIN']['LR_CURRICULUM'].get('C', 10.0))  # Growth factor

    # Store initial LRs for LeRaC warmup
    for g in optimizer.param_groups:
        if '_init_lr' not in g:
            g['_init_lr'] = g['lr']

    # Phase 1: LeRaC exponential warmup
    lerac_scheduler = LeRaCScheduler(
        optimizer,
        base_lr=base_lr,
        warmup_epochs=warmup_epochs,
        c=c,
        last_epoch=-1
    )

    # Phase 2: Cosine annealing after warmup
    cosine_epochs = max(1, num_epochs - warmup_epochs)
    cosine_scheduler = CosineAnnealingLR(
        optimizer,
        T_max=cosine_epochs,
        eta_min=min_lr
    )

    # Combine: LeRaC warmup → Cosine decay
    scheduler = SequentialLR(
        optimizer,
        schedulers=[lerac_scheduler, cosine_scheduler],
        milestones=[warmup_epochs]
    )

    print("="*80)
    print("LeRaC Scheduler Configuration")
    print("="*80)
    print(f"Phase 1 - LeRaC Warmup:")
    print(f"  Duration: {warmup_epochs} epochs")
    print(f"  Growth factor (c): {c}")
    print(f"  Target LR: {base_lr}")
    print(f"\nPhase 2 - Cosine Annealing:")
    print(f"  Duration: {cosine_epochs} epochs")
    print(f"  Start LR: {base_lr}")
    print(f"  Min LR: {min_lr}")
    print(f"\nTotal epochs: {num_epochs}")
    print("="*80)

    return scheduler


# Build scheduler for CIFAR-10
scheduler = get_scheduler(config, optimizer)

## Set Criterion

In [None]:
from timm.data.mixup import Mixup

class SoftTargetCrossEntropy(nn.Module):
    def __init__(self):
        super(SoftTargetCrossEntropy, self).__init__()

    def forward(self, x, target):
        loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
        return loss.mean()

def build_criterion(is_train=True):
    if is_train:
        return SoftTargetCrossEntropy()
    else:
        return nn.CrossEntropyLoss()

# Dynamically get number of classes from config
num_classes = config['MODEL']['SPEC']['NUM_CLASSES']

aug = config['AUG']
mixup_fn = Mixup(
    mixup_alpha=aug['MIXUP'],
    cutmix_alpha=aug['MIXCUT'],
    cutmix_minmax=None,
    prob=aug['MIXUP_PROB'],
    label_smoothing=0.0,
    num_classes=num_classes  # Dynamic based on dataset
)

criterion = build_criterion()
criterion.cuda()
criterion_eval = build_criterion(is_train=False)
criterion_eval.cuda()

print(f"✓ Criterion setup complete for {num_classes} classes")
print(f"  Mixup alpha: {aug['MIXUP']}")
print(f"  Cutmix alpha: {aug['MIXCUT']}")
print(f"  Mixup probability: {aug['MIXUP_PROB']}")

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

In [None]:
def train_one_epoch(epoch, model, train_loader, criterion, optimizer, scheduler,
                    config, scaler, mixup_fn):
    """
    Train for one epoch
    Note: scheduler is passed but NOT stepped here - stepping happens in main loop
    """
    model.train()

    losses = AverageMeter()
    acc_m = AverageMeter()

    batch_bar = tqdm(total=len(train_loader), dynamic_ncols=True,
                     leave=False, position=0, desc=f'Train Epoch {epoch+1}')

    num_updates = epoch * len(train_loader)

    for batch_idx, (images, targets) in enumerate(train_loader):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        # Apply mixup/cutmix
        if mixup_fn is not None:
            images, targets = mixup_fn(images, targets)

        with torch.cuda.amp.autocast(enabled=config['AMP']['ENABLED']):
            outputs = model(images)
            loss = criterion(outputs, targets)

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()

        # Gradient accumulation
        if (batch_idx + 1) % config['TRAIN']['GRADIENT_ACCUMULATION_STEPS'] == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        # Metrics
        losses.update(loss.item(), images.size(0))

        # Calculate accuracy (handle both hard labels and soft targets from mixup)
        if targets.dim() == 1:  # Hard labels
            acc = accuracy(outputs, targets)[0]
        else:  # Soft targets from mixup
            _, hard_targets = targets.max(dim=1)
            acc = accuracy(outputs, hard_targets)[0]

        acc_m.update(acc.item(), images.size(0))

        # Update progress bar
        batch_bar.set_postfix(
            loss=f"{losses.avg:.4f}",
            acc=f"{acc_m.avg:.2f}%",
            lr=f"{optimizer.param_groups[0]['lr']:.6f}"
        )
        batch_bar.update()

        num_updates += 1

    # Final cleanup
    if (batch_idx + 1) % config['TRAIN']['GRADIENT_ACCUMULATION_STEPS'] != 0:
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

    batch_bar.close()

    torch.cuda.empty_cache()

    return losses.avg, acc_m.avg

In [None]:
@torch.no_grad()
def validate(model, val_loader, criterion, config):
    losses = AverageMeter()
    acc_m = AverageMeter()

    model.eval()
    batch_bar = tqdm(total=len(val_loader), dynamic_ncols=True, position=0, leave=False, desc='Val.', ncols=5)

    for idx, (images, targets) in enumerate(val_loader):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=False):
            outputs = model(images)
            loss = criterion(outputs, targets)

        acc = accuracy(outputs, targets)
        losses.update(loss.item(), images.size(0))
        acc_m.update(acc[0].item(), images.size(0))
        batch_bar.set_postfix(
            acc="{:.02f}% ({:.02f}%)".format(acc[0].item(), acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item(), losses.avg))

        batch_bar.update()

    batch_bar.close()
    print(f' * Acc {acc_m.avg:.3f}')

    return losses.avg, acc_m.avg

In [None]:
def save_model(model, optimizer, scheduler, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         'scheduler_state_dict'     : scheduler.state_dict(),
         'epoch'                    : epoch},
         path)


def load_model(model, optimizer=None, scheduler=None, path='./checkpoint.pth'):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = None
    epoch = checkpoint['epoch']
    return model, optimizer, scheduler, epoch

In [None]:
os.makedirs(config['OUTPUT_DIR'], exist_ok=True)
model = model.to(DEVICE)
scaler = scaler = torch.cuda.amp.GradScaler()

In [None]:
wandb.login(key=os.environ.get('WANDB_API_KEY')) # API Key is in your wandb account, under settings (wandb.ai/settings)


In [None]:
run = wandb.init(
    name = "idl-project-cvt-13-baseline-no-curriculum-learning-cifar-10-lerac-2", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    project = "idl-project", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

In [None]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

In [None]:
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_loss = -1
start_epoch = config['TRAIN']['BEGIN_EPOCH']
end_epoch = config['TRAIN']['END_EPOCH']

print("Starting Training")
print(f"Epochs: {start_epoch} -> {end_epoch}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Device: {DEVICE}")

for epoch in range(start_epoch, end_epoch):
    # epoch
    print("\nEpoch {}/{}".format(epoch+1, end_epoch))
    train_loss, train_acc = train_one_epoch(epoch, model, train_loader, criterion, optimizer, scheduler,
                    config, scaler, mixup_fn)

    val_loss, val_acc = validate(model, val_loader, criterion_eval, config)
    scheduler.step()
    is_best = (best_loss) == -1 or val_loss < best_loss
    best_loss = min(val_loss, best_loss)
    save_model(model, optimizer, scheduler, epoch,os.path.join(config['OUTPUT_DIR'], f'{epoch}.pth'))
    if is_best:
        save_model(model, optimizer, scheduler, epoch, os.path.join(config['OUTPUT_DIR'], 'best.pth') )

    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"  Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}")
    metrics = {'train_loss': train_loss, 'val_loss': val_loss, 'train_acc': train_acc, 'val_acc': val_acc, 'epoch': epoch}
    if run is not None:
        run.log(metrics)

In [None]:
model, optimizer, scheduler, epoch = load_model(model, optimizer, scheduler, path="./OUTPUTOUTPUT/CIFAR-10-LeRaC/best.pth")

In [None]:
train_losses = []
train_accs = []
val_losses = []
val_accs = []
best_loss = -1
start_epoch = epoch
end_epoch = config['TRAIN']['END_EPOCH']

print("Starting Training")
print(f"Epochs: {start_epoch} -> {end_epoch}")
print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Device: {DEVICE}")

for epoch in range(start_epoch, end_epoch):
    # epoch
    print("\nEpoch {}/{}".format(epoch+1, end_epoch))
    train_loss, train_acc = train_one_epoch(epoch, model, train_loader, criterion, optimizer, scheduler,
                    config, scaler, mixup_fn)

    val_loss, val_acc = validate(model, val_loader, criterion_eval, config)
    is_best = (best_loss) == -1 or val_loss < best_loss
    best_loss = min(val_loss, best_loss)
    save_model(model, optimizer, scheduler, epoch,os.path.join(config['OUTPUT_DIR'], f'{epoch}.pth'))
    if is_best:
        save_model(model, optimizer, scheduler, epoch, os.path.join(config['OUTPUT_DIR'], 'best.pth') )

    print(f"Epoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"  Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}")
    metrics = {'train_loss': train_loss, 'val_loss': val_loss, 'train_acc': train_acc, 'val_acc': val_acc, 'epoch': epoch}
    if run is not None:
        run.log(metrics)

In [None]:
@torch.no_grad()
def test(model, test_loader, config):

    model.eval()

    all_predictions = []
    all_targets = []
    all_probabilities = []

    losses = AverageMeter()
    top1 = AverageMeter()

    criterion = torch.nn.CrossEntropyLoss()

    for images, targets in tqdm(test_loader, desc='Testing'):
        images = images.to(DEVICE, non_blocking=True)
        targets = targets.to(DEVICE, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(images)
            loss = criterion(outputs, targets)

        probs = torch.softmax(outputs, dim=1)

        _, preds = outputs.topk(1, 1, True, True)

        all_predictions.extend(preds.cpu().numpy().flatten())
        all_targets.extend(targets.cpu().numpy())
        all_probabilities.extend(probs.cpu().numpy())

        acc1 = accuracy(outputs, targets)
        losses.update(loss.item(), images.size(0))
        top1.update(acc1[0].item(), images.size(0))

    return {
        'predictions': np.array(all_predictions),
        'targets': np.array(all_targets),
        'probabilities': np.array(all_probabilities),
        'loss': losses.avg,
        'top1_acc': top1.avg
    }

test_results = test(model, val_loader, config)

print("Final Test Results")
print(f"Test Loss: {test_results['loss']:.4f}")
print(f"Test Acc@1: {test_results['top1_acc']:.3f}%")
