## import modules

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import optuna

torch.manual_seed(0)
device = torch.device("cpu")

## define model architecture

In [2]:
class ConvNet(nn.Module):
    def __init__(self, trial):
        super(ConvNet, self).__init__()
        num_conv_layers = trial.suggest_int("num_conv_layers", 1, 4)
        num_fc_layers = trial.suggest_int("num_fc_layers", 1, 2)

        self.layers = []
        input_depth = 1 # grayscale image
        for i in range(num_conv_layers):
            output_depth = trial.suggest_int(f"conv_depth_{i}", 16, 64)
            self.layers.append(nn.Conv2d(input_depth, output_depth, 3, 1))
            self.layers.append(nn.ReLU())
            input_depth = output_depth
        self.layers.append(nn.MaxPool2d(2))
        p = trial.suggest_float(f"conv_dropout_{i}", 0.1, 0.4)
        self.layers.append(nn.Dropout(p))
        self.layers.append(nn.Flatten())

        input_feat = self._get_flatten_shape()
        for i in range(num_fc_layers):
            output_feat = trial.suggest_int(f"fc_output_feat_{i}", 16, 64)
            self.layers.append(nn.Linear(input_feat, output_feat))
            self.layers.append(nn.ReLU())
            p = trial.suggest_float(f"fc_dropout_{i}", 0.1, 0.4)
            self.layers.append(nn.Dropout(p))
            input_feat = output_feat
        self.layers.append(nn.Linear(input_feat, 10))
        self.layers.append(nn.LogSoftmax(dim=1))
        
        self.model = nn.Sequential(*self.layers)
    
    def _get_flatten_shape(self):
        conv_model = nn.Sequential(*self.layers)
        op_feat = conv_model(torch.rand(1, 1, 28, 28))
        n_size = op_feat.data.view(1, -1).size(1)
        return n_size
 
    def forward(self, x):
        return self.model(x)

## create data loaders

In [3]:
# The mean and standard deviation values are calculated as the mean of all pixel values of all images in the training dataset
train_ds = datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1302,), (0.3069,))]))
test_ds = datasets.MNIST('../data', train=False, 
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1302,), (0.3069,))]))

train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=32, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_ds, batch_size=500, shuffle=True)

## define training and inference routines

In [4]:
def train(model, device, train_dataloader, optim, epoch):
    model.train()
    for b_i, (X, y) in enumerate(train_dataloader):
        X, y = X.to(device), y.to(device)
        optim.zero_grad()
        pred_prob = model(X)
        loss = F.nll_loss(pred_prob, y) # nll is the negative likelihood loss
        loss.backward()
        optim.step()
        if b_i % 500 == 0:
            print('epoch: {} [{}/{} ({:.0f}%)]\t training loss: {:.6f}'.format(
                epoch, b_i * len(X), len(train_dataloader.dataset),
                100. * b_i / len(train_dataloader), loss.item()))

In [5]:
def test(model, device, test_dataloader):
    model.eval()
    loss = 0
    success = 0
    with torch.no_grad():
        for X, y in test_dataloader:
            X, y = X.to(device), y.to(device)
            pred_prob = model(X)
            loss += F.nll_loss(pred_prob, y, reduction='sum').item()  # loss summed across the batch
            pred = pred_prob.argmax(dim=1, keepdim=True)  # use argmax to get the most likely prediction
            success += pred.eq(y.view_as(pred)).sum().item()

    loss /= len(test_dataloader.dataset)
    
    accuracy = 100. * success / len(test_dataloader.dataset)

    print('\nTest dataset: Overall Loss: {:.4f}, Overall Accuracy: {}/{} ({:.0f}%)\n'.format(
        loss, success, len(test_dataloader.dataset), accuracy))
    
    return accuracy

## define optimizer and model training routine

In [6]:
def objective(trial):
    
    model = ConvNet(trial)
    opt_name = trial.suggest_categorical("optimizer", ["Adam", "Adadelta", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-1, 5e-1, log=True)
    optimizer = getattr(optim, opt_name)(model.parameters(), lr=lr)
    
    for epoch in range(1, 3):
        train(model, device, train_dataloader, optimizer, epoch)
        accuracy = test(model, device, test_dataloader)
        trial.report(accuracy, epoch)
        
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return accuracy

## run the hyperparameter search

In [None]:
study = optuna.create_study(study_name="mastering_pytorch", direction="maximize")
study.optimize(objective, n_trials=100, timeout=2000)

pruned_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]
complete_trials = [t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]

print("results: ")
print("num_trials_conducted: ", len(study.trials))
print("num_trials_pruned: ", len(pruned_trials))
print("num_trials_completed: ", len(complete_trials))

print("results from best trial:")
trial = study.best_trial

print("accuracy: ", trial.value)
print("hyperparameters: ")
for key, value in trial.params.items():
    print("{}: {}".format(key, value))

[32m[I 2020-10-24 10:57:24,406][0m A new study created in memory with name: mastering_pytorch[0m



Test dataset: Overall Loss: 2.3394, Overall Accuracy: 974/10000 (10%)



[32m[I 2020-10-24 11:01:01,311][0m Trial 0 finished with value: 10.32 and parameters: {'num_conv_layers': 1, 'num_fc_layers': 1, 'conv_depth_0': 29, 'conv_dropout_0': 0.33975460255509893, 'fc_output_feat_0': 63, 'fc_dropout_0': 0.28457195650592076, 'optimizer': 'Adam', 'lr': 0.4635077277229724}. Best is trial 0 with value: 10.32.[0m



Test dataset: Overall Loss: 2.3260, Overall Accuracy: 1032/10000 (10%)

