In [5]:
import os
import random
import math

import numpy as np
import torch
from torch import nn
from tqdm.auto import tqdm

from dataloader import get_dataloader
from resnet import ResNet

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
import wandb

wandb.login()



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.




[34m[1mwandb[0m: Currently logged in as: [33meddiezhuang[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
wandb.init(
    project="resnet",
    config={
        "batch_size": 128,
        "learning_rate": 0.1,
        "num_epochs": 7,
        "weight_decay": 0.0001,
        "momentum": 0.9,
        "n": 3,
    },
)

config = wandb.config

In [None]:
def validate_model(model, val_dataloader, loss_func):
    model.evail()
    val_loss = 0
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(val_dataloader):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            val_loss += loss_func(outputs, labels) * labels.size(0)

In [None]:
train_dataloader, val_dataloader = get_dataloader(True, config.batch_size)
model = ResNet(config.n).to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, momentum=config.momentum)
epochs = config.num_iterations // len(train_dataloader)
n_steps_per_epoch = math.ceil(len(train_dataloader.dataset) / config.batch_size)

for epoch in range(config.epochs):
    model.train()

    for step, (images, labels) in enumerate(train_dataloader):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        train_loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()

        accuracy = (outputs.argmax(1) == labels).float().mean()
        metrics = {
            "train/train_loss": train_loss,
            "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch,
            "train/error": 1 - accuracy,
        }
        wandb.log(metrics)

    val_loss, accuracy = validate_model(model, val_dataloader, loss_func)
