In [None]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from model import WideModel
from tqdm.auto import tqdm

In [None]:
# load dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))  # Normalize with mean and std for MNIST
])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


In [None]:
# create dataloaders
batch_size = 64
train_dataloader = DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(mnist_testset, batch_size=batch_size, shuffle=True)

In [None]:
# create model
model = WideModel(hidden_dim_scale = 20)

# create optimizer
lr = 1e-3
optimizer = torch.optim.SGD([p for p in model.parameters()], lr=lr)

In [None]:
# train model
num_epochs = 10
for epoch in range(num_epochs):
    print(f"Training for epoch {epoch}")
    for x,y in tqdm(train_dataloader):
        loss = model.loss(x,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Testing for epoch {epoch}")
    num_test_datapoints = 0
    total_loss = 0
    for x,y in tqdm(test_dataloader):
        num_test_datapoints += 1
        with torch.no_grad():
            loss = model.loss(x,y)
            total_loss = total_loss + loss.detach().cpu()
            
    test_loss = total_loss/num_test_datapoints
    print(f"Epoch {epoch} Test Loss: {test_loss}")