# Optuna 

In [12]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback
import os 
# os.environ["CUDA_VISIBLE_DEVICES"]="3"
from wwv.Architecture.ResNet.model import ResNet
from wwv.routine import Routine 
from wwv.eval import Metric
from wwv.util import OnnxExporter

import torch 
import torch.nn.functional as F 
from wwv.eval import Metric
import statistics
from wwv.data import AudioDataModule
import wwv.config as cfg
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation

import bisect 
import torch 
from pytorch_lightning import Trainer
import pytorch_lightning as pl 
import torch.nn.functional as F 
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor, ModelPruning

from torch.optim.lr_scheduler import ReduceLROnPlateau

from wwv.util import get_username
from wwv.Architecture.ResNet.model import ResNet
from wwv.Architecture.HTSwin.model import HTSwinTransformer
from wwv.Architecture.DeepSpeech.model import DeepSpeech
from wwv.Architecture.LeeNet.model import LeeNet
from wwv.Architecture.MobileNet.model import MobileNet

from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor,
)
from dotenv import load_dotenv
import os 
model_name = "ResNet"

STR_TO_MODEL_CFGS = {
    "HSTAT": cfg.HTSwin(),
    "ResNet": cfg.ResNet(),
    "DeepSpeech": cfg.DeepSpeech(),
    "LeeNet": cfg.LeeNet(),
    "MobileNet": cfg.MobileNet(),
}
STR_TO_MODELS = {
    "HSTAT": HTSwinTransformer,
    "ResNet": ResNet,
    "DeepSpeech": DeepSpeech,
    "LeeNet": LeeNet,
    "MobileNet": MobileNet,
}


cfg_model = STR_TO_MODEL_CFGS[model_name]
# select comp graph/model arch
model = STR_TO_MODELS[model_name]



env_filepath = os.getenv(
    "ENV_FILE_PATH", f"../env_vars/{model_name.lower()}/.dev.env"
)

print(f"Loading env vars from file: {env_filepath}")
load_dotenv(env_filepath)



# init the fitter <---- associated  data loaders and fitting routine to model

model = model
cfg_model = cfg_model
cfg_fitting = cfg.Fitting()
cfg_signal = cfg.Signal()
cfg_feature = cfg.Feature()






data_path = cfg.DataPath(
    f"/media/{get_username()}/Samsung_T5/data/audio/keyword-spotting", cfg_model.model_name, cfg_model.model_dir
)

def setup():
    '''
    Set up data module and loaders
    '''
    data_module = AudioDataModule(
        data_path.root_data_dir,
        cfg_model=cfg_model,
        cfg_feature=cfg_feature,
        cfg_fitting=cfg_fitting,
    )

    train_loader = data_module.train_dataloader()
    val_loader = data_module.val_dataloader()
    test_loader = data_module.test_dataloader()

    return data_module, train_loader, val_loader, test_loader

# get loaders and datamodule to access input shape
data_module, train_loader, val_loader, test_loader = setup()

# get input shape for onnx exporting
input_shape = data_module.input_shape
# init model

# Init a trainer to execute routineSTR_TO_MODELS
from wwv.util import OnnxExporter, CallbackCollection



# callback_dict = callbacks()
# callback_list = [v for (_, v) in callback_dict.items()]
number_devices = os.getenv("CUDA_VISIBLE_DEVICES", "1,").split(",")
try:
    number_devices.remove("")
except ValueError:
    pass




def get_callbacks():
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    early_stopping = EarlyStopping(mode="min", monitor='val_loss', patience=cfg_fitting.es_patience)
    checkpoint_callback = ModelCheckpoint(monitor="val_loss",
                                            dirpath=data_path.model_dir,
                                            save_top_k=1,
                                            mode="min",
                                            filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}-{val_ttr:.2f}-{val_ftr:.2f}')
    callbacks = [checkpoint_callback, lr_monitor, early_stopping]
    return callbacks 

# def callbacks():
#     cfg_fitting =cfg_fitting
#     data_path = data_path
#     callback_collection = CallbackCollection(cfg_fitting, data_path)
#     return callback_collection()

def objective(trial: optuna.trial.Trial) -> float:

    # We optimize the number of layers, hidden units in each layer and dropouts.
    dropout = trial.suggest_float("dropout", 0.2, 0.5)

    Model = STR_TO_MODELS[model_name]

    kwargs['dropout'] = dropout

    model = Model(**kwargs)
    # setup training, validating and testing routines for the model
    routine = Routine(model, cfg_fitting, cfg_model)
    callbacks = get_callbacks() + [PyTorchLightningPruningCallback(trial, monitor="val_acc")]

    trainer = Trainer(
        accelerator="gpu",
        devices=len(number_devices),
        strategy=os.getenv("STRATEGY", "ddp"),
        sync_batchnorm=True,
        max_epochs=cfg_fitting.max_epoch,
        callbacks=callbacks,
        num_sanity_val_steps=2,
        # resume_from_checkpoint=self.cfg_fitting.resume_from_checkpoint,
        gradient_clip_val=1.0,
        fast_dev_run=cfg_fitting.fast_dev_run,
    )



    hyperparameters = dict(dropout=dropout)
    trainer.logger.log_hyperparams(hyperparameters)
    trainer.fit(
        routine, train_dataloaders=train_loader, val_dataloaders=val_loader
    )  # ,ckpt_path=PATH)

    return trainer.callback_metrics["val_acc"].item()


# trainer.test(dataloaders=test_loader)

objective(
)

Loading env vars from file: ./env_vars/resnet/.dev.env


TypeError: objective() missing 1 required positional argument: 'trial'

In [None]:
from argparse import ArgumentParser

# if __name__ == "__main__":


parser = ArgumentParser(description="PyTorch Lightning example.")
parser.add_argument(
    "--pruning",
    "-p",
    action="store_true",
    help="Activate the pruning feature. `MedianPruner` stops unpromising "
    "trials at the early stages of training.",
)
args = parser.parse_args()

pruner = optuna.pruners.BasePruner = optuna.pruners.MedianPruner()


study = optuna.create_study(direction="maximize", pruner=pruner)
study.optimize(objective, n_trials=100, timeout=600)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))