In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from model import Network
from utils import RunBuilder, RunManager, Epoch

from collections import OrderedDict

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

%load_ext autoreload
%autoreload 2

In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
params = OrderedDict(
    lr = [.01, 0.001],
    batch_size = [1000, 2000],
    num_workers = [0, 1]
) 

with RunManager('results', device) as m:
    
    for run in RunBuilder.get_runs(params):
        network = Network()
        network.to(device)
        loader = DataLoader(train_set, batch_size=run.batch_size, num_workers=run.num_workers)
        optimizer = optim.Adam(network.parameters(), lr=run.lr)
        
        m.begin_run(run, network,loader)
        
        epochs = (Epoch(loader) for i in range(5))
        for epoch in epochs:
            
            for batch, labels in loader:
                batch, labels = batch.to(device), labels.to(device)
                preds = network(batch) #pass batch
                loss = F.cross_entropy(preds, labels, reduction='sum')
                
                optimizer.zero_grad() #set the gradients to zero
                loss.backward() #calculate gradients
                optimizer.step() #update weights

                epoch.track_loss(loss)
                epoch.track_num_correct(preds,labels)

            epoch.end()
            m.write_epoch(epoch.to_dict())
        m.end_run()

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,num_workers
0,1,0,0.563438,0.805167,8.884274,9.478643,0.01,1000,0
1,1,1,0.329687,0.878467,8.95304,18.535555,0.01,1000,0
2,1,2,0.286246,0.894017,9.518364,28.15844,0.01,1000,0
3,1,3,0.265327,0.90145,11.92539,40.18634,0.01,1000,0
4,1,4,0.244339,0.90925,11.090806,51.394298,0.01,1000,0
5,2,0,0.570017,0.800283,8.212492,9.061668,0.01,1000,1
6,2,1,0.337432,0.87565,8.95717,18.193883,0.01,1000,1
7,2,2,0.300355,0.887633,9.40224,27.713147,0.01,1000,1
8,2,3,0.277852,0.897067,8.337512,36.206661,0.01,1000,1
9,2,4,0.259427,0.903783,10.099322,46.460416,0.01,1000,1
