# Imports

In [None]:
import random
import sys

import numpy as np
import pytorch_lightning as pl
import torch

In [None]:
sys.path.append('/mnt/home/rheinrich/taaowpf')

from data.cnn.wpf_dataset_germany_all_experiments import WPF_Germany_DataModule
from models.cnn.resnet import WPF_ResNet

# Train a separate ResNet model for each of the 8 experiments

In [None]:
# set seeds for reproducibility
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [None]:
for experiment in range(1, 9):
    ## Set Hyperparameters for model & training
    config = {
        'forecast_horizon': 8, # 8 hour ahead wind power forecast 
        'n_past_timesteps': 4, # Number of past time steps considered for prediction (excluding time step at prediction time)
        'resnet_version': 34,
        'forecast_version': 'single', #'all',
        'batch_size': 256,
        'num_workers': 32,
        'max_epochs': 100, #300
        'learning_rate': 0.001, # Default learning rate of the PyTorch Adam optimizer
        'p_adv_training': 1.0, # 0. for normal training ; 1.0 for adversarial training
        'eps_adv_training': 0.15, # Maximum perturbation caused by adversarial attacks.
        'step_num_adv_training': 100, # Number of PGD-iterations for adversarial attacks.
        'norm_adv_training': 'Linf', # Norm used to calculate adversarial attacks.
        'checkpoint_dirpath': './checkpoints_adversarial_training/',  # './checkpoints_normal_training/', 
        'checkpoint_filename': f'best_resnet_model_adversarial_training_experiment{experiment}', # f'best_resnet_model_normal_training_experiment{experiment}',  
    }
    
    # Initialize DataModule
    windspeed_dir = '/mnt/home/rheinrich/taaowpf/data/cnn/wind_speed_100m_germany_res10x10_012018_062021.csv'
    windpower_dir = '/mnt/home/rheinrich/taaowpf/data/cnn/windpower_germany_102018_062021.csv'
    
    datamodule = WPF_Germany_DataModule(windspeed_dir = windspeed_dir,
                                        windpower_dir = windpower_dir, 
                                        forecast_horizon = config['forecast_horizon'],  
                                        n_past_timesteps = config['n_past_timesteps'],
                                        batch_size = config['batch_size'], 
                                        num_workers = config['num_workers'],
                                        experiment = experiment,
                                       )
    
    ## Show input data
    #datamodule.setup()
    
    #fig, axs = plt.subplots(4,3, figsize=(15, 30))
    #axs = axs.flatten()
    #print(axs.shape)
    #for inputs, targets in datamodule.train_dataloader():  
    #    input_sample = inputs[0]
    #    target_sample = targets[0]

    #    print((inputs.shape, targets.shape))
    #    print(target_sample)
    #    for timestep, ax in zip(list(range(input_sample.shape[0])), axs):
    #        ax.set_title(str(timestep))
    #        im = ax.imshow(input_sample[timestep].numpy())
    #        fig.colorbar(im, ax = ax)

    #    plt.show()

    #    break    
    
    
    # Model
    ## Initiate model and trainer
    ### Callbacks
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor= 'val_loss',
        dirpath=config['checkpoint_dirpath'],
        filename=config['checkpoint_filename'], 
        save_top_k=1,
        mode='min')
    
    early_stopping = pl.callbacks.EarlyStopping(monitor = 'val_loss',
                                                patience = 15)
    
    ### Create model
    model = WPF_ResNet(resnet_version = config['resnet_version'],
                       forecast_version = config['forecast_version'],
                       forecast_horizon = config['forecast_horizon'],
                       n_past_timesteps = config['n_past_timesteps'],
                       learning_rate= config['learning_rate'],
                       p_adv_training = config['p_adv_training'],
                       eps_adv_training = config['eps_adv_training'],
                       step_num_adv_training = config['step_num_adv_training'],
                       norm_adv_training = config['norm_adv_training'])
    
    ### Create trainer
    trainer = pl.Trainer(max_epochs= config['max_epochs'],
                         devices = 1,
                         accelerator = 'gpu',
                         callbacks=[checkpoint_callback, early_stopping],
                        )
    
    ## Train model
    ### Fit model
    trainer.fit(model, datamodule = datamodule)
    
    ### Validate model
    #### Validation set
    #trainer.validate(model, datamodule = datamodule, ckpt_path = 'best')
    
    #### Test set
    #trainer.test(model, datamodule = datamodule, ckpt_path = 'best')