# Model Training Experiment

This experiment trains either a VGG11 or ResNet18 model on the CIFAR10 or MNIST dataset, and compares the performance of three optimizers:
- **DoWG**
- **NDoWG**
- **Nesterov**

We will measure and compare their training and validation accuracy and loss.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import matplotlib.pyplot as plt
import numpy as np

from dowg.dowg import DoWG, NDoWG  # Assuming these are implemented in dowg.py


In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters
batch_size = 64
epochs = 10
learning_rate = 0.01

# User selection for model and dataset
model_name = 'VGG11'  # Options: 'VGG11', 'ResNet18'
dataset_name = 'MNIST'  # Options: 'CIFAR10', 'MNIST'


In [None]:
# Data loading
if dataset_name == 'CIFAR10':
    num_classes = 10
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
elif dataset_name == 'MNIST':
    num_classes = 10
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=3),  # For compatibility with VGG/ResNet
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
else:
    raise ValueError('Unknown dataset')

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)


In [None]:
# Model selection
def get_model(model_name, num_classes):
    if model_name == 'VGG11':
        model = models.vgg11(num_classes=num_classes)
    elif model_name == 'ResNet18':
        model = models.resnet18(num_classes=num_classes)
    else:
        raise ValueError('Unknown model')
    return model

model = get_model(model_name, num_classes).to(device)


In [None]:
# Training and evaluation functions
def train_and_evaluate(model, optimizer_class, optimizer_kwargs, epochs=10):
    model = get_model(model_name, num_classes).to(device)  # fresh model for each optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optimizer_class(model.parameters(), **optimizer_kwargs)
    train_losses, test_losses, train_accs, test_accs = [], [], [], []
    for epoch in range(epochs):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
        train_loss = running_loss / total
        train_acc = correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        # Validation
        model.eval()
        test_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for inputs, labels in testloader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                test_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()
        test_loss = test_loss / total
        test_acc = correct / total
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.4f}")
    return train_losses, test_losses, train_accs, test_accs


In [None]:
# Run experiments with each optimizer
optimizers = {
    'DoWG': (DoWG, {'lr': learning_rate}),
    'NDoWG': (NDoWG, {'lr': learning_rate}),
    'Nesterov': (optim.SGD, {'lr': learning_rate, 'momentum': 0.9, 'nesterov': True})
}
results = {}
for name, (opt_class, opt_kwargs) in optimizers.items():
    print(f"\nTraining with {name} optimizer...")
    train_losses, test_losses, train_accs, test_accs = train_and_evaluate(model, opt_class, opt_kwargs, epochs)
    results[name] = {
        'train_losses': train_losses,
        'test_losses': test_losses,
        'train_accs': train_accs,
        'test_accs': test_accs
    }

# Plotting
plt.figure(figsize=(12,5))
for name in results:
    plt.plot(results[name]['test_accs'], label=f'{name} Test Acc')
plt.title('Test Accuracy per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
