# Imports, helper functions, configuration settings

In [None]:
# Standard library imports
import pickle

# Local imports
from config import (
    EPOCHS,
    LEARNING_RATE,
    SLIDER_LENGTH,
    MODEL,
    RUN_ID_MS,
    INPUT_LIST,
    VARIABLE,
)
from data import (
    normalize,
    compute_mean_std,
    normalize_input_data,
    concatenate_training_data,
)
from models import CNNModel, ResNet18
from plotting import plot_training_loss, global_timeseries_plot, global_anomaly_plot
from tools import train_k_fold, make_predictions, RMSE, predict_model, metrics

# Data Formatting

In [None]:
# 3 EM average:
file_path = "/discover/nobackup/jmekus/loaded_DT_data_EMavg.pkl"
with open(file_path, "rb") as f:
    loaded_data = pickle.load(f)
    
# Retrieve the lists from the loaded data
X_train = loaded_data["X_train"]
X_test = loaded_data["X_test"]
Y_train = loaded_data["Y_train"]
Y_test = loaded_data["Y_test"]

# Normalize Input Data
meanstd_inputs = compute_mean_std(INPUT_LIST, X_train)
X_train_norm, X_test_xr = normalize_input_data(
    X_train, X_test, INPUT_LIST, meanstd_inputs, normalize
)

# Reshape input/output for training
X_train_all, Y_train_all = concatenate_training_data(
    X_train, Y_train, X_train_norm, VARIABLE, SLIDER_LENGTH
)

# Model Training

In [None]:
# Set model
input_channels = len(INPUT_LIST)
output_channels = len(VARIABLE)
if MODEL == 'CNN':
    model = CNNModel(input_channels, output_channels)
elif MODEL == 'RESNET':
    model = ResNet18(input_channels, output_channels)
    
# Train model
model, all_train_losses = train_k_fold(
    INPUT_LIST, VARIABLE, LEARNING_RATE, X_train_all, Y_train_all, EPOCHS, model
)

# Plot training curve
plot_training_loss(EPOCHS, all_train_losses)

# Model Predictions

In [None]:
# Make predictions on SSP126/245/585
(
    Y_pred,
    Y_test,
    Y_pred_585,
    Y_test_585,
    Y_pred_126,
    Y_test_126,    
) = make_predictions(
    X_test,
    Y_test,
    X_train,
    Y_train,
    model,
    INPUT_LIST,
    meanstd_inputs,
    SLIDER_LENGTH,
    RUN_ID_MS,
    predict_model,
    VARIABLE,
)

# Evaluation

In [None]:
# Calculate RMSE for 126/245/585
rmse_245 = RMSE(Y_pred, Y_test, VARIABLE)
rmse_126 = RMSE(Y_pred_126, Y_test_126, VARIABLE)
rmse_585 = RMSE(Y_pred_585, Y_test_585, VARIABLE)
print(f"RMSE for SSP245: {rmse_245}")
print(f"RMSE for SSP126: {rmse_126}")
print(f"RMSE for SSP585: {rmse_585}")

In [None]:
# Calculate spatial and global RMSE
spatial_RMSE, global_RMSE = metrics(Y_hat=Y_pred, Y_test=Y_test, VARIABLE=VARIABLE)

In [None]:
# Plot timeseries
global_timeseries_plot(Y_pred=Y_pred, Y_train=Y_train, Y_test=Y_test, VARIABLE=VARIABLE)

In [None]:
# Plot spatial anomalies
global_anomaly_plot(Y_pred=Y_pred, Y_test=Y_test, p_value=0.05, VARIABLE=VARIABLE)