In [1]:
import json
from pathlib import Path
from operator import itemgetter
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score

import sys
import os

import plotly.graph_objs as go
import plotly.subplots as sp
from plotly.offline import init_notebook_mode, plot, iplot
import plotly.express as px
from plotly.subplots import make_subplots

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import *
from train import create_model

In [2]:
experiment = "1"

In [3]:
# List all model directories
ray_results = Path("../ray_results/")
model_dirs = [d for d in ray_results.iterdir() if d.is_dir()]

In [4]:
# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=5):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)
                progress_data = pd.read_csv(progress_file)

                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]

In [5]:
def get_dataloader(params):
    d = Data(params["data_file"], params["datetime"])
    data_loader = d.prepare_data(**params["data"])
    return data_loader

In [6]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [7]:
def get_preds_actuals(model, test_dataloader):
    model.eval()
    all_preds = []
    all_actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_dataloader:
            batch_output = model(X_batch.float())
            all_preds.extend(batch_output.squeeze().numpy())
            all_actuals.extend(y_batch.numpy())

    return all_preds, all_actuals

In [8]:
def load_model_from_checkpoint(model, checkpoint_path):
    try:
        model_state_dict, _ = torch.load(checkpoint_path)
        model.load_state_dict(model_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {checkpoint_path}")
        print(e)
    return model

In [9]:
def plot_pred_actual(data_loader, model, model_name):
    val_loader, test_loader = data_loader['val'], data_loader['test']
    y_val = [y for _, y in val_loader]
    y_val = np.concatenate(y_val)
    y_preds, y_test = get_preds_actuals(model, test_loader)
    
    # Get datetime values
    datetime_val = val_loader.datetime_index
    datetime_test = test_loader.datetime_index

    # Create dataframes
    val_df = pd.DataFrame({'datetime': datetime_val, 'target': y_val})
    test_df = pd.DataFrame({'datetime': datetime_test, 'target': y_test})
    predictions_df = pd.DataFrame({'datetime': datetime_test, 'predictions': y_preds})

    # Create a scatter plot for each dataset
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=val_df['datetime'], y=val_df['target'],
                        mode='lines', name='Validation'))
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'],
                        mode='lines', name='Test'))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'],
                        mode='lines', name='Predictions'))

    # Add spread plot between Test and Predictions
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'], fill=None,
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', showlegend=False))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'], fill='tonexty',
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', name='Spread'))

    # Add axis labels and plot title
    fig.update_layout(
        title=f'Time Series Data with {model_name} Predictions',
        xaxis_title='Datetime',
        yaxis_title='Target Value',
    )

    # Show the plot
    fig.show()

# Plotting the data and predictions

In [10]:
for model_dir in model_dirs:
    if experiment not in str(model_dir):
        continue
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        plot_pred_actual(
            data_loader,
            model,
            model_name=params["model"],
        )

Best performing model on experiment: data_4-1-lstm


Best performing model on experiment: data_4-1-spatio_temporal


Best performing model on experiment: data_4-1-temporal


# Plotting the loss

In [11]:
def get_losses(run_dir):
    if run_dir.is_dir():
        # Read the progress.csv file to get the validation losses
        progress_file = run_dir / "progress.csv"
        if progress_file.exists():
            return pd.read_csv(progress_file)[['train_loss', 'val_loss']]

def plot_losses(losses):
    fig = go.Figure()

    fig.add_trace(go.Scatter(y=losses['train_loss'], mode='lines', name='Train Loss'))
    fig.add_trace(go.Scatter(y=losses['val_loss'], mode='lines', name='Validation Loss'))

    fig.update_layout(title='Train and Validation Losses', xaxis_title='Epoch', yaxis_title='Loss')

    fig.show()

for model_dir in model_dirs:
    if experiment not in str(model_dir):
        continue
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Access run_dir from checkpoint_path
        print(f"Best model from {model_dir}")
        run_dir = checkpoint.parents[1]
        losses = get_losses(run_dir)
        plot_losses(losses)


Best model from ..\ray_results\data_4-1-lstm


Best model from ..\ray_results\data_4-1-spatio_temporal


Best model from ..\ray_results\data_4-1-temporal


# Hour ahead forecast

In [12]:
model_dfs = {}
for model_dir in model_dirs:
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=5)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        test_loader = data_loader['test']

        y_preds, y_test = get_preds_actuals(model, test_loader)

        # Calculate the Mean Absolute Error (MAE)
        mae = mean_absolute_error(y_test, y_preds)
        # Calculate the Root Mean Squared Error (RMSE)
        rmse = np.sqrt(mean_squared_error(y_test, y_preds))
        # Calculate the Mean Absolute Percentage Error (MAPE)
        mape = mean_absolute_percentage_error(y_test, y_preds)
        # Calculate the Determination Coefficient (R^2)
        r2 = r2_score(y_test, y_preds)

        rows.append(
            {
                "model": params["model"],
                "val_mae": val_loss,
                "test_mae": mae,
                "rmse": rmse,
                "mape": mape,
                "r2": r2,
                "variables": params["data"]["variables"],
            }
        )
    df = pd.DataFrame(rows).sort_values("test_mae")
    model_dfs[model_dir.name] = df


In [13]:
torch.cuda.empty_cache()

In [14]:
pd.set_option('max_colwidth', 500)

In [15]:
model_dfs.get(f"data_4-{experiment}-lstm")

Unnamed: 0,model,val_mae,test_mae,rmse,mape,r2,variables
1,LSTM,0.393925,0.421661,0.895829,0.051693,0.986557,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."
4,LSTM,0.4465,0.444891,1.052171,0.048105,0.981455,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."
0,LSTM,0.208098,0.491119,1.229909,0.044536,0.97466,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
3,LSTM,0.445775,1.046228,2.317164,0.094184,0.910056,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
2,LSTM,0.408358,2.111323,3.56039,0.280829,0.787648,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Air_Temperature_Fister, Precipitation_Fister, Flow_Lyngsvatn_Overflow, Flow_Tapping, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Precipitation_Nilsebu, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Mu..."


In [16]:
model_dfs.get(f"data_4-{experiment}}-temporal")

Unnamed: 0,model,val_mae,test_mae,rmse,mape,r2,variables
0,LSTMTemporalAttention,0.17784,0.217467,0.618214,0.022463,0.993598,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
3,LSTMTemporalAttention,0.195202,0.325253,1.020972,0.028004,0.982538,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
1,LSTMTemporalAttention,0.191166,0.345888,0.802695,0.04442,0.989207,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
2,LSTMTemporalAttention,0.192919,0.385117,0.71269,0.060404,0.991491,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
4,LSTMTemporalAttention,0.213901,0.396865,0.836251,0.038355,0.988285,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"


In [17]:
model_dfs.get(f"data_4-{experiment}-spatio_temporal")

Unnamed: 0,model,val_mae,test_mae,rmse,mape,r2,variables
0,LSTMSpatioTemporalAttention,0.167536,0.251866,0.521203,0.032792,0.995449,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
1,LSTMSpatioTemporalAttention,0.198232,0.27757,0.566039,0.036746,0.994633,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Air_Temperature_Fister, Precipitation_Fister, Flow_Lyngsvatn_Overflow, Flow_Tapping, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Precipitation_Nilsebu, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Mu..."
4,LSTMSpatioTemporalAttention,0.238339,0.617871,1.1962,0.068751,0.97603,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."
3,LSTMSpatioTemporalAttention,0.23114,0.651344,1.019112,0.080424,0.982602,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
2,LSTMSpatioTemporalAttention,0.202778,0.652884,0.903606,0.137459,0.986322,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."


In [18]:
# concatenate the dataframes
df_concat = pd.concat([model_dfs.get(f"data_4-{experiment}-lstm"), model_dfs.get(f"data_4-{experiment}-temporal"), model_dfs.get(f"data_4-{experiment}-spatio_temporal")])

df_concat = df_concat.drop(columns=['variables'])

# calculate the mean of each evaluation metric
df_avg = df_concat.groupby(['model']).mean()
df_avg.sort_values("test_mae")

Unnamed: 0_level_0,val_mae,test_mae,rmse,mape,r2
model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
LSTMTemporalAttention,0.194206,0.334118,0.798164,0.038729,0.989024
LSTMSpatioTemporalAttention,0.207605,0.490307,0.841232,0.071234,0.987007
LSTM,0.380531,0.903044,1.811093,0.103869,0.928075


Based on the provided results, the following observations can be made:

1. LSTMSpatioTemporalAttention model has the best overall performance among the three models. It has the lowest validation mean absolute error (val_mae) of 0.167536 and a high coefficient of determination (r2) of 0.995449 in model 0.

2. LSTMTemporalAttention model also performs well, with model 0 having the lowest val_mae of 0.177840 and an r2 of 0.993598.

3. LSTM model performance is comparatively lower than the other two models. However, it still shows reasonable performance with model 4 having the lowest val_mae of 0.519872 and an r2 of 0.985007.

In conclusion, the LSTMSpatioTemporalAttention model generally outperforms the other two models across the different sets of input variables. The best performing model is LSTMSpatioTemporalAttention model 0, with the lowest validation mean absolute error and a high coefficient of determination, indicating a good fit to the data.

# Attention weights understanding 

In [19]:
def visualize_temporal_attention(attention_weights, batch_idx, features):
    # Extract attention weights for a specific batch element
    attention_matrix = attention_weights[batch_idx].detach().cpu().numpy()

    fig = go.Figure(
        data=go.Heatmap(
            z=attention_matrix.T,
            y=features,
            x=[f't-{i}' for i in range(attention_matrix.shape[0])],
            colorscale='Viridis',
        ),
    )

    fig.update_layout(
        title="Temporal Attention Weights",
        xaxis_title="Input Time Step",
        yaxis_title="Features",
        width=900,
        height=800,
    )
    print("Temporal weights")
    fig.show()

In [20]:
def visualize_spatial_attention(attention_weights, batch_idx, features):
    # Extract attention weights for a specific batch element
    attention_matrix = attention_weights[batch_idx].detach().cpu().numpy()
    
    fig = go.Figure(
        data=go.Heatmap(
            z=attention_matrix.T,
            x=[f't-{i}' for i in range(attention_matrix.shape[0])],
            y=features,
            colorscale='Viridis',
        ),
    )

    fig.update_layout(
        title="Spatial Attention Weights",
        xaxis_title="Input Time Step",
        yaxis_title="Features",
        width=900,
        height=800,
    )
    print("Spatial weigths")
    fig.show()

In [21]:
def plot_attention(params, spatial_weights=None, temporal_weights=None):
    features = [params["data"]['target_variable']] + params["data"]["variables"]
    # If the attention weights are torch tensors, convert them to numpy arrays first
    if isinstance(spatial_weights, torch.Tensor):
        visualize_spatial_attention(spatial_weights, batch_idx=0, features=features)
    if isinstance(temporal_weights, torch.Tensor):
        visualize_temporal_attention(temporal_weights, batch_idx=0, features=features)
        

In [22]:
for model_dir in model_dirs:
    if experiment not in str(model_dir):
        continue
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)
        data_loader = get_dataloader(params)
        test_dataloader = data_loader["test"]

        # Get a batch of input sequences and their corresponding targets
        inputs, targets = next(iter(test_dataloader))
        if params['model'] == "LSTM":
            continue
        elif params['model'] == "LSTMTemporalAttention":
            output, temporal_attention_weights = model(inputs, True)
            spatial_attention_weights = None
        else:
            output, spatial_attention_weights, temporal_attention_weights = model(inputs, True)
        
        plot_attention(params, spatial_attention_weights, temporal_attention_weights)

        break
        

Best performing model on experiment: data_4-1-lstm
Best performing model on experiment: data_4-1-spatio_temporal
Spatial weigths


Temporal weights


Best performing model on experiment: data_4-1-temporal
Temporal weights


# Multi-time step ahead forecasting

In [23]:
steps_ahead = 12

In [24]:
def recursive_forecast(model, input, forecast_steps=1, return_weights=False):
    predictions = []
    attention_weights = []

    # Find the index of the target feature in the input_size dimension
    target_feature_idx = 0

    for _ in range(forecast_steps):
        if return_weights:
            out, alpha_list, beta_t = model(input, return_weights=True)
            attention_weights.append((alpha_list, beta_t))
        else:
            out = model(input)

        predictions.append(out)

        input[:, -1, target_feature_idx] = out.squeeze(-1)


    predictions = torch.stack(predictions, dim=1)

    if return_weights:
        return predictions, attention_weights
    else:
        return predictions


In [25]:
def get_multi_step_preds_actuals(model, data_loader, forecast_steps=3):
    y_preds = []
    y_actuals = []

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Move the model to the GPU
    model.to(device)

    # In your get_multi_step_preds_actuals function, move the input to the GPU
    for input, target in data_loader:
        input = input.to(device)
        predictions = recursive_forecast(model, input, forecast_steps)
        y_preds.append(predictions)
        y_actuals.append(target.unsqueeze(1).repeat(1, forecast_steps))

    y_preds = torch.cat(y_preds).detach().cpu().numpy()
    y_actuals = torch.cat(y_actuals).detach().cpu().numpy()

    return y_preds, y_actuals


In [26]:
def plot_pred_actual(data_loader, model, model_name, forecast_steps=3):
    val_loader, test_loader = data_loader['val'], data_loader['test']
    
    with torch.no_grad():
        y_preds_val, y_val = get_multi_step_preds_actuals(model, val_loader, forecast_steps) # Redundent
        y_preds_test, y_test = get_multi_step_preds_actuals(model, test_loader, forecast_steps)
    
    # Get datetime values
    datetime_val = val_loader.datetime_index
    datetime_test = test_loader.datetime_index

    # Create dataframes
    val_df = pd.DataFrame({'datetime': datetime_val, 'target': y_val[:, 0]})
    test_df = pd.DataFrame({'datetime': datetime_test, 'target': y_test[:, 0]})
    predictions_df = pd.DataFrame({'datetime': datetime_test, 'predictions': y_preds_test[:, 0]})

    # Create a scatter plot for each dataset
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=val_df['datetime'], y=val_df['target'],
                        mode='lines', name='Validation'))
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'],
                        mode='lines', name='Test'))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'],
                        mode='lines', name='Predictions'))

    # Add spread plot between Test and Predictions
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'], fill=None,
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', showlegend=False))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'], fill='tonexty',
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', name='Spread'))

    # Add axis labels and plot title
    fig.update_layout(
        title=f'Time Series Data with {model_name} Predictions',
        xaxis_title='Datetime',
        yaxis_title='Target Value',
    )

    # Show the plot
    fig.show()


In [27]:
model_dfs = {}
for model_dir in model_dirs:
    if experiment not in str(model_dir):
        continue
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        test_loader = data_loader['test']
        
        plot_pred_actual(data_loader, model, params['model'], forecast_steps=steps_ahead)

In [28]:
torch.cuda.empty_cache()

In [29]:
model_dfs = {}
for model_dir in model_dirs:
    if experiment not in str(model_dir):
        continue
    
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=5)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        test_loader = data_loader['test']
        
        with torch.no_grad():
            y_preds, y_test = get_multi_step_preds_actuals(model, test_loader, forecast_steps=steps_ahead)

        # Calculate the Mean Absolute Error (MAE)
        mae = mean_absolute_error(y_test, y_preds)
        # Calculate the Root Mean Squared Error (RMSE)
        rmse = np.sqrt(mean_squared_error(y_test, y_preds))
        # Calculate the Mean Absolute Percentage Error (MAPE)
        mape = mean_absolute_percentage_error(y_test, y_preds)
        # Calculate the Determination Coefficient (R^2)
        r2 = r2_score(y_test, y_preds)

        rows.append(
            {
                "model": params["model"],
                "val_mae": val_loss,
                "test_mae": mae,
                "rmse": rmse,
                "mape": mape,
                "r2": r2,
                "variables": params["data"]["variables"],
            }
        )
    df = pd.DataFrame(rows).sort_values("test_mae")
    model_dfs[model_dir.name] = df

In [30]:
model_dfs.get(f"data_4-{experiment}-lstm")

In [31]:
model_dfs.get(f"data_4-{experiment}-temporal")

In [32]:
model_dfs.get(f"data_4-{experiment}-spatio_temporal")

In [33]:
# concatenate the dataframes
df_concat = pd.concat([model_dfs.get(f"data_4-{experiment}-lstm"), model_dfs.get(f"data_4-{experiment}-temporal"), model_dfs.get(f"data_4-{experiment}-spatio_temporal")])

df_concat = df_concat.drop(columns=['variables'])

# calculate the mean of each evaluation metric
df_avg = df_concat.groupby(['model']).mean()
df_avg

ValueError: All objects passed were None

Here's a brief analysis of the results:

1. LSTMSpatioTemporalAttention Model 1: This model has the lowest test MAE (1.635497) and the highest R2 score (0.922657), indicating the best overall performance among all models. This model used a combination of meteorological and hydrological variables for forecasting. It is worth noting that this model also has a relatively low RMSE (2.148725) and MAPE (0.254132).

2. LSTM Models: These models have relatively higher test MAE values (2.247437 and 3.335603) and lower R2 scores (0.714177 and 0.287179) compared to the LSTMSpatioTemporalAttention models. This suggests that adding attention mechanisms improves the performance of the LSTM model.

3. LSTMTemporalAttention Models: The performance of these models is generally better than the basic LSTM models but not as good as the LSTMSpatioTemporalAttention models. The test MAE values range from 2.369149 to 2.919680, and the R2 scores range from 0.663311 to 0.746914.

In conclusion, the LSTMSpatioTemporalAttention models generally perform better than the LSTM and LSTMTemporalAttention models, with the best performance observed in LSTMSpatioTemporalAttention Model 1. This suggests that incorporating both spatial and temporal attention mechanisms helps improve the model's forecasting capabilities.