In [1]:
import torch
import torch.optim as optim
import torch.nn.functional as F

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from model import Network
from utils import RunBuilder, RunManager, Epoch

from collections import OrderedDict

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

In [3]:
params = OrderedDict(
    lr = [.01, 0.001],
    batch_size = [1000, 2000],
    num_workers = [0, 1]
) 

with RunManager('results') as m:
    
    for run in RunBuilder.get_runs(params):
        network = Network()
        loader = DataLoader(train_set, batch_size=run.batch_size, num_workers=run.num_workers)
        optimizer = optim.Adam(network.parameters(), lr=run.lr)
        
        m.begin_run(run, network,loader)
        
        epochs = (Epoch(loader) for i in range(5))
        for epoch in epochs:
            
            for batch in loader:
                images, labels = batch
                preds = network(images) #pass batch
                loss = F.cross_entropy(preds, labels, reduction='sum')
                
                optimizer.zero_grad() #set the gradients to zero
                loss.backward() #calculate gradients
                optimizer.step() #update weights

                epoch.track_loss(loss)
                epoch.track_num_correct(preds,labels)

            epoch.end()
            m.write_epoch(epoch.to_dict())
        m.end_run()

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,num_workers
0,1,0,0.608819,0.787883,10.459532,11.168809,0.01,1000,0
1,1,1,0.345524,0.872917,10.868247,22.160529,0.01,1000,0
2,1,2,0.302093,0.888683,9.606172,31.880126,0.01,1000,0
3,1,3,0.271406,0.899083,9.651167,41.647221,0.01,1000,0
4,1,4,0.250365,0.906467,9.620304,51.381214,0.01,1000,0
5,2,0,0.584984,0.7968,7.649908,8.340589,0.01,1000,1
6,2,1,0.338123,0.87405,7.431083,15.906837,0.01,1000,1
7,2,2,0.292193,0.8926,7.723281,23.75554,0.01,1000,1
8,2,3,0.27105,0.89955,8.044301,31.915893,0.01,1000,1
9,2,4,0.248125,0.908283,7.901829,39.944805,0.01,1000,1
