# Imports

In [None]:
import random
import sys

import numpy as np
import pytorch_lightning as pl
import torch

sys.path.append('/mnt/home/rheinrich/taaowpf')

from data.lstm.wpf_dataset_single_turbine_gefcom import WPF_SingleTurbine_DataModule
from models.lstm.lstm import WPF_AutoencoderLSTM

# Train a separate LSTM model for each wind farm (zone 1 to 10).

In [None]:
# set seeds for reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [None]:
for zone in range(1, 11):
    ## Set Hyperparameters for model & training
    # hyperparameters after hyperparameter tuning
    config = {
        'forecast_horizon': 8, # 8 hour ahead wind power forecast 
        'n_past_timesteps': 12, # Number of past timesteps considered for prediction (including time step at prediction time)
        'hidden_size': 32,
        'num_layers': 1,
        'batch_size': 256,
        'num_workers': 32,
        'max_epochs': 100, 
        'learning_rate': 0.01, 
        'p_adv_training': 1.0, # 0. for normal training ; 1.0 for adversarial training
        'eps_adv_training': 0.15, # Maximum perturbation caused by adversarial attacks.
        'step_num_adv_training': 100, # Number of PGD-iterations for adversarial attacks.
        'norm_adv_training': 'Linf', # Norm used to calculate adversarial attacks
        'checkpoint_dirpath': './checkpoints_adversarial_training/', # './checkpoints_normal_training/',
        'checkpoint_filename': f'best_lstm_model_gefcom_adversarial_training_zone{zone}', # f'best_lstm_model_gefcom_normal_training_zone{zone}', 
    }
    
    # Initialize DataModule
    data_dir = f'/mnt/home/rheinrich/taaowpf/data/lstm/Gefcom2014_Wind/gefcom2014_W_100m_zone{zone}.csv'
    
    datamodule = WPF_SingleTurbine_DataModule(data_dir = data_dir,
                                              forecast_horizon = config['forecast_horizon'],
                                              n_past_timesteps = config['n_past_timesteps'],
                                              batch_size = config['batch_size'],
                                              num_workers = config['num_workers'],
                                             )
    
    ## Show input data
    #datamodule.setup()
    
    #fig, axs = plt.subplots(1,3, figsize=(15, 5))
    #axs = axs.flatten()

    #for inputs_windspeed, inputs_windpower, targets in datamodule.train_dataloader():  
    #    input_sample_windspeed = inputs_windspeed[0]
    #    input_sample_windpower = inputs_windpower[0]
    #    target_sample = targets[0]

    #    print((inputs_windspeed.shape, inputs_windpower.shape, targets.shape))
    #    print(target_sample)

    #    f1 = pd.DataFrame(input_sample_windspeed.numpy()).plot(title = "Input Wind Speed", ax = axs[0])
    #    f2 = pd.DataFrame(input_sample_windpower.numpy()).plot(title = "Input Wind Power", ax = axs[1])
    #    f3 = pd.DataFrame(target_sample.numpy()).plot(title = "Target", ax = axs[2])
    #    plt.show()

    #    break   
        
    # Model
    ## Initiate model and trainer
    ### Callbacks
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor= 'val_loss',
        dirpath=config['checkpoint_dirpath'],
        filename=config['checkpoint_filename'], 
        save_top_k=1,
        mode='min')
    
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
                                                patience = 15)
    
    ### Create model
    model = WPF_AutoencoderLSTM(forecast_horizon = config['forecast_horizon'],
                            n_past_timesteps = config['n_past_timesteps'],
                            hidden_size = config['hidden_size'],
                            num_layers = config['num_layers'],
                            learning_rate= config['learning_rate'],
                            p_adv_training = config['p_adv_training'],
                            eps_adv_training = config['eps_adv_training'],
                            step_num_adv_training = config['step_num_adv_training'],
                            norm_adv_training = config['norm_adv_training'])
    
    ### Create trainer
    trainer = pl.Trainer(max_epochs= config['max_epochs'],
                     devices = 1,
                     accelerator = 'gpu',
                     callbacks=[checkpoint_callback, early_stopping],
                    )
    
    ## Train model
    ### Fit model
    trainer.fit(model, datamodule = datamodule)
    
    ### Validate model
    #### Validation set
    # trainer.validate(model, datamodule = datamodule, ckpt_path = 'best')
    
    #### Test set
    # trainer.test(model, datamodule = datamodule, ckpt_path = 'best')