In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
torch.set_printoptions(linewidth=120)
import torch.nn.functional as F
import torchvision
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from itertools import product

In [2]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [4]:
#CNN model
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.flatten(x, start_dim = 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        return x

In [5]:
train_set = torchvision.datasets.FashionMNIST(
    root = './data/FashionMNIST',
    train = True,
    download = True,
    transform = transforms.ToTensor())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#Hyperparameters
paramerters = dict(
    lr = [0.01, 0.001],
    batch_size = [32, 64, 128],
    shuffle = [True, False]
)

param_values = [v for v in paramerters.values()]

100%|██████████| 26.4M/26.4M [04:30<00:00, 97.7kB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 69.5kB/s]
100%|██████████| 4.42M/4.42M [00:23<00:00, 192kB/s] 
100%|██████████| 5.15k/5.15k [00:00<00:00, 8.77MB/s]


In [7]:
model_t = CNN()
for run_id, (lr, batch_size, shuffle) in enumerate(product(*param_values)):
    print('Run id:', run_id)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=shuffle)
    model = CNN().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    comment = f' batch_size = {batch_size} lr = {lr} shuffle_{shuffle}'
    tb = SummaryWriter(comment=comment)
    images, labels = next(iter(train_loader))
    grid = torchvision.utils.make_grid(images)
    tb.add_image('images', grid)
    tb.add_graph(model_t, images)

    for epoch in range(5):
        total_loss = 0
        total_correct = 0
        for images, labels in train_loader:
            images = images.to(device)
            labels = labels.to(device)
            preds = model(images)

            loss = criterion(preds, labels)
            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tb.add_scalar('Loss', total_loss, epoch)
        tb.add_scalar('Correct', total_correct, epoch)
        tb.add_scalar('Accuracy', total_correct/len(train_set), epoch)

        print('batch_size', batch_size, 'lr', lr, 'shuffle', shuffle)
        print('epoch', epoch, 'total_correct:', total_correct, 'loss:', total_loss)
    print('-------------------------------------------------------------------')
    tb.add_hparams(
        {'lr': lr, 'batch_size': batch_size, 'shuffle': shuffle},
        {
            'accuracy': total_correct/len(train_set),
            'loss': total_loss,
        },
        )

tb.close()

Run id: 0
batch_size 32 lr 0.01 shuffle True
epoch 0 total_correct: 47471 loss: 1041.5790256336331
batch_size 32 lr 0.01 shuffle True
epoch 1 total_correct: 50565 loss: 804.1287724226713
batch_size 32 lr 0.01 shuffle True
epoch 2 total_correct: 51105 loss: 772.3443571180105
batch_size 32 lr 0.01 shuffle True
epoch 3 total_correct: 51225 loss: 755.1587502472103
batch_size 32 lr 0.01 shuffle True
epoch 4 total_correct: 51346 loss: 740.9222726151347
-------------------------------------------------------------------
Run id: 1
batch_size 32 lr 0.01 shuffle False
epoch 0 total_correct: 48407 loss: 978.7986023500562
batch_size 32 lr 0.01 shuffle False
epoch 1 total_correct: 50949 loss: 772.2412153184414
batch_size 32 lr 0.01 shuffle False
epoch 2 total_correct: 51296 loss: 743.3958904966712
batch_size 32 lr 0.01 shuffle False
epoch 3 total_correct: 51665 loss: 719.8462783880532
batch_size 32 lr 0.01 shuffle False
epoch 4 total_correct: 51624 loss: 723.77579459548
----------------------------

In [9]:
%load_ext tensorboard
%tensorboard --logdir=runs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 25932), started 0:02:47 ago. (Use '!kill 25932' to kill it.)