https://towardsdatascience.com/a-complete-guide-to-using-tensorboard-with-pytorch-53cb2301e8c3

In [24]:
import torch
import torch.nn as nn
import torch.optim as opt
torch.set_printoptions(linewidth=120)
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

In [None]:
torch.

# Tensorboard SummaryWriter

In [6]:
from torch.utils.tensorboard import SummaryWriter

In [7]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()
# sum of predics equal to labels when preds == argmax of class probabilities.

## Simple CNN

In [12]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
        
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, kernel_size=2, stride=2)
        x = torch.flatten(x, start_dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.out(x)
        
        return x

In [9]:
train_set = torchvision.datasets.FashionMNIST(root="./data",
train = True,
 download=True,
transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(train_set,batch_size = 100, shuffle = True)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


0it [00:00, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!


  return torch.from_numpy(parsed).view(length, num_rows, num_cols)


In [13]:
tb = SummaryWriter() 

SummaryWriter(log_dir) creates log_dir to store log files.

In [13]:
# default `log_dir` is "runs"
model = CNN()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
tb.add_image("images", grid)
tb.add_graph(model, images)
tb.close()

# Training Loop 

In [15]:
device = ("cuda" if torch.cuda.is_available() else cpu)
model = CNN().to(device)
train_loader = torch.utils.data.DataLoader(train_set,batch_size = 100, shuffle = True)
optimizer = opt.Adam(model.parameters(), lr= 0.01)
criterion = torch.nn.CrossEntropyLoss()

tb = SummaryWriter()

for epoch in range(10):

    total_loss = 0
    total_correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)

        loss = criterion(preds, labels)
        total_loss+= loss.item()
        total_correct+= get_num_correct(preds, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    tb.add_scalar("Loss", total_loss, epoch)
    tb.add_scalar("Correct", total_correct, epoch)
    tb.add_scalar("Accuracy", total_correct/ len(train_set), epoch)

    tb.add_histogram("conv1.bias", model.conv1.bias, epoch)
    tb.add_histogram("conv1.weight", model.conv1.weight, epoch)
    tb.add_histogram("conv2.bias", model.conv2.bias, epoch)
    tb.add_histogram("conv2.weight", model.conv2.weight, epoch)

    print("epoch:", epoch, "total_correct:", total_correct, "loss:",total_loss)

tb.close()

epoch: 0 total_correct: 48173 loss: 318.929568991065
epoch: 1 total_correct: 51843 loss: 220.7587616443634
epoch: 2 total_correct: 52521 loss: 202.61319528520107
epoch: 3 total_correct: 52843 loss: 194.4499491006136
epoch: 4 total_correct: 52984 loss: 190.95944210886955
epoch: 5 total_correct: 53149 loss: 185.70717848092318
epoch: 6 total_correct: 53244 loss: 182.6273631080985
epoch: 7 total_correct: 53524 loss: 177.43566562235355
epoch: 8 total_correct: 53423 loss: 178.32259202748537
epoch: 9 total_correct: 53680 loss: 173.0257710069418


# Monitoring Hyperparameter tuning

In [17]:
from itertools import product
parameters = dict(lr = [0.01, 0.001],
                  batch_szie=[32,64,128],
                  shuffle = [True, False])
param_values = [v for v in parameters.values()]

In [27]:
for run_id, (lr, batch_size, shuffle) in enumerate(product(*param_values)):
    print("run id", run_id+1)
    model = CNN().to(device)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=shuffle)
    optimizer = opt.Adam(model.parameters(), lr=lr)
    criterion=torch.nn.CrossEntropyLoss()
    comment = f'batch_size = {batch_size} lr = {lr} shuffle = {shuffle}'
    tb = SummaryWriter(comment=comment)
    for epoch in range(5):
        total_loss = 0
        total_correct = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)
            
            loss = criterion(preds, labels)
            total_loss += loss.item()
            total_correct += get_num_correct(preds, labels)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        tb.add_scalar("Loss", total_loss, epoch)
        tb.add_scalar("Correct", total_correct, epoch)
        tb.add_scalar("Accuracy", total_correct / len(train_set), epoch)
        
        print("batch_size:",batch_size, "lr:",lr,"shuffle:",shuffle)
        print("epoch:", epoch, "total_correct:", total_correct, "loss:",total_loss)
    print("__________________________________________________________")

    tb.add_hparams(
            {"lr": lr, "bsize": batch_size, "shuffle":shuffle},
            {
                "accuracy": total_correct/ len(train_set),
                "loss": total_loss,
            },
        )

tb.close()

run id 1
batch_size: 32 lr: 0.01 shuffle: True
epoch: 0 total_correct: 47802 loss: 1023.2121961787343
batch_size: 32 lr: 0.01 shuffle: True
epoch: 1 total_correct: 50385 loss: 810.8808999881148
batch_size: 32 lr: 0.01 shuffle: True
epoch: 2 total_correct: 50992 loss: 777.6378016620874
batch_size: 32 lr: 0.01 shuffle: True
epoch: 3 total_correct: 51422 loss: 750.2654168866575
batch_size: 32 lr: 0.01 shuffle: True
epoch: 4 total_correct: 51558 loss: 728.7547300029546
__________________________________________________________
run id 2
batch_size: 32 lr: 0.01 shuffle: False
epoch: 0 total_correct: 46353 loss: 1128.91398049891
batch_size: 32 lr: 0.01 shuffle: False
epoch: 1 total_correct: 49656 loss: 891.4816038459539
batch_size: 32 lr: 0.01 shuffle: False
epoch: 2 total_correct: 50110 loss: 858.0037266984582
batch_size: 32 lr: 0.01 shuffle: False
epoch: 3 total_correct: 50267 loss: 844.4450650215149
batch_size: 32 lr: 0.01 shuffle: False
epoch: 4 total_correct: 50277 loss: 844.193120464682