In [11]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from PIL import Image

In [12]:
class LeNet5(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(6, 16, kernel_size=5, stride=1),
            nn.Tanh(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(16, 120, kernel_size=5, stride=1),
            nn.Tanh()
        )
        self.linear = nn.Sequential(
            nn.Linear(120, 84),
            nn.Tanh(),
            nn.Linear(84, 10)
        )
        
    def forward(self, x):
        out = self.conv(x)
        out = torch.flatten(out, 1)
        out = self.linear(out)
        
        return out        
        
    def general_step(self, batch, batch_idx):
        image, labels = batch
        out = self.forward(image)
        
        loss = F.cross_entropy(out, labels)
        pred = out.argmax(axis=1)
        acc = (pred == labels).float().mean()
        
        return loss, acc
    
    def training_step(self, batch, batch_idx):
        loss, acc = self.general_step(batch, batch_idx)
        logs = {'train_loss': loss, 'training_acc': acc}
        return {'loss': loss, 'training_acc': acc, 'log': logs}
    
    def validation_step(self, batch, batch_idx):
        loss, acc = self.general_step(batch, batch_idx)
        logs = {'val_loss': loss, 'val_acc': acc}
        return {'val_loss': loss, 'val_acc': acc, 'log': logs}
    
    def test_step(self, batch, batch_idx):
        loss, acc = self.general_step(batch, batch_idx)
        
        logs = {'test_loss': loss, 'test_acc': acc}
        return {'test_loss': loss, 'test_acc': acc, 'log': logs}
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        logs = {'val_loss': avg_loss, 'val_acc': avg_acc}
        return {'val_loss': avg_loss, 'val_acc': avg_acc, 'log': logs}
    
    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean().item()
        avg_acc = torch.stack([x['test_acc'] for x in outputs]).mean().item() * 100
        logs = {'test_loss': avg_loss, 'test_acc': avg_acc}
        test_acc = "{:.2f}".format(avg_acc) + '%'
        print("Test Accuracy:", test_acc)
        return {'test_loss': avg_loss, 'test_acc': avg_acc, 'log': logs}
    
    def prepare_data(self):
        transform = transforms.Compose([transforms.Resize((32, 32)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5), (0.5))])
        
        mnist_train = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=True)
        mnist_test = torchvision.datasets.MNIST(root='./data/', train=False, transform=transform, download=True)
        
        self.train_data, self.val_data = random_split(mnist_train, [50000, 10000])
        self.test_data = mnist_test
        
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.hparams['batch_size'], shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.hparams['batch_size'])
    
    def test_dataloader(self):
        return DataLoader(self.test_data)
    
    def configure_optimizers(self):
        opt = torch.optim.SGD(self.parameters(), lr=self.hparams['learning_rate'], momentum=0.9)
        return opt
        
    def predict(self, img, true_label=None):
        self.eval()
        self.cuda()
        img = img.resize((28, 28))
        img = transforms.ToTensor()(img)
        img_tensor = img.view((1,) + tuple(img.size()))
        pred = self.forward(img_tensor).argmax(axis=1)
        
        if true_label:
            print("The number is:", label)
        print("Recognized:", pred)
        
        img = img / 2 + 0.5
        npimg = img.numpy().reshape((32, 32))
        plt.imshow(npimg, cmap='gray')
        plt.show()

In [13]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback

In [14]:
from pytorch_lightning import Callback

class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""
    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

In [15]:
def objective(trial):
    # as explained above, we'll use this callback to collect the validation accuracies
    metrics_callback = MetricsCallback()  
    
    # here we sample the hyper params, similar as in our old random search
    trial_hparams = { "num_epochs": trial.suggest_int("num_epochs", 10, 65),
                     "batch_size": trial.suggest_int("batch_size", 8, 512), 
                     "learning_rate": trial.suggest_loguniform("learning_rate", 1e-8, 1e-1)}
    
    # create a trainer
    trainer = pl.Trainer(
        logger=False,                                                                  # deactivate PL logging
        max_epochs=trial_hparams['num_epochs'],                                                                  # epochs
        gpus=0 if torch.cuda.is_available() else None,                                 # #gpus
        callbacks=[metrics_callback],                                                  # save latest accuracy
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_acc"), # early stopping
    )
    
    # create model from these hyper params and train it
    model = LeNet5(trial_hparams)
    model.prepare_data()
    trainer.fit(model)

    # return validation accuracy from latest model, as that's what we want to minimize by our hyper param search
    return metrics_callback.metrics[-1]["val_acc"]

In [16]:
pruner = optuna.pruners.NopPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=10, timeout=2400)

GPU available: True, used: False
No environment variable for node rank defined. Set as 0.

   | Name     | Type       | Params
------------------------------------
0  | conv     | Sequential | 50 K  
1  | conv.0   | Conv2d     | 156   
2  | conv.1   | Tanh       | 0     
3  | conv.2   | AvgPool2d  | 0     
4  | conv.3   | Conv2d     | 2 K   
5  | conv.4   | Tanh       | 0     
6  | conv.5   | AvgPool2d  | 0     
7  | conv.6   | Conv2d     | 48 K  
8  | conv.7   | Tanh       | 0     
9  | linear   | Sequential | 11 K  
10 | linear.0 | Linear     | 10 K  
11 | linear.1 | Tanh       | 0     
12 | linear.2 | Linear     | 850   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-08-08 21:28:15,572] Finished trial#0 with value: 0.46809110045433044 with parameters: {'num_epochs': 56, 'batch_size': 297, 'learning_rate': 2.479486873162931e-05}. Best is trial#0 with value: 0.46809110045433044.
GPU available: True, used: False
No environment variable for node rank defined. Set as 0.

   | Name     | Type       | Params
------------------------------------
0  | conv     | Sequential | 50 K  
1  | conv.0   | Conv2d     | 156   
2  | conv.1   | Tanh       | 0     
3  | conv.2   | AvgPool2d  | 0     
4  | conv.3   | Conv2d     | 2 K   
5  | conv.4   | Tanh       | 0     
6  | conv.5   | AvgPool2d  | 0     
7  | conv.6   | Conv2d     | 48 K  
8  | conv.7   | Tanh       | 0     
9  | linear   | Sequential | 11 K  
10 | linear.0 | Linear     | 10 K  
11 | linear.1 | Tanh       | 0     
12 | linear.2 | Linear     | 850   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-08-08 21:35:25,293] Finished trial#1 with value: 0.14063578844070435 with parameters: {'num_epochs': 20, 'batch_size': 448, 'learning_rate': 5.565762192463009e-06}. Best is trial#0 with value: 0.46809110045433044.
GPU available: True, used: False
No environment variable for node rank defined. Set as 0.

   | Name     | Type       | Params
------------------------------------
0  | conv     | Sequential | 50 K  
1  | conv.0   | Conv2d     | 156   
2  | conv.1   | Tanh       | 0     
3  | conv.2   | AvgPool2d  | 0     
4  | conv.3   | Conv2d     | 2 K   
5  | conv.4   | Tanh       | 0     
6  | conv.5   | AvgPool2d  | 0     
7  | conv.6   | Conv2d     | 48 K  
8  | conv.7   | Tanh       | 0     
9  | linear   | Sequential | 11 K  
10 | linear.0 | Linear     | 10 K  
11 | linear.1 | Tanh       | 0     
12 | linear.2 | Linear     | 850   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-08-08 21:48:30,953] Finished trial#2 with value: 0.9900318384170532 with parameters: {'num_epochs': 34, 'batch_size': 152, 'learning_rate': 0.07249465638400594}. Best is trial#2 with value: 0.9900318384170532.


In [17]:
print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
best_trial = study.best_trial

print("  Value: {}".format(best_trial.value))

print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))

Number of finished trials: 3
Best trial:
  Value: 0.9900318384170532
  Params: 
    num_epochs: 34
    batch_size: 152
    learning_rate: 0.07249465638400594
