In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
from torch.utils import data
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import accuracy_score

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((.1307, ), (.3081, ))])
train_data = datasets.MNIST("~/.pytorch", transform=transform, download=True)
test_data = datasets.MNIST("~/.pytorch", train=False, transform=transform, download=True)

In [4]:
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(nn.Conv2d(1, 32, 3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 3),
                                  nn.MaxPool2d(2),
                                  nn.ReLU(),
                                  nn.Flatten())
        
        self.head = nn.Sequential(nn.Linear(1600, 512),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.5),
                                  nn.Linear(512, 10))
    
    def forward(self, x):
        x = self.body(x)
        x = self.head(x)
        return x

In [5]:
@torch.enable_grad()
def update(model, data_loader, loss, opt):
    device = next(model.parameters()).device
    model.train()
    
    errors= []
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        err = loss(logits, y) 
        errors.append(err)
        
        opt.zero_grad()
        err.backward()
        opt.step()
    
    return errors


@torch.no_grad()
def evaluate(model, data_loader, metric):
    device = next(model.parameters()).device
    model.eval()
    
    accuracies = []
    for x, y in data_loader:
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        y_hat = torch.argmax(logits.cpu(), 1)
        
        acc = metric(y.cpu(), y_hat) 
        accuracies.append(acc)
    
    return accuracies

In [6]:
class Grid():
    def __init__(self, *params):
        self.params = {}
        for p in params:
            self.params[p] = None

    def add_values(self, key, values):
        self.params[key] = values

    def print_grid(self):
        for key in self.params.keys():
            print(key, end=": ")
            for val in self.params[key]:
                print(val, end=" ")
            print()

    def get_combo(self):
        for item in itertools.product(*self.params.values()):
            t = zip(self.params.keys(), item)
            yield dict(t)

In [7]:
hyper_params = Grid("lr", "nr_epochs", "batch_size")
hyper_params.add_values("lr", [0.001, 0.0001])
hyper_params.add_values("batch_size", [64])
hyper_params.add_values("nr_epochs", [10])

hyper_params.print_grid()
for params in hyper_params.get_combo():
    print(params)

lr: 0.001 0.0001 
nr_epochs: 10 
batch_size: 64 
{'lr': 0.001, 'nr_epochs': 10, 'batch_size': 64}
{'lr': 0.0001, 'nr_epochs': 10, 'batch_size': 64}


In [8]:
path = "best_model.pt"
max_acc = 0

criterion = nn.CrossEntropyLoss()

for params in hyper_params.get_combo():
    batch_size = params["batch_size"]
    lr = params["lr"]
    n_epochs = params["nr_epochs"]
    
    net = ConvNet().to(device)
    opt = torch.optim.Adam(net.parameters(), lr=lr)
    train_loader = data.DataLoader(train_data, shuffle=True, batch_size=batch_size, num_workers=4)
    test_loader = data.DataLoader(test_data, shuffle=False, batch_size=batch_size, num_workers=4)

    print("--"*40)
    for epoch in range(n_epochs):
        errors = update(net, train_loader, criterion, opt)
        acc = evaluate(net, test_loader, accuracy_score)
        avg_acc = sum(acc)/len(acc)
        print(f"Epoch: {epoch+1} \t  Params: lr: {lr}, batch_size: {batch_size} \t Acc: {avg_acc * 100}%")
        
    
        if avg_acc > max_acc:
            max_acc = avg_acc
            torch.save(net.state_dict(), path)

------------------------------------------------------------
Epoch: 1 	  Params: lr: 0.001, batch_size: 64 	 Acc: 98.44745222929936%
Epoch: 2 	  Params: lr: 0.001, batch_size: 64 	 Acc: 98.8953025477707%
Epoch: 3 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.19386942675159%
Epoch: 4 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.19386942675159%
Epoch: 5 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.28343949044586%
Epoch: 6 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.03463375796179%
Epoch: 7 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.23367834394905%
Epoch: 8 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.15406050955414%
Epoch: 9 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.23367834394905%
Epoch: 10 	  Params: lr: 0.001, batch_size: 64 	 Acc: 99.17396496815286%
------------------------------------------------------------
Epoch: 1 	  Params: lr: 0.0001, batch_size: 64 	 Acc: 97.22332802547771%
Epoch: 2 	  Params: lr: 0.0001, batch_size: 64 	 Acc: 98.19864649681529%
Epoch: 3 	  