In [1]:
import os
import shutil

import pytorch_lightning as pl
from pytorch_lightning import Callback
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
import torch.utils.data
from torchvision import datasets
from torchvision import transforms

import optuna
from optuna.integration import PyTorchLightningPruningCallback

In [2]:
pl.__version__

'0.8.0'

In [3]:
PERCENT_VALID_EXAMPLES = 0.1
BATCH_SIZE = 128
CLASSES = 10
EPOCHS = 10
DIR = os.getcwd()
MODEL_DIR = os.path.join(DIR, 'result')

In [9]:
class MetricsCallback(Callback):
    '''PyTorch Lightning metric callback'''
    
    def __init__(self):
        self.metrics = []
        
    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

In [10]:
class Net(nn.Module):
    def __init__(self, trial):
        super().__init__()
        
        self.layers = []
        self.dropouts = []
        
        # We optimize the number of layers, hidden units in each layer and dropouts.
        n_layers = trial.suggest_int('n_layers', 1, 3)
        dropout = trial.suggest_uniform('dropout', 0.2, 0.5)
        
        input_dim = 28*28
        
        for i in range(n_layers):
            output_dim = int(trial.suggest_loguniform('n_units_l{}'.format(i), 4, 128))
            self.layers.append(nn.Linear(input_dim, output_dim))
            self.dropouts.append(nn.Dropout(dropout))
            input_dim = output_dim
            
        self.layers.append(nn.Linear(input_dim, CLASSES))
        
        # Assigning the layers as class variables
        for idx, layer in enumerate(self.layers):
            setattr(self, "fc{}".format(idx), layer)
            
        # Assigning the dropouts as class variables.
        for idx, dropout in enumerate(self.dropouts):
            setattr(self, 'drop{}'.format(idx), dropout)
            
    def forward(self, data):
        data = data.view(-1, 28*28)
        
        for layer, dropout in zip(self.layers, self.dropouts):
            # This loop only goes till 2nd last layer, since dropouts is smaller than layers list
            data = F.relu(layer(data))
            data = dropout(data)
            
        return F.log_softmax(self.layers[-1](data), dim=1)

In [11]:
class LightningNet(pl.LightningModule):
    def __init__(self, trial):
        super().__init__()
        self.model = Net(trial)
        
    def forward(self, data):
        return self.model(data)
        
    def training_step(self, batch, batch_nb):
        data, target = batch
        output = self.forward(data)
        return {'loss': F.nll_loss(output, target)}
    
    def validation_step(self, batch, batch_nb):
        data, target = batch
        output = self.forward(data)
        pred = output.argmax(dim=1, keepdim=True)
        accuracy = pred.eq(target.view_as(pred)).float().mean()
        return {"batch_val_acc": accuracy}
    
    def validation_epoch_end(self, outputs):
        accuracy = sum(x['batch_val_acc'] for x in outputs) / len(outputs)
        return {'log': {'val_acc': accuracy}}
    
    def configure_optimizers(self):
        return Adam(self.model.parameters())
    
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor()),
            batch_size=BATCH_SIZE, shuffle=True)
    
    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            datasets.MNIST(DIR, train=False, download=True, transform=transforms.ToTensor()),
            batch_size=BATCH_SIZE, shuffle=False)

In [12]:
def objective(trial):
    # Filenames for each trial must be unique in order to access each checkpoint
    
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        os.path.join(MODEL_DIR, 'trial_{}'.format(trial.number), '{epoch}'), monitor='val_acc')
    
    # The default logger in pytorch lightning writes events to files to be consumed by 
    # tensorboard. We dont use any logger here as it requires us to implement several 
    # abstract methods. Instead, we setup a simple callback, that saves metrics for each
    # validation step.
    
    metrics_callback = MetricsCallback()
    
    trainer = pl.Trainer(
        logger=False,
        val_percent_check=PERCENT_VALID_EXAMPLES,
        checkpoint_callback=checkpoint_callback,
        max_epochs=EPOCHS,
        gpus=None,
        callbacks=[metrics_callback],
        early_stop_callback=PyTorchLightningPruningCallback(trial, monitor='val_acc')
    )
    
    model = LightningNet(trial)
    trainer.fit(model)
    
    return metrics_callback.metrics[-1]["val_acc"].item()

In [13]:
pruner = optuna.pruners.MedianPruner()
study = optuna.create_study(direction='maximize', pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)

print('No. of finished trials: {}'.format(len(study.trials)))
print('Best trial: ')
trial = study.best_trial
print('  Value: {}'.format(trial.value))

print('  Params:  ')
for key, value in trial.params.items():
    print('    {}: {}'.format(key, value))
shutil.rmtree(MODEL_DIR)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 69 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:46:07,417] Finished trial#0 with value: 0.9732142686843872 with parameters: {'n_layers': 2, 'dropout': 0.23973037250104887, 'n_units_l0': 82.37932730709898, 'n_units_l1': 59.093300921429176}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 10 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:48:23,810] Finished trial#1 with value: 0.8069196343421936 with parameters: {'n_layers': 3, 'dropout': 0.33659987829915794, 'n_units_l0': 7.483780258524184, 'n_units_l1': 31.25145750431548, 'n_units_l2': 122.69980799695885}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 7 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:49:19,439] Finished trial#2 with value: 0.6964285969734192 with parameters: {'n_layers': 3, 'dropout': 0.4257747055608976, 'n_units_l0': 9.551274271107642, 'n_units_l1': 31.0610220217365, 'n_units_l2': 14.860018226891714}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 61 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:50:18,311] Finished trial#3 with value: 0.9084821343421936 with parameters: {'n_layers': 2, 'dropout': 0.35960264121540797, 'n_units_l0': 78.97461047112313, 'n_units_l1': 7.906216081067158}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 21 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:51:14,283] Finished trial#4 with value: 0.9129464030265808 with parameters: {'n_layers': 1, 'dropout': 0.49743592922065843, 'n_units_l0': 27.031501873050015}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 9 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:52:55,141] Finished trial#5 with value: 0.7600446343421936 with parameters: {'n_layers': 3, 'dropout': 0.40684861579427284, 'n_units_l0': 9.994460875218575, 'n_units_l1': 14.565160468289418, 'n_units_l2': 110.60597321791592}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 9 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:53:44,959] Finished trial#6 with value: 0.8861607313156128 with parameters: {'n_layers': 1, 'dropout': 0.4446681877679336, 'n_units_l0': 12.60581517973499}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 9 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:54:39,288] Finished trial#7 with value: 0.875 with parameters: {'n_layers': 3, 'dropout': 0.23105506542225138, 'n_units_l0': 8.01805492679793, 'n_units_l1': 120.2799786612332, 'n_units_l2': 13.111385306016189}. Best is trial#0 with value: 0.9732142686843872.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type | Params
-------------------------------
0 | model | Net  | 3 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




[I 2020-06-26 14:55:47,136] Finished trial#8 with value: 0.53125 with parameters: {'n_layers': 2, 'dropout': 0.460941523278537, 'n_units_l0': 4.410444341667207, 'n_units_l1': 6.176153150474766}. Best is trial#0 with value: 0.9732142686843872.


No. of finished trials: 9
Best trial: 
  Value: 0.9732142686843872
  Params:  
    n_layers: 2
    dropout: 0.23973037250104887
    n_units_l0: 82.37932730709898
    n_units_l1: 59.093300921429176
