# Optuna 

In [1]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import os 
# os.environ["CUDA_VISIBLE_DEVICES"]="3"
from wwv.Architecture.ResNet.model import ResNet
from wwv.routine import Routine 
from wwv.eval import Metric
from wwv.util import OnnxExporter

import torch 
import torch.nn.functional as F 
from wwv.eval import Metric
import statistics
from wwv.data import AudioDataModule
import wwv.config as cfg
from wwv.meta import params as params 
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation

import bisect 
import torch 
from pytorch_lightning import Trainer
import pytorch_lightning as pl 
import torch.nn.functional as F 
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor, ModelPruning

from torch.optim.lr_scheduler import ReduceLROnPlateau


torch.cuda.is_available()
cfg_fitting = cfg.Fitting()
cfg_resnet = cfg.ResNet()


params['model_name'] = "ResNet"
Cfg = cfg.Config(params)

params
Cfg = cfg.Config(params)
data_path = cfg.DataPath("/home/akinwilson/Code/HTS-Audio-Transformer", Cfg.model_name, Cfg.path['model_dir'])

data_module = AudioDataModule(data_path.root_data_dir + "/train.csv",
                              data_path.root_data_dir + "/val.csv",
                              data_path.root_data_dir + "/test.csv",
                              cfg=Cfg,
                              cfg_fitting=cfg_fitting)
                              
train_loader =  data_module.train_dataloader()
val_loader =  data_module.val_dataloader()
test_loader =  data_module.test_dataloader()

cfg = Cfg 



def get_callbacks():
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    early_stopping = EarlyStopping(mode="min", monitor='val_loss', patience=cfg_fitting.es_patience)
    checkpoint_callback = ModelCheckpoint(monitor="val_loss",
                                            dirpath=data_path.model_dir,
                                            save_top_k=1,
                                            mode="min",
                                            filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}-{val_ttr:.2f}-{val_ftr:.2f}')
    callbacks = [checkpoint_callback, lr_monitor, early_stopping]
    return callbacks 




def objective(trial: optuna.trial.Trial) -> float:

    # We optimize the number of layers, hidden units in each layer and dropouts.
    dropout = trial.suggest_float("dropout", 0.2, 0.5)
    model = ResNet(num_blocks=cfg_resnet.num_blocks, cfg=cfg, dropout=dropout)
    callbacks = get_callbacks() + [PyTorchLightningPruningCallback(trial, monitor="val_acc")]
    logger = TensorBoardLogger(save_dir=data_path.model_dir, version=1, name="lightning_logs")

    trainer = Trainer(accelerator="gpu",
                    devices=1,
                    strategy='dp',
                    logger = logger, 
                    default_root_dir=data_path.model_dir,
                    callbacks=callbacks)


    hyperparameters = dict(dropout=dropout)
    trainer.logger.log_hyperparams(hyperparameters)

    trainer.fit(Routine(model, cfg), train_dataloaders=train_loader, val_dataloaders=val_loader)
    return trainer.callback_metrics["val_acc"].item()


# trainer.test(dataloaders=test_loader)

NameError: name 'trainer' is not defined

In [None]:
from argparse import ArgumentParser

# if __name__ == "__main__":


parser = ArgumentParser(description="PyTorch Lightning example.")
parser.add_argument(
    "--pruning",
    "-p",
    action="store_true",
    help="Activate the pruning feature. `MedianPruner` stops unpromising "
    "trials at the early stages of training.",
)
args = parser.parse_args()

pruner = optuna.pruners.BasePruner = optuna.pruners.MedianPruner()


study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))