In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision as vision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120) # Display options for output

import tensorboard

from torch.utils.tensorboard import SummaryWriter
from tensorboard import notebook

from itertools import product

In [2]:
print("    PyTorch Version:", torch.__version__)
print("Torchvision Version:", vision.__version__)
print("Tensorboard Version:", tensorboard.__version__)
print()
print("----------------------------------------")
notebook.list()

    PyTorch Version: 1.10.0
Torchvision Version: 0.11.1
Tensorboard Version: 2.7.0

----------------------------------------
Known TensorBoard instances:
  - port 6006: logdir runs (started 0:58:59 ago; pid 26424)


In [3]:
def get_num_correct(preds, labels):
  return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
parameters = dict(
        lr = [0.01, 0.001]
        ,batch_size = [10, 100, 1000]
        ,shuffle = [True, False]
)

In [5]:
param_values = [i for i in parameters.values()]
param_values

[[0.01, 0.001], [10, 100, 1000], [True, False]]

In [12]:
for lr, batch_size, shuffle in product(*param_values):
    print(lr, batch_size, shuffle)

0.01 10 True
0.01 10 False
0.01 100 True
0.01 100 False
0.01 1000 True
0.01 1000 False
0.001 10 True
0.001 10 False
0.001 100 True
0.001 100 False
0.001 1000 True
0.001 1000 False


In [8]:
class Network(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
    self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

    self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
    self.fc2 = nn.Linear(in_features=120, out_features=60)
    self.out = nn.Linear(in_features=60, out_features=10)

  def forward(self, t):

    t = F.relu(self.conv1(t))
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    t = F.relu(self.conv2(t))
    t = F.max_pool2d(t, kernel_size=2, stride=2) 

    t = t.flatten(start_dim=1)
    t = F.relu(self.fc1(t))

    t = F.relu(self.fc2(t))

    t = self.out(t)

    return t

In [9]:
# Load the training set. 

train_set = vision.datasets.FashionMNIST(
    root='./data'
    ,train=True
    ,download=False # If the data needs to be downloaded, change to True.
    ,transform=transforms.Compose([
                transforms.ToTensor()
    ])
)

In [16]:
run_number = 0
network = Network()
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
images, labels = next(iter(train_loader))
grid = vision.utils.make_grid(images)
optimizer = optim.Adam(network.parameters(), lr=lr)

for lr, batch_size, shuffle in product(*param_values):
    comment = f' batch_size={batch_size} lr={lr} shuffle={shuffle}'

# allows us to identify each run in tensorboard
#comment = f' batch_size={batch_size} lr={lr}' 
    tb = SummaryWriter(comment=comment) 
    tb.add_image('images', grid)
    tb.add_graph(network, images)
    # ---------------------------------------------

    for epoch in range(5):
      total_loss = 0
      total_correct = 0
      for batch in train_loader:
        images, labels = batch # get batch
        preds = network(images) # Pass batch
        loss = F.cross_entropy(preds, labels) # calculate loss
        optimizer.zero_grad() # zero out gradients
        loss.backward() # Calculate new gradients
        optimizer.step() # update weights

        total_loss += loss.item() * batch_size
        total_correct += get_num_correct(preds, labels)

      tb.add_scalar('Loss', total_loss, epoch)
      tb.add_scalar('Number Correct', total_correct, epoch)
      tb.add_scalar('Accuracy', total_correct/len(train_set), epoch)

      for name, weight in network.named_parameters():
            tb.add_histogram(name, weight, epoch)
            tb.add_histogram(f'{name}.grad', weight.grad, epoch)

    run_number += 1
    print("Run:", run_number, "total_correct:", total_correct, "loss:", total_loss)

tb.close()


Run: 1 total_correct: 46906 loss: 347.0539879798889
Run: 2 total_correct: 50060 loss: 271.28078520298004
Run: 3 total_correct: 51188 loss: 2409.0529084205627
Run: 4 total_correct: 52132 loss: 2166.92715883255
Run: 5 total_correct: 52538 loss: 20428.76425385475
Run: 6 total_correct: 53004 loss: 19349.056839942932
Run: 7 total_correct: 53270 loss: 184.72918450832367
Run: 8 total_correct: 53555 loss: 176.84176236391068
Run: 9 total_correct: 53749 loss: 1715.3362542390823
Run: 10 total_correct: 53941 loss: 1649.2481961846352
Run: 11 total_correct: 54143 loss: 15946.110963821411
Run: 12 total_correct: 54282 loss: 15401.7873108387


In [None]:
#