In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision as vision
import torchvision.transforms as transforms
import tensorboard

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from IPython.display import display, clear_output
import pandas as pd
import time
import json

from itertools import product
from collections import namedtuple
from collections import OrderedDict
from tensorboard import notebook

In [2]:
print("    PyTorch Version:", torch.__version__)
print("Torchvision Version:", vision.__version__)
print("Tensorboard Version:", tensorboard.__version__)
print()
print("----------------------------------------")
notebook.list()

    PyTorch Version: 1.10.0
Torchvision Version: 0.11.1
Tensorboard Version: 2.7.0

----------------------------------------
Known TensorBoard instances:
  - port 6006: logdir runs (started 3 days, 1:35:39 ago; pid 26424)


In [3]:
class RunBuilder():
    @staticmethod
    def get_runs(params):
        
        Run = namedtuple('Run', params.keys())
        
        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))
            
        return runs

In [4]:
class RunManager():
    def __init__(self):
        
        self.epoch_count = 0
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        self.epoch_start_time = None
        
        self.run_params = None
        self.run_count = 0
        self.run_data = []
        self.run_start_time = None
        
        self.network = None
        self.loader = None
        self.tb = None
        
    def begin_run(self, run, network, loader):
        
        self.run_start_time = time.time()
        
        self.run_params = run
        self.run_count += 1
        
        self.network = network
        self.loader = loader
        self.tb = SummaryWriter(comment=f'{run}')
        
        images, labels = next(iter(self.loader))
        
        grid = vision.utils.make_grid(images)
        
        self.tb.add_image('images', grid)
        self.tb.add_graph(self.network, images)
        
    def end_run(self):
        self.tb.close()
        self.epoch_count = 0
        
    def begin_epoch(self):
        self.epoch_start_time = time.time()
        
        self.epoch_count += 1
        self.epoch_loss = 0
        self.epoch_num_correct = 0
        
    def end_epoch(self):
        
        epoch_duration = time.time() - self.epoch_start_time
        run_duration = time.time() - self.run_start_time
        
        loss = self.epoch_loss / len(self.loader.dataset)
        accuracy = self.epoch_num_correct / len(self.loader.dataset)
        
        self.tb.add_scalar('Loss', loss, self.epoch_count)
        self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)
        
        for name, param in self.network.named_parameters():
            self.tb.add_histogram(name, param, self.epoch_count)
            self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
            
        results = OrderedDict()
        results["Run"] = self.run_count
        results["Epoch"] = self.epoch_count
        results["Loss"] = loss
        results["Accuracy"] = accuracy
        results["Epoch Duration"] = epoch_duration
        results["Run Duration"] = run_duration
        
        for i,j in self.run_params._asdict().items(): results[i] = j
        self.run_data.append(results)
        df = pd.DataFrame.from_dict(self.run_data, orient='columns')
        
        clear_output(wait=True)
        display(df)
        
    def track_loss(self, loss):
        self.epoch_loss += loss.item() * self.loader.batch_size
        
    def track_num_correct(self, preds, labels):
        self.epoch_num_correct += self.get_num_correct(preds, labels)
        
    @torch.no_grad()
    def get_num_correct(self,preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()
    
    def save(self, fileName):
        
        pd.DataFrame.from_dict(
            self.run_data
            ,orient='columns'
        ).to_csv(f'{fileName}.csv')
        
        with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
            json.dump(self.run_data, f, ensure_ascii = False, indent=4)
        

In [5]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

    self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out = nn.Linear(in_features=60, out_features=10)

  def forward(self, t):

    t = F.relu(self.conv1(t))
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    t = F.relu(self.conv2(t))
    t = F.max_pool2d(t, kernel_size=2, stride=2) 

    t = t.flatten(start_dim=1)
    t = F.relu(self.fc1(t))

    t = F.relu(self.fc2(t))

    t = self.out(t)

    return t

In [6]:
# Load the training set. 

train_set = vision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=False # If the data needs to be downloaded, change to True.
    ,transform=transforms.Compose([
                transforms.ToTensor()
    ])
)

In [7]:
number_of_epochs = 5

params = OrderedDict(
    lr= [0.1, 0.01, 0.001]
    ,batch_size = [100, 1000, 10000]
    ,shuffle = [True, False]
)

m = RunManager()
for run in RunBuilder.get_runs(params):
    
    network = Network()
    loader = DataLoader(train_set, batch_size=run.batch_size, shuffle=run.shuffle, num_workers=2)
    optimizer = optim.Adam(network.parameters(), lr=run.lr)
    
    m.begin_run(run, network, loader)
    for epoch in range(number_of_epochs):
        m.begin_epoch()
        for batch in loader:
            
            images = batch[0]
            labels = batch[1]
            preds = network(images)
            loss = F.cross_entropy(preds, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            m.track_loss(loss)
            m.track_num_correct(preds, labels)
        m.end_epoch()
    m.end_run()
m.save('results')


Unnamed: 0,Run,Epoch,Loss,Accuracy,Epoch Duration,Run Duration,lr,batch_size,shuffle
0,1,1,2.435442,0.101050,7.533236,9.630650,0.100,100,True
1,1,2,2.309248,0.099000,7.497050,17.171709,0.100,100,True
2,1,3,2.308990,0.098583,7.431345,24.640062,0.100,100,True
3,1,4,2.309949,0.101117,7.715756,32.391824,0.100,100,True
4,1,5,2.309337,0.098850,7.221189,39.647019,0.100,100,True
...,...,...,...,...,...,...,...,...,...
85,18,1,2.293446,0.100033,7.287707,15.399014,0.001,10000,False
86,18,2,2.244566,0.132483,7.150994,22.599019,0.001,10000,False
87,18,3,2.118959,0.366167,7.340566,29.993595,0.001,10000,False
88,18,4,1.840271,0.503367,7.216948,37.264554,0.001,10000,False
