# ResNet-18 with Gaussian Blur Curriculum Learning on CIFAR-10

This notebook implements ResNet-18 with Gaussian blur curriculum learning on CIFAR-10, based on the approach from the CVT_13_blur notebook.

## Curriculum Learning Strategy
- **Early epochs**: Train with Gaussian blur applied to images
- **Later epochs**: Train with original sharp images
- **Blur parameters**: Kernel size 7, sigma 1.0, applied for first 20 epochs

## Model: ResNet-18
- Standard ResNet-18 architecture
- ~11.7M parameters
- Curriculum learning with Gaussian blur


## 1. Setup and Imports


In [None]:
%pip install torchsummary torchvision tqdm wandb


In [None]:
!pip freeze > requirements.txt

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import os
import gc
from tqdm.auto import tqdm
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import json
from datetime import datetime
import wandb
import logging

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## 2. Configuration


In [None]:
# Training configuration with blur curriculum
config = {
    'batch_size': 128,
    'num_epochs': 100,
    'learning_rate': 0.01,
    'weight_decay': 5e-4,
    'momentum': 0.9,
    'num_classes': 10,
    'image_size': 32,
    'num_workers': 4,
    'save_dir': './checkpoints_cifar10_blur',
    'use_wandb': True,
    'project_name': 'resnet18-cifar10-blur-curriculum',
    'run_name': 'blur-curriculum-training',
    'DATASET': {
        'ROOT': './data'
    },
    'BLUR': {
        'KERNEL_SIZE': 7,
        'SIGMA': 1.0,
        'EPOCHS': 20  # Number of epochs to use blur
    }
}

print("Configuration:")
for key, value in config.items():
    if isinstance(value, dict):
        print(f"  {key}:")
        for sub_key, sub_value in value.items():
            print(f"    {sub_key}: {sub_value}")
    else:
        print(f"  {key}: {value}")


## 3. Data Loading and Preprocessing with Blur Curriculum


In [4]:
def build_transforms(config, is_train):
    """Build transforms for CIFAR-10"""
    if is_train:
        transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    else:
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    return transform

def build_gaussian_transforms(config):
    """Build transforms with Gaussian blur for curriculum learning"""
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.GaussianBlur(
            kernel_size=config['BLUR']['KERNEL_SIZE'], 
            sigma=config['BLUR']['SIGMA']
        ),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    return transform

def build_dataset(config, is_train):
    '''
    Build CIFAR-10 dataset
    '''
    transforms = build_transforms(config, is_train)
    dataset = datasets.CIFAR10(
        root=config['DATASET']['ROOT'], 
        train=is_train, 
        download=True, 
        transform=transforms
    )
    logging.info(f'load samples: {len(dataset)}, is_train: {is_train}')
    return dataset

def build_gaussian_dataset(config):
    """Build dataset with Gaussian blur for curriculum learning"""
    transforms = build_gaussian_transforms(config)
    dataset = datasets.CIFAR10(
        root=config['DATASET']['ROOT'], 
        train=True, 
        download=True, 
        transform=transforms
    )
    logging.info(f'load samples: {len(dataset)}, is_train: True (blur)')
    return dataset


In [None]:
# Load the datasets
print("Loading CIFAR-10 dataset...")
train_dataset = build_dataset(config, is_train=True)
val_dataset = build_dataset(config, is_train=False)
gaussian_dataset = build_gaussian_dataset(config)

# Create data loaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=config['batch_size'], 
    shuffle=True,
    num_workers=config['num_workers'], 
    pin_memory=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=config['batch_size'], 
    shuffle=False,
    num_workers=config['num_workers'], 
    pin_memory=True
)

gaussian_loader = DataLoader(
    gaussian_dataset, 
    batch_size=config['batch_size'], 
    shuffle=True,
    num_workers=config['num_workers'], 
    pin_memory=True,
    drop_last=True
)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Gaussian blur samples: {len(gaussian_dataset)}")
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Classes: {train_dataset.classes}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")
print(f"Gaussian blur batches: {len(gaussian_loader)}")

# Test loading a batch from each loader
print("\nTesting data loaders:")
train_images, train_labels = next(iter(train_loader))
gaussian_images, gaussian_labels = next(iter(gaussian_loader))
print(f"Train batch shape: {train_images.shape}")
print(f"Gaussian batch shape: {gaussian_images.shape}")
print(f"Blur curriculum epochs: {config['BLUR']['EPOCHS']}")


## 4. ResNet-18 Model Definition


In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        # First conv layer for CIFAR-10 (32x32 input)
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        
        # ResNet layers
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        # Global average pooling and classifier
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)

# Test model creation
model = ResNet18(num_classes=config['num_classes'])
total_params = sum(p.numel() for p in model.parameters())
print(f"ResNet-18 created with {total_params:,} parameters")
print(f"Model size: {total_params/1e6:.2f}M parameters")


## 5. Training and Evaluation Functions


In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def train_epoch_with_curriculum(model, train_loader, gaussian_loader, criterion, optimizer, device, epoch, config):
    """Train for one epoch with blur curriculum"""
    model.train()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    # Choose loader based on curriculum
    if epoch < config['BLUR']['EPOCHS']:
        loader = gaussian_loader
        curriculum_type = "Blur"
    else:
        loader = train_loader
        curriculum_type = "Sharp"
    
    pbar = tqdm(loader, desc=f'Epoch {epoch+1} ({curriculum_type})')
    
    for batch_idx, (data, target) in enumerate(pbar):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        # Update metrics
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), data.size(0))
        top1.update(acc1[0], data.size(0))
        top5.update(acc5[0], data.size(0))
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{losses.avg:.4f}',
            'Acc@1': f'{top1.avg:.2f}%',
            'Acc@5': f'{top5.avg:.2f}%',
            'Type': curriculum_type
        })
    
    return losses.avg, top1.avg, top5.avg


def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            
            output = model(data)
            loss = criterion(output, target)
            
            # Update metrics
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            losses.update(loss.item(), data.size(0))
            top1.update(acc1[0], data.size(0))
            top5.update(acc5[0], data.size(0))
            
            pbar.set_postfix({
                'Loss': f'{losses.avg:.4f}',
                'Acc@1': f'{top1.avg:.2f}%',
                'Acc@5': f'{top5.avg:.2f}%'
            })
    
    return losses.avg, top1.avg, top5.avg


def adjust_learning_rate(optimizer, epoch, initial_lr, lr_decay_epochs=[60, 120, 160]):
    """Decay learning rate by 10 at specified epochs"""
    lr = initial_lr
    for decay_epoch in lr_decay_epochs:
        if epoch >= decay_epoch:
            lr *= 0.1
    
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    
    return lr

print("Training and evaluation functions defined.")


## 6. Initialize Model and Optimizer


In [None]:
# Create checkpoint directory
os.makedirs(config['save_dir'], exist_ok=True)
print(f"Checkpoint directory created: {config['save_dir']}")

# Initialize model, optimizer, and loss
model = ResNet18(num_classes=config['num_classes']).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], weight_decay=config['weight_decay'])
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=60, gamma=0.1)

print(f"Model initialized on {DEVICE}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Initial learning rate: {config['learning_rate']}")
print(f"Optimizer: SGD with momentum {config['momentum']}")
print(f"Weight decay: {config['weight_decay']}")
print(f"Blur curriculum: {config['BLUR']['EPOCHS']} epochs with blur, then sharp images")


## 7. Initialize Weights & Biases (Optional)


In [None]:
wandb.login(key=os.environ.get('WANDB_API_KEY')) # API Key is in your wandb account, under settings (wandb.ai/settings)


In [None]:
# Create your wandb run
run = wandb.init(
    name = "ResNet18 CIFAR10 Blur Curriculum", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = ### Insert specific run id here if you want to resume a previous run
    # resume = "must" ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "cifar10-blur-curriculum", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)


## 8. Training Loop with Blur Curriculum


In [None]:
# Check for existing checkpoint and resume if found
last_path = os.path.join(config['save_dir'], 'last.pth')
best_path = os.path.join(config['save_dir'], 'best_model.pth')

start_epoch = 0
best_acc1 = 0.0
best_epoch = 0

if os.path.exists(last_path):
    print(f"Found checkpoint at {last_path}. Resuming...")
    try:
        ckpt = torch.load(last_path, map_location=DEVICE)
        model.load_state_dict(ckpt['model_state_dict'])
        optimizer.load_state_dict(ckpt['optimizer_state_dict'])
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
        start_epoch = int(ckpt.get('epoch', 0)) + 1
        best_acc1 = float(ckpt.get('best_acc1', 0.0))
        print(f"Resumed from epoch {start_epoch} with best_acc1={best_acc1:.2f}%")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        print("Starting fresh training...")
        start_epoch = 0
        best_acc1 = 0.0
else:
    print("No existing checkpoint. Starting fresh.")

# Training history
history = {
    'train_loss': [],
    'train_acc1': [],
    'train_acc5': [],
    'val_loss': [],
    'val_acc1': [],
    'val_acc5': [],
    'learning_rate': [],
    'curriculum_type': []  # Track whether blur or sharp was used
}

print(f"Starting training for {config['num_epochs']} epochs...")
print(f"Blur curriculum: First {config['BLUR']['EPOCHS']} epochs with blur, then sharp images")
print(f"Training on {len(train_dataset)} samples")
print(f"Validation on {len(val_dataset)} samples")
print("="*60)


In [None]:
for epoch in range(start_epoch, config['num_epochs']):
    print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
    
    # Determine curriculum type
    curriculum_type = "Blur" if epoch < config['BLUR']['EPOCHS'] else "Sharp"
    print(f"Curriculum: {curriculum_type} images")
    
    # Adjust learning rate
    current_lr = adjust_learning_rate(optimizer, epoch, config['learning_rate'])
    
    # Train for one epoch with curriculum
    train_loss, train_acc1, train_acc5 = train_epoch_with_curriculum(
        model, train_loader, gaussian_loader, criterion, optimizer, DEVICE, epoch, config
    )
    
    # Validate
    val_loss, val_acc1, val_acc5 = validate(model, val_loader, criterion, DEVICE)
    
    # Update scheduler
    scheduler.step()
    
    # Save 'last' checkpoint every epoch
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'best_acc1': best_acc1,
        'val_acc1': val_acc1,
        'val_acc5': val_acc5,
    }, last_path)
    
    # Track best and save best model
    if val_acc1 > best_acc1:
        best_acc1 = val_acc1
        best_epoch = epoch + 1
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_acc1': best_acc1,
            'val_acc1': val_acc1,
            'val_acc5': val_acc5,
        }, best_path)
        print(f"New best model saved! Acc@1: {best_acc1:.2f}%")
    
    # Save metrics (convert tensors to floats for JSON serialization)
    history['train_loss'].append(float(train_loss))
    history['train_acc1'].append(float(train_acc1))
    history['train_acc5'].append(float(train_acc5))
    history['val_loss'].append(float(val_loss))
    history['val_acc1'].append(float(val_acc1))
    history['val_acc5'].append(float(val_acc5))
    history['learning_rate'].append(float(current_lr))
    history['curriculum_type'].append(curriculum_type)
    
    # Log to wandb
    if run is not None:
        run.log({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'train_acc1': train_acc1,
            'train_acc5': train_acc5,
            'val_loss': val_loss,
            'val_acc1': val_acc1,
            'val_acc5': val_acc5,
            'learning_rate': current_lr,
            'curriculum_type': curriculum_type
        })
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f} | Train Acc@1: {train_acc1:.2f}% | Train Acc@5: {train_acc5:.2f}%")
    print(f"Val Loss: {val_loss:.4f} | Val Acc@1: {val_acc1:.2f}% | Val Acc@5: {val_acc5:.2f}%")
    print(f"Learning Rate: {current_lr:.6f}")
    print(f"Curriculum: {curriculum_type}")
    print("-" * 60)

print(f"\nTraining completed!")
print(f"Best validation accuracy: {best_acc1:.2f}% (Epoch {best_epoch})")


## 9. Save Final Results


In [None]:
# Save final model and history
final_checkpoint = {
    'epoch': config['num_epochs'],
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_acc1': best_acc1,
    'best_epoch': best_epoch,
    'history': history,
    'config': config
}

torch.save(final_checkpoint, os.path.join(config['save_dir'], 'final_model.pth'))

# Save training history as JSON (convert any remaining tensors to floats)
def convert_to_serializable(obj):
    """Convert tensors and other non-serializable objects to Python types"""
    if isinstance(obj, dict):
        return {key: convert_to_serializable(value) for key, value in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(item) for item in obj]
    elif hasattr(obj, 'item'):  # PyTorch tensor
        return obj.item()
    elif hasattr(obj, 'tolist'):  # NumPy array
        return obj.tolist()
    else:
        return obj

# Convert history to JSON-serializable format
serializable_history = convert_to_serializable(history)

with open(os.path.join(config['save_dir'], 'training_history.json'), 'w') as f:
    json.dump(serializable_history, f, indent=2)

print(f"Final results saved to {config['save_dir']}")
print(f"Best model: {os.path.join(config['save_dir'], 'best_model.pth')}")
print(f"Final model: {os.path.join(config['save_dir'], 'final_model.pth')}")
print(f"Training history: {os.path.join(config['save_dir'], 'training_history.json')}")


## 10. Visualization


In [None]:
def plot_training_history(history, save_path=None):
    """Plot training and validation metrics with curriculum information"""
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Plot loss
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].axvline(x=config['BLUR']['EPOCHS'], color='g', linestyle='--', alpha=0.7, label='Blur→Sharp Transition')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Plot accuracy (Top-1)
    axes[0, 1].plot(epochs, history['train_acc1'], 'b-', label='Train Acc@1', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc1'], 'r-', label='Val Acc@1', linewidth=2)
    axes[0, 1].axvline(x=config['BLUR']['EPOCHS'], color='g', linestyle='--', alpha=0.7, label='Blur→Sharp Transition')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].set_title('Training and Validation Accuracy (Top-1)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Plot accuracy (Top-5)
    axes[0, 2].plot(epochs, history['train_acc5'], 'b-', label='Train Acc@5', linewidth=2)
    axes[0, 2].plot(epochs, history['val_acc5'], 'r-', label='Val Acc@5', linewidth=2)
    axes[0, 2].axvline(x=config['BLUR']['EPOCHS'], color='g', linestyle='--', alpha=0.7, label='Blur→Sharp Transition')
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('Accuracy (%)')
    axes[0, 2].set_title('Training and Validation Accuracy (Top-5)')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # Plot learning rate
    axes[1, 0].plot(epochs, history['learning_rate'], 'g-', linewidth=2)
    axes[1, 0].axvline(x=config['BLUR']['EPOCHS'], color='g', linestyle='--', alpha=0.7, label='Blur→Sharp Transition')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_title('Learning Rate Schedule')
    axes[1, 0].set_yscale('log')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # Plot curriculum type
    curriculum_colors = ['red' if ct == 'Blur' else 'blue' for ct in history['curriculum_type']]
    axes[1, 1].scatter(epochs, [1 if ct == 'Blur' else 0 for ct in history['curriculum_type']], 
                      c=curriculum_colors, alpha=0.7, s=20)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Curriculum Type')
    axes[1, 1].set_title('Curriculum Learning Schedule')
    axes[1, 1].set_yticks([0, 1])
    axes[1, 1].set_yticklabels(['Sharp', 'Blur'])
    axes[1, 1].grid(True, alpha=0.3)
    
    # Plot validation accuracy comparison
    blur_epochs = [i+1 for i, ct in enumerate(history['curriculum_type']) if ct == 'Blur']
    sharp_epochs = [i+1 for i, ct in enumerate(history['curriculum_type']) if ct == 'Sharp']
    blur_acc = [history['val_acc1'][i-1] for i in blur_epochs]
    sharp_acc = [history['val_acc1'][i-1] for i in sharp_epochs]
    
    if blur_epochs:
        axes[1, 2].plot(blur_epochs, blur_acc, 'ro-', label='Blur Curriculum', linewidth=2, markersize=4)
    if sharp_epochs:
        axes[1, 2].plot(sharp_epochs, sharp_acc, 'bo-', label='Sharp Images', linewidth=2, markersize=4)
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Validation Accuracy (%)')
    axes[1, 2].set_title('Validation Accuracy by Curriculum Phase')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Plot saved to {save_path}")
    
    plt.show()

# Plot training history
plot_training_history(history, os.path.join(config['save_dir'], 'training_plots.png'))


## 11. Final Results Summary


In [None]:
print("="*70)
print("FINAL TRAINING RESULTS - BLUR CURRICULUM")
print("="*70)
print(f"Dataset: CIFAR-10")
print(f"Model: ResNet-18")
print(f"Training Strategy: Gaussian Blur Curriculum Learning")
print(f"Blur Curriculum: First {config['BLUR']['EPOCHS']} epochs with blur, then sharp images")
print(f"Blur Parameters: Kernel size {config['BLUR']['KERNEL_SIZE']}, Sigma {config['BLUR']['SIGMA']}")
print(f"Total Epochs: {config['num_epochs']}")
print(f"Batch Size: {config['batch_size']}")
print(f"Initial Learning Rate: {config['learning_rate']}")
print(f"Weight Decay: {config['weight_decay']}")
print(f"Momentum: {config['momentum']}")
print("-"*70)
print(f"Best Validation Accuracy: {best_acc1:.2f}% (Epoch {best_epoch})")
print(f"Final Training Accuracy: {history['train_acc1'][-1]:.2f}%")
print(f"Final Validation Accuracy: {history['val_acc1'][-1]:.2f}%")
print(f"Final Training Top-5 Accuracy: {history['train_acc5'][-1]:.2f}%")
print(f"Final Validation Top-5 Accuracy: {history['val_acc5'][-1]:.2f}%")
print("-"*70)
print(f"Model Parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model Size: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
print("="*70)

# Close wandb run if active
if run is not None:
    run.finish()
    print("Wandb run finished.")


## 12. Load and Test Best Model (Optional)


In [None]:
# Load best model for testing
best_model_path = os.path.join(config['save_dir'], 'best_model.pth')
if os.path.exists(best_model_path):
    print("Loading best model for final evaluation...")
    
    # Load checkpoint
    checkpoint = torch.load(best_model_path)
    
    # Create model and load weights
    best_model = ResNet18(num_classes=config['num_classes']).to(DEVICE)
    best_model.load_state_dict(checkpoint['model_state_dict'])
    
    # Evaluate on validation set
    print("Evaluating best model on validation set...")
    val_loss, val_acc1, val_acc5 = validate(best_model, val_loader, criterion, DEVICE)
    
    print(f"Best Model Results:")
    print(f"  Validation Loss: {val_loss:.4f}")
    print(f"  Validation Acc@1: {val_acc1:.2f}%")
    print(f"  Validation Acc@5: {val_acc5:.2f}%")
    print(f"  Best epoch: {checkpoint['epoch']}")
else:
    print("Best model checkpoint not found.")


## 13. Cleanup and Summary


In [None]:
# Clean up GPU memory
gc.collect()
torch.cuda.empty_cache()

print("Training completed successfully!")
print(f"\nCheckpoint directory: {config['save_dir']}")
print("Files created:")
for file in os.listdir(config['save_dir']):
    file_path = os.path.join(config['save_dir'], file)
    if os.path.isfile(file_path):
        size = os.path.getsize(file_path) / (1024*1024)  # Size in MB
        print(f"  - {file} ({size:.1f} MB)")

print("\nThis blur curriculum implementation can be compared with the baseline ResNet-18.")
print("Key differences:")
print(f"- First {config['BLUR']['EPOCHS']} epochs use Gaussian blur (kernel={config['BLUR']['KERNEL_SIZE']}, sigma={config['BLUR']['SIGMA']})")
print("- Remaining epochs use sharp images")
print("- Same model architecture and hyperparameters as baseline")
