In [1]:
import torch
import torch.nn as nn
import torch.optim as opt
torch.set_printoptions(linewidth=120)
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [33]:
def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()

In [27]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels = 6, kernel_size = 5)
        self.conv2 = nn.Conv2d(in_channels = 6, out_channels = 12, kernel_size =5)

        self.fc1 = nn.Linear(in_features = 12*4*4, out_features = 120)
        self.fc2 = nn.Linear(in_features = 120, out_features = 60)
        self.out = nn.Linear(in_features = 60, out_features = 10)

    def forward(self, x):
            x = F.relu(self.conv1(x))
            x = F.max_pool2d(x, kernel_size = 2, stride = 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, kernel_size = 2, stride = 2)
            x = torch.flatten(x, start_dim = 1)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.out(x)

            return x

In [28]:
train_set = torchvision.datasets.FashionMNIST(root = "./data", train = True, download = True, transform = transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(train_set, batch_size = 100, shuffle = True)

In [29]:
tb = SummaryWriter()
model = CNN()
images, labels = next(iter(train_loader))
grid = torchvision.utils.make_grid(images)
tb.add_image("images", grid)
tb.add_graph(model, images)
tb.close()

In [35]:
device = ("cuda" if torch.cuda.is_available() else "cpu")
model = CNN().to(device)
train_loader = torch.utils.data.DataLoader(train_set, batch_size = 100, shuffle = True)
optimizer = opt.Adam(model.parameters(), lr = 0.01)
criterion = torch.nn.CrossEntropyLoss()

tb = SummaryWriter()

for epoch in range(10):
    total_loss = 0
    total_correct = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        preds = model(images)

        loss = criterion(preds, labels)
        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    tb.add_scalar("Loss", total_loss, epoch)
    tb.add_scalar("Correct", total_correct, epoch)
    tb.add_scalar("Accuracy", total_correct/len(train_set), epoch)
    tb.add_histogram("conv1.bias", model.conv1.bias, epoch)
    tb.add_histogram("conv1.weight", model.conv1.weight, epoch)
    tb.add_histogram("conv2.bias", model.conv2.bias, epoch)
    tb.add_histogram("conv2.weight", model.conv2.weight, epoch)

    print("epoch:", epoch, "total_correct:", total_correct, "loss:", total_loss)

    tb.close()

epoch: 0 total_correct: 46597 loss: 351.93421091139317
epoch: 1 total_correct: 51676 loss: 227.25896137952805
epoch: 2 total_correct: 52354 loss: 209.65638510882854
epoch: 3 total_correct: 52514 loss: 202.17722918093204
epoch: 4 total_correct: 52875 loss: 191.5535284280777
epoch: 5 total_correct: 53107 loss: 187.56945431232452
epoch: 6 total_correct: 53252 loss: 184.39352037012577
epoch: 7 total_correct: 53178 loss: 184.16533616930246
epoch: 8 total_correct: 53427 loss: 179.81141930818558
epoch: 9 total_correct: 53472 loss: 177.68081154674292


# Hyperparameter tunig

In [38]:
from itertools import product
parameters = dict(lr = [0.01, 0.001], batch_size = [32,64, 128], shuffle = [True, False])

param_values = [v for v in parameters.values()]
print(param_values)

# Get all possible combinations of parameter values
for lr, batch_size, shuffle in product(*param_values):
    print(lr, batch_size, shuffle)

[[0.01, 0.001], [32, 64, 128], [True, False]]
0.01 32 True
0.01 32 False
0.01 64 True
0.01 64 False
0.01 128 True
0.01 128 False
0.001 32 True
0.001 32 False
0.001 64 True
0.001 64 False
0.001 128 True
0.001 128 False


In [39]:
# Training Loop
for run_id, (lr,batch_size, shuffle) in enumerate(product(*param_values)):
    print("run id:", run_id + 1)
    model = CNN().to(device)
    train_loader = torch.utils.data.DataLoader(train_set,batch_size = batch_size, shuffle = shuffle)
    optimizer = opt.Adam(model.parameters(), lr= lr)
    criterion = torch.nn.CrossEntropyLoss()
    comment = f' batch_size = {batch_size} lr = {lr} shuffle = {shuffle}'
    tb = SummaryWriter(comment=comment)
    for epoch in range(5):
        total_loss = 0
        total_correct = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)

            loss = criterion(preds, labels)
            total_loss+= loss.item()
            total_correct+= get_num_correct(preds, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        tb.add_scalar("Loss", total_loss, epoch)
        tb.add_scalar("Correct", total_correct, epoch)
        tb.add_scalar("Accuracy", total_correct/ len(train_set), epoch)

        print("batch_size:",batch_size, "lr:",lr,"shuffle:",shuffle)
        print("epoch:", epoch, "total_correct:", total_correct, "loss:",total_loss)
    print("___________________________________________________________________")

    tb.add_hparams(
            {"lr": lr, "bsize": batch_size, "shuffle":shuffle},
            {
                "accuracy": total_correct/ len(train_set),
                "loss": total_loss,
            },
        )

tb.close()


run id: 1
batch_size: 32 lr: 0.01 shuffle: True
epoch: 0 total_correct: 48590 loss: 974.33458911255
batch_size: 32 lr: 0.01 shuffle: True
epoch: 1 total_correct: 51141 loss: 759.3493862450123
batch_size: 32 lr: 0.01 shuffle: True
epoch: 2 total_correct: 51297 loss: 739.9819678105414
batch_size: 32 lr: 0.01 shuffle: True
epoch: 3 total_correct: 51698 loss: 717.3941930010915
batch_size: 32 lr: 0.01 shuffle: True
epoch: 4 total_correct: 51732 loss: 707.9779039155692
___________________________________________________________________
run id: 2
batch_size: 32 lr: 0.01 shuffle: False
epoch: 0 total_correct: 47870 loss: 1025.0641575157642
batch_size: 32 lr: 0.01 shuffle: False
epoch: 1 total_correct: 50789 loss: 789.9204152449965
batch_size: 32 lr: 0.01 shuffle: False
epoch: 2 total_correct: 50956 loss: 776.318956349045
batch_size: 32 lr: 0.01 shuffle: False
epoch: 3 total_correct: 51442 loss: 738.742635935545
batch_size: 32 lr: 0.01 shuffle: False
epoch: 4 total_correct: 51586 loss: 725.7080