In [1]:
import warnings
warnings.filterwarnings('ignore')

import shutil
import pytorch_lightning as pl
import torch
import yaml
from audio_classification.tools import *
import optuna
from optuna.integration import PyTorchLightningPruningCallback

In [2]:
import yaml
with open("/nfs/students/winter-term-2020/project-1/project-1/audio_classification/configs/m11_bmw.yaml", "r") as config_file:
    cfg = yaml.load(config_file)

In [3]:
from pytorch_lightning import Callback

class MetricsCallback(Callback):
    """PyTorch Lightning metric callback."""
    def __init__(self):
        super().__init__()
        self.metrics = []

    def on_validation_end(self, trainer, pl_module):
        self.metrics.append(trainer.callback_metrics)

In [4]:
def objective(trial):
    # Use this callback to collect the validation accuracies
    metrics_callback = MetricsCallback()
    
    trial_hparams = {"batch_size": trial.suggest_int("batch_size", 1, 12), 
                     "learning_rate": trial.suggest_loguniform("learning_rate", 1e-5, 1e-2),
                     "weight_decay": trial.suggest_loguniform("weight_decay", 1e-8, 1e-3)
                    }
    
    train_loader, val_loader, test_loader, class_weights = get_dataloader(cfg, trial_hparams, transform=get_transform(cfg))
    
    if class_weights is not None:
        class_weights = torch.tensor(class_weights).to(device=torch.device('cuda'))
    
    early_stop_callback=PyTorchLightningPruningCallback(trial, monitor="val_acc")
    trainer = pl.Trainer(gpus=cfg["SOLVER"]["NUM_GPUS"],
                         min_epochs=cfg["SOLVER"]["MIN_EPOCH"],
                         max_epochs=10,
                         progress_bar_refresh_rate=10,
                         callbacks=[metrics_callback, early_stop_callback],
                         logger=True,
                         
                        )
    
    model = get_model(cfg, class_weights, trial_hparams, train_loader, val_loader)
    model.prepare_data()
    trainer.fit(model)

    save_model(model, '{}.p'.format(trial.number), "checkpoints")

    # return validation accuracy from latest model, as that's what we want to maximize by our hyper param search
    return metrics_callback.metrics[-1]["val_acc"]

In [None]:
pruner = optuna.pruners.NopPruner()
study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=20, timeout=600)

[32m[I 2020-11-26 14:47:30,856][0m A new study created in memory with name: no-name-da43d366-3eca-4673-9023-4862bf9de518[0m


In [None]:
print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
best_trial = study.best_trial

print("  Value: {}".format(best_trial.value))

print("  Params: ")
for key, value in best_trial.params.items():
    print("    {}: {}".format(key, value))

In [None]:
shutil.rmtree("checkpoints")