In [1]:
import json
from pathlib import Path
from operator import itemgetter
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score

import sys
import os

import plotly.graph_objs as go
import plotly.subplots as sp
from plotly.offline import init_notebook_mode, plot, iplot
import plotly.express as px
from plotly.subplots import make_subplots

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import *
from train import create_model

In [2]:
# List all model directories
ray_results = Path("../ray_results/")
model_dirs = [d for d in ray_results.iterdir() if d.is_dir()]

In [3]:
# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=5):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)
                progress_data = pd.read_csv(progress_file)

                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]

In [4]:
def get_dataloader(params):
    d = Data(params["data_file"], params["datetime"])
    data_loader = d.prepare_data(**params["data"])
    return data_loader

In [5]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [6]:
def get_preds_actuals(model, test_dataloader):
    model.eval()
    all_preds = []
    all_actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_dataloader:
            batch_output = model(X_batch.float())
            all_preds.extend(batch_output.numpy())
            all_actuals.extend(y_batch.numpy())

    return all_preds, all_actuals

In [7]:
def load_model_from_checkpoint(model, checkpoint_path):
    try:
        model_state_dict, _ = torch.load(checkpoint_path)
        model.load_state_dict(model_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {checkpoint_path}")
        print(e)
    return model

In [8]:
def plot_pred_actual(data_loader, model, model_name):
    val_loader, test_loader = data_loader['val'], data_loader['test']
    y_val = [y for _, y in val_loader]
    y_val = np.concatenate(y_val)
    y_preds, y_test = get_preds_actuals(model, test_loader)
    
    # Get datetime values
    datetime_val = val_loader.datetime_index
    datetime_test = test_loader.datetime_index

    # Create dataframes
    val_df = pd.DataFrame({'datetime': datetime_val, 'target': y_val})
    test_df = pd.DataFrame({'datetime': datetime_test, 'target': y_test})
    predictions_df = pd.DataFrame({'datetime': datetime_test, 'predictions': y_preds})
    
    # Create a scatter plot for each dataset
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=val_df['datetime'], y=val_df['target'],
                        mode='lines', name='Validation'))
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'],
                        mode='lines', name='Test'))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'],
                        mode='lines', name='Predictions'))

    # Add spread plot between Test and Predictions
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'], fill=None,
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', showlegend=False))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'], fill='tonexty',
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', name='Spread'))

    # Add axis labels and plot title
    fig.update_layout(
        title=f'Time Series Data with {model_name} Predictions',
        xaxis_title='Datetime',
        yaxis_title='Target Value',
    )

    # Show the plot
    fig.show()

# Plotting the data and predictions

In [9]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        plot_pred_actual(
            data_loader,
            model,
            model_name=params["model"],
        )
    break

Best performing model on experiment: data_1-attention_max


# Plotting the loss

In [10]:
def get_losses(run_dir):
    if run_dir.is_dir():
        # Read the progress.csv file to get the validation losses
        progress_file = run_dir / "progress.csv"
        if progress_file.exists():
            return pd.read_csv(progress_file)[['train_loss', 'val_loss']]

def plot_losses(losses):
    fig = go.Figure()

    fig.add_trace(go.Scatter(y=losses['train_loss'], mode='lines', name='Train Loss'))
    fig.add_trace(go.Scatter(y=losses['val_loss'], mode='lines', name='Validation Loss'))

    fig.update_layout(title='Train and Validation Losses', xaxis_title='Epoch', yaxis_title='Loss')

    fig.show()

for model_dir in model_dirs:
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Access run_dir from checkpoint_path
        run_dir = checkpoint.parents[1]
        losses = get_losses(run_dir)
        plot_losses(losses)
    break


# Looking at the results

In [11]:
model_dfs = {}
for model_dir in model_dirs:
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=5)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        test_loader = data_loader['test']

        y_preds, y_test = get_preds_actuals(model, test_loader)

        # Calculate the Mean Absolute Error (MAE)
        mae = mean_absolute_error(y_test, y_preds)
        # Calculate the Root Mean Squared Error (RMSE)
        rmse = np.sqrt(mean_squared_error(y_test, y_preds))
        # Calculate the Mean Absolute Percentage Error (MAPE)
        mape = mean_absolute_percentage_error(y_test, y_preds)
        # Calculate the Pearson correlation coefficient
        pearson_corr = np.corrcoef(y_test, y_preds)[0, 1]

        rows.append(
            {
                "model": params["model"],
                "val_loss": val_loss,
                "mae": mae,
                "rmse": rmse,
                "mape": mape,
                "pearson_corr": pearson_corr,                 "variables": params["data"]["variables"],
            }
        )
    df = pd.DataFrame(rows).sort_values("val_loss")
    model_dfs[model_dir.name] = df

In [12]:
model_dfs.keys()
pd.set_option('max_colwidth', 500)

In [13]:
model_dfs.get("data_1-attention_max")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.236961,0.261272,0.508377,0.059065,0.997836,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMSpatialAttention,0.288076,0.476467,0.665179,0.113931,0.998065,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMTemporalAttention,0.288235,0.494203,0.653997,0.125745,0.997518,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
3,LSTMTemporalAttention,0.296685,0.324813,0.59592,0.049983,0.998363,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTM,0.310853,0.296109,0.548298,0.060713,0.997645,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [14]:
model_dfs.get("data_2-attention_max")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.178348,0.466262,0.840765,0.055172,0.99708,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMSpatialTemporalAttention,0.239623,0.781457,1.301996,0.109577,0.996462,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMSpatialTemporalAttention,0.347902,0.494395,0.914813,0.073212,0.993943,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
3,LSTMTemporalAttention,0.398019,0.340393,0.607485,0.06462,0.997165,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
4,LSTMSpatialTemporalAttention,0.398651,0.521235,0.779083,0.093146,0.995904,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"


In [15]:
model_dfs.get("data_3-attention_max")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.091983,0.528901,0.815967,0.091303,0.99743,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
1,LSTMSpatialTemporalAttention,0.10726,0.258088,0.703573,0.027237,0.99607,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
2,LSTMTemporalAttention,0.107559,0.409818,0.654089,0.081839,0.997071,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
3,LSTM,0.116887,0.250965,0.783927,0.026413,0.994963,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
4,LSTMTemporalAttention,0.132533,0.304988,0.644001,0.055572,0.996898,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."


In [16]:
model_dfs.get("data_4-attention_max")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.182129,0.475454,0.792322,0.066023,0.996436,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."
1,LSTMSpatialTemporalAttention,0.217598,0.826026,1.24863,0.122405,0.995864,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
2,LSTMTemporalAttention,0.21782,0.681094,1.276977,0.10458,0.993841,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."
3,LSTM,0.218529,0.40239,1.144118,0.047548,0.993709,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
4,LSTMTemporalAttention,0.221623,0.371819,0.834833,0.044363,0.996566,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"


In [23]:
model_dfs.get("data_1-spatial")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.291384,0.345981,0.581262,0.058876,0.99871,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMSpatialAttention,0.457517,0.347739,0.864035,0.048772,0.99408,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMSpatialAttention,1.080532,0.982451,2.120859,0.137268,0.980359,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
3,LSTMSpatialAttention,1.422778,1.193468,2.3598,0.184208,0.956509,[]
4,LSTMSpatialAttention,1.968258,0.647133,1.409689,0.152454,0.986647,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"


In [24]:
model_dfs.get("data_2-spatial")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.207392,0.421298,0.780893,0.056928,0.996419,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMSpatialAttention,0.393795,0.436345,0.880822,0.059343,0.993623,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
2,LSTMSpatialAttention,0.459091,0.582714,0.986663,0.093255,0.994874,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
3,LSTMSpatialAttention,0.49082,0.531501,1.2665,0.061948,0.993835,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMSpatialAttention,0.541739,0.414556,0.790178,0.069056,0.995481,"[Air_Temperature_Fister, Precipitation_Fister]"


In [25]:
model_dfs.get("data_3-spatial")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.256377,0.677281,1.414522,0.106109,0.989152,[]
1,LSTMSpatialAttention,0.265355,0.346528,0.664783,0.054177,0.996563,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
2,LSTMSpatialAttention,0.284761,0.870341,1.220759,0.151417,0.994653,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMSpatialAttention,0.429512,0.607936,1.305654,0.086712,0.986743,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMSpatialAttention,0.473116,1.098343,1.47716,0.17318,0.991901,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"


In [26]:
model_dfs.get("data_4-spatial")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.110386,0.482708,0.651677,0.074106,0.99879,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
1,LSTMSpatialAttention,0.190095,0.41984,0.835695,0.048933,0.996682,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMSpatialAttention,0.380568,0.704245,1.09241,0.119785,0.992698,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMSpatialAttention,0.429734,0.54638,1.211001,0.058715,0.991958,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Wate..."
4,LSTMSpatialAttention,0.445699,0.425969,0.809043,0.053616,0.995388,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"


# Attention weights understanding 

In [27]:
def plot_attention(params, spatial_weights=None, temporal_weights=None):
    # If the attention weights are torch tensors, convert them to numpy arrays first
    if isinstance(spatial_weights, torch.Tensor):
        spatial_attention_weights = spatial_weights.detach().cpu().numpy()
        average_spatial_attention_weights = np.mean(spatial_attention_weights, axis=0)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_spatial_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Spatial Attention Weights - {params['model']}",
            xaxis_title="Sequence Length",
            yaxis_title="Input Size (Features)"
        )
        vars = [params["data"]['target_variable']] + params["data"]["variables"]
        # Show the figure
        print({i: vars[i] for i in range(len(vars))})
        fig.show()
    if isinstance(temporal_weights, torch.Tensor):
        temporal_attention_weights = temporal_weights.detach().cpu().numpy()
        average_temporal_attention_weights = np.mean(temporal_attention_weights, axis=0)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_temporal_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Temporal Attention Weights - {params['model']}",
            xaxis_title="Sequence Length (Key)",
            yaxis_title="Sequence Length (Query)"
        )

        # Show the figure
        fig.show()

In [28]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=10)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)
        data_loader = get_dataloader(params)
        test_dataloader = data_loader["test"]

        # Get a batch of input sequences and their corresponding targets
        inputs, targets = next(iter(test_dataloader))
        output, (spatial_attention_weights, temporal_attention_weights) = model(inputs, True)

        
        plot_attention(params, spatial_attention_weights, temporal_attention_weights)
        break

        #print("sample_temporal_attention_weights", average_temporal_attention_weights)
        #print("sample_spatial_attention_weights", average_spatial_attention_weights)

Best performing model on experiment: data_1-attention_max


Best performing model on experiment: data_1-spatial
{0: 'Flow_Kalltveit', 1: 'Water_Level_Kalltveit', 2: 'Water_Temperature_Kalltveit_Kum', 3: 'Water_Temperature_Lyngsaana'}


Best performing model on experiment: data_2-attention_max
{0: 'Flow_Kalltveit', 1: 'Water_Level_Lyngsaana', 2: 'Water_Temperature_Hiafossen', 3: 'Water_Level_Hiafossen', 4: 'Water_Level_Kalltveit', 5: 'Water_Temperature_Kalltveit_Kum', 6: 'Water_Temperature_Lyngsaana'}


Best performing model on experiment: data_2-spatial
{0: 'Flow_Kalltveit', 1: 'Water_Level_Lyngsaana', 2: 'Water_Temperature_Hiafossen', 3: 'Water_Level_Hiafossen', 4: 'Water_Level_Kalltveit', 5: 'Water_Temperature_Kalltveit_Kum', 6: 'Water_Temperature_Lyngsaana'}


Best performing model on experiment: data_3-attention_max


Best performing model on experiment: data_3-spatial
{0: 'Flow_Kalltveit'}


Best performing model on experiment: data_4-attention_max
{0: 'Flow_Kalltveit', 1: 'Wind_Speed_Nilsebu', 2: 'Wind_Direction_Nilsebu', 3: 'Relative_Humidity_Nilsebu', 4: 'Precipitation_Fister', 5: 'Precipitation_Nilsebu', 6: 'Water_Level_Lyngsaana', 7: 'Water_Temperature_Hiafossen', 8: 'Water_Level_Hiafossen', 9: 'Water_Level_Kalltveit', 10: 'Water_Temperature_Kalltveit_Kum', 11: 'Water_Temperature_Hiavatn', 12: 'Water_Level_Hiavatn', 13: 'Water_Temperature_Musdalsvatn', 14: 'Water_Level_Musdalsvatn', 15: 'Water_Temperature_Musdalsvatn_Downstream', 16: 'Water_Level_Musdalsvatn_Downstream', 17: 'Water_Temperature_Viglesdalsvatn', 18: 'Water_Level_Viglesdalsvatn', 19: 'Water_Temperature_Lyngsaana', 20: 'Water_Temperature_Kalltveit_River'}


Best performing model on experiment: data_4-spatial
{0: 'Flow_Kalltveit', 1: 'Water_Level_Lyngsaana', 2: 'Water_Temperature_Hiafossen', 3: 'Water_Level_Hiafossen', 4: 'Water_Level_Kalltveit', 5: 'Water_Temperature_Kalltveit_Kum', 6: 'Water_Temperature_Hiavatn', 7: 'Water_Level_Hiavatn', 8: 'Water_Temperature_Musdalsvatn', 9: 'Water_Level_Musdalsvatn', 10: 'Water_Temperature_Musdalsvatn_Downstream', 11: 'Water_Level_Musdalsvatn_Downstream', 12: 'Water_Temperature_Viglesdalsvatn', 13: 'Water_Level_Viglesdalsvatn', 14: 'Water_Temperature_Lyngsaana', 15: 'Water_Temperature_Kalltveit_River'}
