In [11]:
import random

import numpy as np
import torch
from torch import nn

from dataloader import get_dataloader
from resnet import ResNet

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
config = (
    {
        "batch_size": 128,
        "learning_rate": 0.1,
        "iterations": 64000,
        "weight_decay": 0.0001,
        "momentum": 0.9,
        "n": 3,
    },
)

In [None]:
import wandb

wandb.login()

In [None]:
def make(config):
    train_dataloader, val_dataloader = get_dataloader(True, config.batch_size)
    test_dataloader = get_dataloader(False, config.batch_size)

    model = ResNet(config.n).to(device)

    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        momentum=config.momentum,
    )

    return (
        model,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        loss_func,
        optimizer,
    )

In [13]:
def validate(model, loader, loss_func):
    model.eval()
    val_loss = 0.0
    with torch.inference_mode():
        correct = 0
        for images, labels in enumerate(loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            val_loss += loss_func(outputs, labels) * labels.size(0)

            _, predicted_indices = torch.max(outputs.data, 1)
            correct += (predicted_indices == labels).sum().item()

    val_loss /= len(loader.dataset)
    val_accuracy = correct / len(loader.dataset)

    return val_loss, val_accuracy

In [None]:
def train(model, train_loader, val_loader, loss_func, optimizer, config):
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, loss_func, log="all", log_freq=10)

    iter = 0
    while iter < config.iterations:
        model.train()

        # Train for one epoch
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()

            # Report metrics every 25th batch
            if ((iter + 1) % 25) == 0:
                accuracy = (outputs.argmax(1) == labels).float().mean()
                wandb.log(
                    {
                        "train/error": 1 - accuracy,
                        "train/train_loss": train_loss,
                    }
                )

            iter += 1
            if iter >= config.iterations:
                break

        # Validate
        val_loss, val_accuracy = validate(model, val_loader, loss_func)
        wandb.log(
            {
                "validation/error": 1 - val_accuracy,
                "validation/train_loss": val_loss,
            }
        )

In [None]:
def model_pipeline(hyperparameters):

    # tell wandb to get started
    with wandb.init(project="resnet", config=hyperparameters):
        # access all HPs through wandb.config, so logging matches execution!
        config = wandb.config

        # make the model, data, and optimization problem
        model, train_loader, val_loader, test_loader, loss_func, optimizer = make(
            config
        )

        # and use them to train the model
        train(model, train_loader, loss_func, optimizer, config)

        # and test its final performance
        test(model, test_loader)

    return model