# Model Training and Testing

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import MinMaxScaler
from sunpy.coordinates.sun import carrington_rotation_number, carrington_rotation_time
from matplotlib.colors import Normalize
import pandas as pd
import os
import shutil
import datetime

data_dir = os.path.join(os.getcwd(), 'src', 'data')
huxt_utils_dir = os.path.join(os.getcwd(), 'src', 'huxt')
ml_utils_dir = os.path.join(os.getcwd(), 'src', 'ml')

# Add my utils to the path
import sys
sys.path.append(huxt_utils_dir)
sys.path.append(ml_utils_dir)

# Import my own modules from utils

from ensemble_analysis import evaluate_predictions
from data_loader import load_huxt_data_as_windows, load_omni_data
from data_functions import convert_hpo_to_boolean, train_test_val_split, balance_data
from ensemble_methods import *

## Main function

In [None]:
def main(f, run_number, target, 
         input_window_size, input_buffer_size,
         output_window_size, post_window_size, 
         variables, balance, nens, scale_p, model_class, 
         final_name, HUXt_dir, X_windows, y_windows, times, 
         subfolder, seed):
    """
    Call this function to make a forecast model using the given input parameters

    input: 
    - f                   : file - metric table file
    - run_number          : int - id for saving model outputs
    - target              : str - target variable name ('hp30' or 'hp60')
    - input_window_size   : int - length of input window (in hours)
    - input_buffer_size   : int - length of buffer zone (in hours)
    - output_window_size  : int - length of output window (in hours)
    - post_window_size    : int - length of post window (for plotting)
    - variables           : np.array - array of variable names to include in the model
    - balance             : bool - balances storm and non-storm output windows when True
    - nens                : int - number of HUXt ensemble to use
    - scale_p             : bool - scales ensemble probabilities when True
    - model_class         : sklearn model - Classification model from sklearn
    - final_name          : string - model name for final_classifier
    - HUXt_dir            : os.path - Path to the HUXt data directory
    - X_windows           : windows of training parameters
    - y_windows           : windows of target variable
    - times               : times corresponding to the windows
    - subfolder           : str - name of the subfolder to save plots to 
    - seed                : int - random seed for numpy
    """

    if final_name in ['persistence', '27_day_persistence']:
        nens = 1
    
    # Get the cadence factor of the dataset
    df = pd.read_parquet(os.path.join(HUXt_dir, 'HUXt_rotation_1892'))
    
    # Set constants:
    save_dir = subfolder
    main_dir = os.getcwd()

    # Balance the data
    if balance: 
        X_windows_balanced, y_windows_balanced, times_balanced = balance_data(X_windows, y_windows, times, input_window_size=input_window_size, post_window_size=post_window_size, storm_threshold=storm_training_threshold)
        y = convert_hpo_to_boolean(y_windows_balanced, input_window_size=input_window_size, post_window_size=post_window_size, storm_threshold=storm_training_threshold)
        # Split to train, val and test sets
        X_train, X_val, X_test, y_train_hpo, y_val_hpo, y_test_hpo, times_train, times_val, times_test = train_test_val_split(X_windows_balanced, 
                                                                                                                                 y_windows_balanced, 
                                                                                                                                 times_balanced,
                                                                                                                                 seed)
        
    else: 
        y = convert_hpo_to_boolean(y_windows, input_window_size=input_window_size, post_window_size=post_window_size)
        X_train, X_val, X_test, y_train_hpo, y_val_hpo, y_test_hpo, times_train, times_val, times_test = train_test_val_split(X_windows, 
                                                                                                                                 y_windows, 
                                                                                                                                 times,
                                                                                                                                 seed)

    scaler = MinMaxScaler(feature_range=(0, 1))                                                                                                                          
    def refactor_windows(X, y_hpo, name=''): 
        X[:,:,input_window_size-input_buffer_size:,-1] = 0   

        # Remove post_window from our input variables
        X_without_post = X[:,:,:-post_window_size]
        
        # Extract whether there was a storm or not based on hp60 values in the output window
        y_bool = convert_hpo_to_boolean(y_hpo, input_window_size=input_window_size, post_window_size=post_window_size)

        # Get target var during output window
        y = y_bool
        
        # Combine dimensions for rescaling
        X_reshaped = X_without_post.reshape(-1, X_without_post.shape[-1])
        
        # Scale the data and put it back to original shape
        scaled_X = scaler.fit_transform(X_reshaped).reshape(X_without_post.shape)

        # Combine last 2 dimensions
        scaled_X_reshaped = scaled_X.reshape(scaled_X.shape[:-2] + (-1,))

        # Remove (V - OMNI) for buffer and output window
        scaled_X_reshaped = scaled_X_reshaped[:, :, :-(input_buffer_size + output_window_size)]

        X_hpo = y_hpo[:, :input_window_size-input_buffer_size] if target in variables else None

        print(name, 'ensemble input shape', scaled_X_reshaped.shape)

        return X, scaled_X, scaled_X_reshaped, X_hpo, y, y_bool

    def remove_small_storms(X, y_hpo, times, storm_testing_threshold):
        ''' Removes storms above training threshold and below testing threshold '''
        # get the max hpo value for each output window
        y_max_hpo = np.max(y_hpo[:, input_window_size:-post_window_size], axis=1)

        # Extract indices for when we exceed large storm threshold and when we don't exceed storm threshold
        storm_indices = np.where(y_max_hpo >= storm_testing_threshold)[0]
        non_storm_indices = np.where(y_max_hpo < storm_training_threshold)[0]
        
        # Randomly drop non-storms to balance with the storm times
        non_storm_indices = np.random.choice(non_storm_indices, size=len(storm_indices), replace=False)

        # Combine indices
        all_indices = np.concatenate((storm_indices, non_storm_indices))

        # Extract correct parts of our arrays
        X_removed = X[all_indices]
        y_hpo_removed = y_hpo[all_indices]
        times_removed = times[all_indices]

        return X_removed, y_hpo_removed, times_removed
    
    def get_maes(X):
        X_input = X[:, :, :input_window_size - input_buffer_size, -1]
        maes = np.mean(np.abs(X_input), axis=-1)
        return maes

    # Remove storms based on testing threshold
    if storm_testing_threshold != storm_training_threshold:
        X_test, y_test_hpo, times_test = remove_small_storms(X_test, y_test_hpo, times_test, storm_testing_threshold=storm_testing_threshold)

    print(X_test.shape)
    
    # Extract arrays needed 
    print('Refactoring...')
    X_train, scaled_X_train, scaled_X_train_reshaped, X_train_hpo, y_train, y_train_bool = refactor_windows(X_train, y_train_hpo, 'train')
    X_val, scaled_X_val, scaled_X_val_reshaped, X_val_hpo, y_val, y_val_bool = refactor_windows(X_val, y_val_hpo, 'validation')
    X_test, scaled_X_test, scaled_X_test_reshaped, X_test_hpo, y_test, y_test_bool = refactor_windows(X_test, y_test_hpo, 'test')
    

    print(scaled_X_train_reshaped.shape, y_train.shape)
    # Create ensemble models
    model_params = {'max_iter':2000}
    print("Training Classifiers...")
    model_array = create_ensemble_models(scaled_X_train_reshaped, X_train_hpo, y_train, model_class, model_params)

    # Make ensemble predictions
    print('Making ensemble predictions...')
    
    train_predictions = make_ensemble_predictions(scaled_X_train_reshaped, X_train_hpo, y_train, model_array, predict_probabilities=True)
    test_predictions = make_ensemble_predictions(scaled_X_test_reshaped, X_test_hpo, y_test, model_array, predict_probabilities=True)
    val_predictions = make_ensemble_predictions(scaled_X_val_reshaped, X_val_hpo, y_val, model_array, predict_probabilities=True)

    # Decide whether to sort final_classifier input by MAE
    if final_name in ['logreg_sorted', 'attention_NN'] or final_name[:11] == 'logreg_top_':
        sort=True
    else:
        sort=False

    # Find MAES for input window
    train_maes = get_maes(X_train)
    val_maes = get_maes(X_val)
    test_maes = get_maes(X_test)

    if sort: 
        # Sort the arrays by their associated MAE
        train_indices = np.argsort(train_maes, axis=1)
        val_indices = np.argsort(val_maes, axis=1)
        test_indices = np.argsort(test_maes, axis=1)
        
        sorted_train_predictions = np.take_along_axis(train_predictions, train_indices, axis=1)
        sorted_val_predictions = np.take_along_axis(val_predictions, val_indices, axis=1)
        sorted_test_predictions = np.take_along_axis(test_predictions, test_indices, axis=1)
        sorted_train_maes = np.take_along_axis(train_maes, train_indices, axis=1)
        sorted_val_maes = np.take_along_axis(val_maes, val_indices, axis=1)
        sorted_test_maes = np.take_along_axis(test_maes, test_indices, axis=1)

        train_input = [sorted_train_predictions, sorted_train_maes]
        val_input = [sorted_val_predictions, sorted_val_maes]
        test_input = [sorted_test_predictions, sorted_test_maes]

    else:
        train_input = [train_predictions, train_maes]
        val_input = [val_predictions, val_maes]
        test_input = [test_predictions, test_maes]

    # Pass hpo for the input window 
    if final_name == 'persistence':
        test_input = y_test_hpo[:, :input_window_size-input_buffer_size]

    if final_name == '27_day_persistence':
        # Pass the times corresponding to the output window
        test_input = times_test[:, input_window_size:-post_window_size]


    # Train final classifier
    print('Training final classifier...')
    final_classifier = train_final_classifier(train_input, y_train, final_name)

    # Make probabilistic predictions
    print('Making final forecasts...')
    probabilistic_predictions = make_final_classifier_predictions(test_input, final_classifier, final_name, scale=scale_p)
    
    res = evaluate_predictions(probabilistic_predictions, y_test)
    
    # Write metrics to file
    f.write('\n')
    f.write('-'.join(variables))
    cadence_factor = 2
    f.write(f',{input_window_size//cadence_factor},{input_buffer_size//cadence_factor},{output_window_size//cadence_factor},{balance},{n_ensembles},{sort},{scale_p},{final_name},{storm_testing_threshold},')
    f.write(','.join([str(i) for i in res.values()]))

    # Make plots from model output
    tag = f"i{input_window_size}_o{output_window_size}_s{stride}_buff{input_buffer_size}_{'_'.join(variables)}_nens{nens}_sorted_{sort}_scaled_{scale_p}_final_{final_name}_thresh_{storm_testing_threshold}"
        
    test_tupe = (probabilistic_predictions, X_test, y_test, y_test_hpo, times_test, train_predictions, train_maes, final_name)

    print('Done')
    return test_tupe

## Setting Parameters

In [None]:
# Load in OMNI solar wind flow speed data (for visual comparisons only)
OMNI = load_omni_data(data_dir)

# Params for 
run_number = 1
HUXt_run_number = '1'

# Number of ensembles in the specified HUXt dataset
ensemble_choice = 100

# Window Sizes (48 -> 24 hours, 98 -> 49 hours etc. )
output_window_size = 48          # Period of which to forecast for
input_window_size = 98           # Total amount of data preceding output window (including the buffer)
input_buffer_size = 2            # Lead time for the forecast
post_window_size = 24            # Data to include succeeding forecast window
stride = output_window_size + 1  # Time between starts of successive windows

variables = ['velocity', 'gradient', 'v_minus_omni', 'target']  # Variables for training

n_ensembles = 100 # Max = no. ensembles in HUXt database

target = 'hp30'  # Target variable (must be 'hp30')

balance = True   # Whether to balance storm and non-storm
scale = True     # Whether to scale logreg output

model_class = LogisticRegression # Type of ensemble classifier
final_name = 'weighted_mean'     # Type of final classifier

start_cr = 1892   # Min = 1892
end_cr =2278     # Max = 2278

# Set training and testing thresholds for Hp30 index
storm_training_threshold = 4.66
storm_testing_threshold = 4.66

random_seed = 151201

## Training and Testing

A metric table for this test will be stored at src/figures/metric_tables/

In [None]:
# To overwrite a folder, you can change the following 'False' to 'True'
# Make sure to click run, then change back to 'False' to prevent unwanted overwrites 
OVERWRITE = False

# Define save name
leading_zero = 0 if run_number < 10 else ''
save_name = f'run_{leading_zero}{run_number}'

# Setup data directories
huxt_data_dir = os.path.join(os.getcwd(), 'src', 'data', 'huxt', f'HUXt{HUXt_run_number}_modified')
metric_table_dir = os.path.join(os.getcwd(), 'src', 'figures', 'metric_tables', f'{save_name}_metric_table.csv')
figure_dir = os.path.join(os.getcwd(), 'src', 'figures')

# replace 'target' with target index name
if 'target' in variables: 
    variables.append(target)
    variables.remove('target')

# Clear metric table
FIRST_WRITE = True

if OVERWRITE:
    # If we overwrite, we must set 'FIRST_WRITE = True'
    FIRST_WRITE = True
    with open(metric_table_dir, 'w') as f:
        pass


with open(metric_table_dir, 'a') as f:
    if FIRST_WRITE:
        f.write('Variables,Input Window Size (hours),Input Buffer Size (hours),Output Window Size (hours),Balanced,N_ensembles,Sorted,Scaled,Final Classifier,Storm Test Threshold,')
        # Use dummy values to get metric names
        f.write(','.join(evaluate_predictions(np.array([[1], [0]]), np.array([[1], [0]])).keys()))
        FIRST_WRITE = False
        
    # Load in data
    print(f'LOADING CR {start_cr} TO {end_cr}...')
    X_windows, y_windows, times = load_huxt_data_as_windows(huxt_data_dir,
        target, 
        input_window_size=input_window_size, 
        output_window_size=output_window_size, 
        post_window_size=post_window_size, 
        stride=stride,
        start_cr=start_cr,
        end_cr=end_cr,
        n_ensembles=n_ensembles,
        ensemble_choice=ensemble_choice,
        )

    # Run training and testing function
    output = main(f=f,
         run_number=run_number, 
         target=target, 
         input_window_size=input_window_size, 
         output_window_size=output_window_size,
         post_window_size=post_window_size,
         input_buffer_size=input_buffer_size,
         variables=variables,
         balance=balance,
         model_class=model_class,
         final_name=final_name,
         HUXt_dir=huxt_data_dir,
         X_windows=X_windows, 
         y_windows=y_windows,
         times=times,
         nens=n_ensembles,
         scale_p=scale,
         subfolder=save_name,
         seed=random_seed,)