# Model Training and Hyperparameter Sweep

This notebook implements the training pipeline with Transfer Learning (ResNet18) and uses W&B Sweeps for hyperparameter optimization.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import wandb
import os
import copy
from pathlib import Path

# setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Define Sweep Configuration
sweep_config = {
    'method': 'bayes', # Bayesian optimization
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'   
    },
    'parameters': {
        'learning_rate': {
            'min': 0.0001,
            'max': 0.1
        },
        'batch_size': {
            'values': [32, 64, 128]
        },
        'optimizer': {
            'values': ['adam', 'sgd']
        },
        'epochs': {
            'value': 5 # Keep it small for demonstration, increase for real results
        }
    }
}

PROJECT_NAME = "cifar10_mlops_project"
sweep_id = wandb.sweep(sweep_config, project=PROJECT_NAME)

In [None]:
def build_dataset(batch_size):
    # Data Augmentation and Normalization
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    trainset = torchvision.datasets.CIFAR10(root='../data/raw', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='../data/raw', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
    
    return trainloader, testloader

def build_model():
    # Load pretrained ResNet18
    model = torchvision.models.resnet18(pretrained=True)
    
    # Freeze initial layers (optional, but good for transfer learning on small data)
    # for param in model.parameters():
    #     param.requires_grad = False
        
    # Replace last layer for CIFAR-10 (10 classes)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 10)
    
    return model.to(device)

def train(config=None):
    # Initialize a new wandb run
    with wandb.init(config=config):
        config = wandb.config
        
        trainloader, testloader = build_dataset(config.batch_size)
        model = build_model()
        
        criterion = nn.CrossEntropyLoss()
        if config.optimizer == "sgd":
            optimizer = optim.SGD(model.parameters(), lr=config.learning_rate, momentum=0.9)
        elif config.optimizer == "adam":
            optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
            
        best_acc = 0.0
        
        for epoch in range(config.epochs):
            model.train()
            running_loss = 0.0
            
            for i, (inputs, labels) in enumerate(trainloader):
                inputs, labels = inputs.to(device), labels.to(device)
                
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                
                # Log batch metrics
                wandb.log({"batch_loss": loss.item()})
            
            # Validation
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for inputs, labels in testloader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    outputs = model(inputs)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            val_acc = 100 * correct / total
            epoch_loss = running_loss / len(trainloader)
            
            # Log epoch metrics
            wandb.log({"epoch": epoch, "loss": epoch_loss, "val_acc": val_acc})
            print(f"Epoch {epoch}: Loss {epoch_loss:.3f}, Val Acc {val_acc:.2f}%")
            
            # Save best model to W&B
            if val_acc > best_acc:
                best_acc = val_acc
                
                # Create models directory
                Path("../models").mkdir(parents=True, exist_ok=True)
                
                # Save locally
                model_path = f"../models/model_best_{wandb.run.id}.pth"
                torch.save(model.state_dict(), model_path)
                
                # Log as artifact
                artifact = wandb.Artifact(f"model-best-{wandb.run.id}", type="model")
                artifact.add_file(model_path)
                wandb.log_artifact(artifact)
                print(f"New best model saved with acc: {best_acc}")

In [None]:
# Run the Sweep Agent
# count=5 means run 5 experiments
wandb.agent(sweep_id, train, count=5)