# Imports

In [None]:
import os
import sys

import pytorch_lightning as pl

sys.path.append('/mnt/home/rheinrich/taaowpf')

from data.lstm.wpf_dataset_single_turbine_gefcom import WPF_SingleTurbine_DataModule
from models.lstm.lstm import WPF_AutoencoderLSTM

In [None]:
os.environ["PYTHONPATH"] = '/mnt/home/rheinrich/taaowpf'

In [None]:
# Hyperparameter Tuning
import logging

from pytorch_lightning.loggers import TensorBoardLogger

import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
from ray.tune.integration.pytorch_lightning import TuneReportCallback

## Initialize Ray

In [None]:
ray.init(
    object_store_memory=32000000000,
    include_dashboard=False,
    ignore_reinit_error=True,
    num_cpus=32,
    num_gpus=3,
    _temp_dir="/mnt/home/rheinrich/ray/tmp",
    logging_level =logging.WARNING,
    log_to_driver = False
)

# Hyperparameter Tuning

https://docs.ray.io/en/releases-1.11.0/tune/tutorials/tune-pytorch-lightning.html

## ASHA Scheduler

### Function used for Training during Hyperparameter Tuning

In [None]:
def train_model_tune(config, data_dir, checkpoint_dir=None, num_gpus=0):
    
    # initiate DataModule
    datamodule = WPF_SingleTurbine_DataModule(data_dir = data_dir,
                                              forecast_horizon = config['forecast_horizon'],  
                                              n_past_timesteps = config['n_past_timesteps'],
                                              batch_size = config['batch_size'], 
                                              num_workers = config['num_workers_datamodule'],
                                             )

    # define callback for Early Stopping
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
                                                min_delta = 1e-6,
                                                patience = 15)
    
    # use the Ray Tune callback TuneReportCallback to report metrics back to Tune after each validation epoch
    tune_report_callback = TuneReportCallback({"loss": "val_loss", "val_rmse": "val_rmse"}, on="validation_end")
    
    # initiate model
    model = WPF_AutoencoderLSTM(forecast_horizon = config['forecast_horizon'],
                                n_past_timesteps = config['n_past_timesteps'],
                                hidden_size = config['hidden_size'],
                                num_layers = config['num_layers'],
                                learning_rate= config['learning_rate'],
                                p_adv_training = config['p_adv_training'],
                                eps_adv_training = config['eps_adv_training'],
                                step_num_adv_training = config['step_num_adv_training'],
                                norm_adv_training = config['norm_adv_training'])
    
    # initiate Trainer
    trainer = pl.Trainer(max_epochs = config['max_epochs'],
                         devices = 1,
                         accelerator = 'gpu',
                         logger=TensorBoardLogger(save_dir=tune.get_trial_dir(), name="", version="."),
                         enable_progress_bar = False,
                         enable_checkpointing=False, # otherwise memory gets too large during hyperparameter tuning
                         callbacks=[tune_report_callback, early_stopping],
                        )
    
    # fit model
    trainer.fit(model, datamodule = datamodule)

### Function used for Hyperparameter Tuning

In [None]:
def tune_hyperparams_asha(data_dir, num_samples, gpus_per_trial=0, cpus_per_trial=1, grace_period_asha = 10):
    # configure the search space
    config = {
        "n_past_timesteps": tune.randint(lower = 1, upper = 25),
        "hidden_size": tune.choice([32, 64, 96, 128, 160, 192, 224, 256]),
        "num_layers": tune.randint(lower = 1, upper = 4),
        "learning_rate": tune.loguniform(1e-5, 1e-1),
        "forecast_horizon": 8,
        "max_epochs": 100,
        "batch_size": 256,
        "num_workers_datamodule": 0,
        "p_adv_training": 0.0, # probability is zero, so no adversarial training is used for hyperparameter tuning
        "eps_adv_training": 0.1,
        "step_num_adv_training": 100,
        "norm_adv_training": 'Linf',
    }

    # select a scheduler / algorithm for hyperparameter tuning
    scheduler = ASHAScheduler(
        max_t=config['max_epochs'],
        grace_period=grace_period_asha
    )

    # define the desired CLI output
    reporter = CLIReporter(
        #parameter_columns= list(config.keys()), # shows all hyperparemters
        parameter_columns= ["batch_size"], 
        metric_columns=["loss", "val_rmse", "training_iteration"])

    # pass constants to the train function
    train_fn_with_parameters = tune.with_parameters(train_model_tune,
                                                    num_gpus=gpus_per_trial,
                                                    data_dir = data_dir)
    
    # specify how many resources Tune should request for each trial
    resources_per_trial = {"cpu": cpus_per_trial, "gpu": gpus_per_trial}

    # start Tune
    analysis = tune.run(train_fn_with_parameters,
                        resources_per_trial=resources_per_trial,
                        metric="loss",
                        mode="min",
                        config=config,
                        num_samples=num_samples,
                        scheduler=scheduler,
                        progress_reporter=reporter,
                        name="tune_single-turbine_model_gefcom_202212", 
                        local_dir = "/mnt/home/rheinrich/ray/ray_results/taaowpf",
                        keep_checkpoints_num = 100,
                        checkpoint_score_attr = 'min-loss',
                        verbose = 1
                       )

    print("Best hyperparameters found were: ", analysis.best_config)

## Start Hyperparameter Tuning

In [None]:
data_dir = '/mnt/home/rheinrich/taaowpf/data/lstm/Gefcom2014_Wind/gefcom2014_W_100m_zone1.csv'

In [None]:
tune_hyperparams_asha(data_dir = data_dir, 
                      num_samples=1000, # number of times to sample from the hyperparameter space
                      gpus_per_trial=1,
                      cpus_per_trial = 1,
                      grace_period_asha = 20
                     )

## Get Results of Hyperparameter Tuning

#### Path to Experiment (make sure, every hyperparameter tuning experiment is stored in a separate folder!)

In [None]:
experiment_path = "/mnt/home/rheinrich/ray/ray_results/taaowpf/tune_single-turbine_model_gefcom_202212"

#### Load Experiment

In [None]:
from ray.tune import ExperimentAnalysis
analysis = ExperimentAnalysis(experiment_path, default_metric = "loss", default_mode = "min")

## Shutdown Ray

In [None]:
ray.shutdown()