<a href="https://colab.research.google.com/github/AnhVietPham/Deep-Learning/blob/main/DL-Pytorch/training-loop-run-builder-run-manager/Training_Loop_Run_Builder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from collections import OrderedDict
from collections import namedtuple
from itertools import product

In [None]:
class RunBuilder():
  @staticmethod
  def get_runs(params):
    Run = namedtuple('Run', params.keys())

    runs = []
    for v in product(*params.values()):
      print(v)
      runs.append(Run(*v))
    return runs

In [None]:
params = OrderedDict(
    lr = [.01,.001, .0001],
    batch_size = [1000,10000],
    device = ["cuda", "cpu"]
)

In [None]:
params.values()

odict_values([[0.01, 0.001, 0.0001], [1000, 10000]])

In [None]:
params.keys()

odict_keys(['lr', 'batch_size'])

In [None]:
for v in product(*params.values()):
    print(*v)

0.01 1000
0.01 10000
0.001 1000
0.001 10000
0.0001 1000
0.0001 10000


# **A Cartesian Product**

In [1]:
X = {1,2,3}
Y = {1,2,3}
{(x,y) for x in X for y in Y}

{(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)}

In [2]:
X = {1,2,3}
Y = {1,2,3}

cartesian_product = set()
for x in X:
  for y in Y:
    cartesian_product.add((x,y))
cartesian_product

{(1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3)}

# **CNN Training Loop Refactoring - Simultaneous Hyperparameter Testing**

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from IPython.display import display, clear_output
import pandas as pd
import time
import json

from itertools import product
from collections import namedtuple
from collections import OrderedDict

In [17]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, t):
       
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)
        
        t = t.flatten(start_dim=1)
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.out(t)
        
        return t

In [5]:
class RunBuilder():
    @staticmethod
    def get_runs(params):

        Run = namedtuple('Run', params.keys())

        runs = []
        for v in product(*params.values()):
            runs.append(Run(*v))

        return runs

In [31]:
class RunManager():
  def __init__(self):
    self.epoch_count = 0
    self.epoch_loss = 0
    self.epoch_num_correct = 0
    self.epoch_start_time = None

    self.run_params = None
    self.run_count = 0
    self.run_data = []
    self.run_start_time = None

    self.network = None
    self.loader = None
    self.tb = None
  
  def begin_run(self, run, network, loader):
    self.run_start_time = time.time()

    self.run_params = run
    self.run_count += 1

    self.network = network
    self.loader = loader
    self.tb = SummaryWriter(comment=f'-{run}')

    images, labels = next(iter(self.loader))
    grid = torchvision.utils.make_grid(images)

    self.tb.add_image('images', grid)
    self.tb.add_graph(self.network, images)

  def end_run(self):
    self.tb.close()
    self.epoch_count = 0

  def begin_epoch(self):
    self.epoch_start_time = time.time()

    self.epoch_count += 1
    self.epoch_loss = 0
    self.epoch_num_correct = 0

  def end_epoch(self):

    epoch_duration = time.time() - self.epoch_start_time
    run_duration = time.time() - self.run_start_time

    loss = self.epoch_loss / len(self.loader.dataset)
    accuracy = self.epoch_num_correct / len(self.loader.dataset)

    self.tb.add_scalar('Loss', loss, self.epoch_count)
    self.tb.add_scalar('Accuracy', accuracy, self.epoch_count)

    for name, param in self.network.named_parameters():
      self.tb.add_histogram(name, param, self.epoch_count)
      self.tb.add_histogram(f'{name}.grad', param.grad, self.epoch_count)
    
    results = OrderedDict()
    results["run"] = self.run_count
    results["epoch"] = self.epoch_count
    results["loss"] = loss
    results["accuracy"] = accuracy
    results["epoch duration"] = epoch_duration
    results["run duration"] = run_duration
    for k, v in self.run_params._asdict().items(): results[k] = v
    self.run_data.append(results) 

    df = pd.DataFrame.from_dict(self.run_data, orient = 'columns')

    clear_output(wait = True)
    display(df)

  def track_loss(self, loss, batch):
    self.epoch_loss += loss.item() * batch[0].shape[0]

  def track_num_correct(self, preds, labels):
    self.epoch_num_correct += self._get_num_correct(preds, labels)

  def _get_num_correct(self, preds, labels):
        return preds.argmax(dim=1).eq(labels).sum().item()

  def save(self, fileName):

    pd.DataFrame.from_dict(
        self.run_data, orient = 'columns'
    ).to_csv(f'{fileName}.csv')

    with open(f'{fileName}.json', 'w', encoding='utf-8') as f:
      json.dump(self.run_data, f, ensure_ascii=False, indent = 4)

In [16]:
train_set = torchvision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=True
    ,transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [32]:
params = OrderedDict(
    lr = [.01],
    batch_size = [1000],
    shuffle = [True]
)

m = RunManager()

for run in RunBuilder.get_runs(params):
  network = Network()
  loader = DataLoader(train_set, batch_size=run.batch_size,shuffle=run.shuffle)
  optimizer = optim.Adam(network.parameters(), lr = run.lr)

  m.begin_run(run, network, loader)
  for epoch in range(5):
    m.begin_epoch()
    for batch in loader:
      images, labels = batch
      preds = network(images)
      loss = F.cross_entropy(preds, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      m.track_loss(loss, batch)
      m.track_num_correct(preds,labels)
    m.end_epoch()
  m.end_run()
m.save('results')

Unnamed: 0,run,epoch,loss,accuracy,epoch duration,run duration,lr,batch_size,shuffle
0,1,1,0.956696,0.631467,17.416202,18.400916,0.01,1000,True
1,1,2,0.510977,0.80825,17.440042,35.98759,0.01,1000,True
2,1,3,0.415815,0.84745,17.214252,53.365074,0.01,1000,True
3,1,4,0.373065,0.862117,16.972457,70.494726,0.01,1000,True
4,1,5,0.343703,0.873517,17.317719,87.970785,0.01,1000,True
