In [1]:
import json
from pathlib import Path
from operator import itemgetter
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score

import sys
import os

import plotly.graph_objs as go
import plotly.subplots as sp
from plotly.offline import init_notebook_mode, plot, iplot
import plotly.express as px
from plotly.subplots import make_subplots

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import *
from train import create_model

In [2]:
# List all model directories
ray_results = Path("../ray_results/")
model_dirs = [d for d in ray_results.iterdir() if d.is_dir()]

In [3]:
# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=5):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)
                progress_data = pd.read_csv(progress_file)

                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]

In [4]:
def get_dataloader(params):
    d = Data(params["data_file"], params["datetime"])
    data_loader = d.prepare_data(**params["data"])
    return data_loader

In [5]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [6]:
def get_preds_actuals(model, test_dataloader):
    model.eval()
    all_preds = []
    all_actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_dataloader:
            batch_output = model(X_batch.float())
            all_preds.extend(batch_output.numpy())
            all_actuals.extend(y_batch.numpy())

    return all_preds, all_actuals

In [7]:
def load_model_from_checkpoint(model, checkpoint_path):
    try:
        model_state_dict, _ = torch.load(checkpoint_path)
        model.load_state_dict(model_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {checkpoint_path}")
        print(e)
    return model

In [8]:
def plot_pred_actual(data_loader, model, model_name):
    val_loader, test_loader = data_loader['val'], data_loader['test']
    y_val = [y for _, y in val_loader]
    y_val = np.concatenate(y_val)
    y_preds, y_test = get_preds_actuals(model, test_loader)
    
    # Get datetime values
    datetime_val = val_loader.datetime_index
    datetime_test = test_loader.datetime_index

    # Create dataframes
    val_df = pd.DataFrame({'datetime': datetime_val, 'target': y_val})
    test_df = pd.DataFrame({'datetime': datetime_test, 'target': y_test})
    predictions_df = pd.DataFrame({'datetime': datetime_test, 'predictions': y_preds})
    
    # Create a scatter plot for each dataset
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=val_df['datetime'], y=val_df['target'],
                        mode='lines', name='Validation'))
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'],
                        mode='lines', name='Test'))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'],
                        mode='lines', name='Predictions'))

    # Add spread plot between Test and Predictions
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'], fill=None,
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', showlegend=False))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'], fill='tonexty',
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', name='Spread'))

    # Add axis labels and plot title
    fig.update_layout(
        title=f'Time Series Data with {model_name} Predictions',
        xaxis_title='Datetime',
        yaxis_title='Target Value',
    )

    # Show the plot
    fig.show()

# Plotting the data and predictions

In [9]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        plot_pred_actual(
            data_loader,
            model,
            model_name=params["model"],
        )
    break

Best performing model on experiment: data_1-lstm


# Plotting the loss

In [11]:
def get_losses(run_dir):
    if run_dir.is_dir():
        # Read the progress.csv file to get the validation losses
        progress_file = run_dir / "progress.csv"
        if progress_file.exists():
            return pd.read_csv(progress_file)[['train_loss', 'val_loss']]

def plot_losses(losses):
    fig = go.Figure()

    fig.add_trace(go.Scatter(y=losses['train_loss'], mode='lines', name='Train Loss'))
    fig.add_trace(go.Scatter(y=losses['val_loss'], mode='lines', name='Validation Loss'))

    fig.update_layout(title='Train and Validation Losses', xaxis_title='Epoch', yaxis_title='Loss')

    fig.show()

for model_dir in model_dirs:
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Access run_dir from checkpoint_path
        print(f"Best model from {model_dir}")
        run_dir = checkpoint.parents[1]
        losses = get_losses(run_dir)
        plot_losses(losses)


Best model from ..\ray_results\data_1-lstm


Best model from ..\ray_results\data_1-spatial


Best model from ..\ray_results\data_1-temporal


Best model from ..\ray_results\data_2-lstm


Best model from ..\ray_results\data_2-temporal


Best model from ..\ray_results\data_3-lstm


Best model from ..\ray_results\data_4-lstm


# Looking at the results

In [47]:
model_dfs = {}
for model_dir in model_dirs:
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=5)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        test_loader = data_loader['test']

        y_preds, y_test = get_preds_actuals(model, test_loader)

        # Calculate the Mean Absolute Error (MAE)
        mae = mean_absolute_error(y_test, y_preds)
        # Calculate the Root Mean Squared Error (RMSE)
        rmse = np.sqrt(mean_squared_error(y_test, y_preds))
        # Calculate the Mean Absolute Percentage Error (MAPE)
        mape = mean_absolute_percentage_error(y_test, y_preds)
        # Calculate the Pearson correlation coefficient
        pearson_corr = np.corrcoef(y_test, y_preds)[0, 1]

        rows.append(
            {
                "model": params["model"],
                "val_loss": val_loss,
                "mae": mae,
                "rmse": rmse,
                "mape": mape,
                "pearson_corr": pearson_corr,                 
                "variables": params["data"]["variables"],
            }
        )
    df = pd.DataFrame(rows).sort_values("val_loss")
    model_dfs[model_dir.name] = df

In [48]:
model_dfs.keys()
pd.set_option('max_colwidth', 500)

In [49]:
model_dfs.get("data_1-lstm")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTM,0.439091,0.199775,0.55881,0.034859,0.997744,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTM,0.533215,0.378866,1.001488,0.056366,0.992159,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTM,0.553092,0.565395,0.878341,0.11946,0.995278,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
3,LSTM,0.581492,0.367497,0.733976,0.070528,0.995347,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
4,LSTM,0.658663,0.511873,1.069651,0.078204,0.993358,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [50]:
#model_dfs.get("data_2-lstm")

In [51]:
#model_dfs.get("data_3-lstm")

In [52]:
#model_dfs.get("data_4-lstm")

In [53]:
model_dfs.get("data_1-temporal")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.238682,0.365334,0.58191,0.070308,0.998383,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMTemporalAttention,0.241829,0.385762,0.580705,0.092461,0.998175,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMTemporalAttention,0.351158,0.234649,0.51127,0.045679,0.99795,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
3,LSTMTemporalAttention,0.359715,0.286876,0.559694,0.048379,0.997581,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
4,LSTMTemporalAttention,0.426076,0.15538,0.435063,0.0233,0.998699,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"


In [54]:
#model_dfs.get("data_2-temporal")

In [55]:
#model_dfs.get("data_3-temporal")

In [56]:
#model_dfs.get("data_4-temporal")

In [57]:
model_dfs.get("data_1-spatial")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.362865,0.205574,0.470082,0.034755,0.998394,"[Air_Temperature_Fister, Precipitation_Fister]"
1,LSTMSpatialAttention,0.488657,0.353649,0.700039,0.069268,0.996207,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMSpatialAttention,0.621824,0.381876,0.843557,0.06937,0.994974,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
3,LSTMSpatialAttention,1.492455,0.619,1.44144,0.088648,0.985557,[]
4,LSTMSpatialAttention,2.00155,0.733661,1.47392,0.127992,0.986656,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"


In [58]:
#model_dfs.get("data_2-spatial")

In [59]:
#model_dfs.get("data_3-spatial")

In [60]:
#model_dfs.get("data_4-spatial")

In [61]:
model_dfs.get("data_1-spatio_temporal")

In [62]:
#model_dfs.get("data_2-spatio_temporal")

In [30]:
#model_dfs.get("data_3-spatio_temporal")

In [31]:
#model_dfs.get("data_4-spatio_temporal")

# Attention weights understanding 

In [45]:
def plot_attention(params, spatial_weights=None, temporal_weights=None):
    # If the attention weights are torch tensors, convert them to numpy arrays first
    if isinstance(spatial_weights, torch.Tensor):
        spatial_attention_weights = spatial_weights.detach().cpu().numpy()
        average_spatial_attention_weights = np.mean(spatial_attention_weights, axis=0)
        print(average_spatial_attention_weights)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_spatial_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Spatial Attention Weights - {params['model']}",
            xaxis_title="Sequence Length",
            yaxis_title="Input Size (Features)"
        )
        vars = [params["data"]['target_variable']] + params["data"]["variables"]
        # Show the figure
        print({i: vars[i] for i in range(len(vars))})
        fig.show()
    if isinstance(temporal_weights, torch.Tensor):
        temporal_attention_weights = temporal_weights.detach().cpu().numpy()
        average_temporal_attention_weights = np.mean(temporal_attention_weights, axis=0)
        print(average_temporal_attention_weights)

        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_temporal_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Temporal Attention Weights - {params['model']}",
            xaxis_title="Sequence Length (Key)",
            yaxis_title="Sequence Length (Query)"
        )

        # Show the figure
        fig.show()

In [46]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=10)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        if params['model'] == "LSTM":
            continue
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)
        data_loader = get_dataloader(params)
        test_dataloader = data_loader["test"]

        # Get a batch of input sequences and their corresponding targets
        inputs, targets = next(iter(test_dataloader))
        output, (spatial_attention_weights, temporal_attention_weights) = model(inputs, True)

        
        plot_attention(params, spatial_attention_weights, temporal_attention_weights)
        break

        #print("sample_temporal_attention_weights", average_temporal_attention_weights)
        #print("sample_spatial_attention_weights", average_spatial_attention_weights)

Best performing model on experiment: data_1-lstm
Best performing model on experiment: data_1-spatial
[[8.9814216e-01 1.0185782e-01 1.8974509e-10]
 [2.7985603e-01 7.2014374e-01 4.6610683e-07]
 [6.3490311e-09 1.4202922e-06 9.9999887e-01]]
{0: 'Flow_Kalltveit', 1: 'Air_Temperature_Fister', 2: 'Precipitation_Fister'}


Best performing model on experiment: data_1-temporal
[[3.13641578e-02 5.45179367e-01 1.80877447e-01 1.20763831e-01
  7.42336959e-02 3.84124294e-02 7.71301612e-03 1.21710892e-03
  2.11306018e-04 2.67381511e-05 8.58173792e-07 3.54713992e-09
  2.26885644e-09 2.14665463e-09 2.01353267e-09 1.87065319e-09
  1.74408221e-09 1.67454095e-09 1.64743541e-09 1.68067038e-09
  1.83711213e-09 1.90966154e-09 1.99891592e-09 1.97352956e-09
  1.96711669e-09]
 [5.59467264e-02 3.43414813e-01 1.99830011e-01 1.61606058e-01
  1.22132570e-01 8.25024918e-02 2.54577193e-02 6.64740102e-03
  1.98295549e-03 4.08384716e-04 4.66232595e-05 2.92286518e-06
  2.12584723e-06 1.94986205e-06 1.78683092e-06 1.65160111e-06
  1.53337635e-06 1.46534705e-06 1.43406646e-06 1.44588626e-06
  1.54765121e-06 1.58514263e-06 1.64855567e-06 1.61218031e-06
  1.61242451e-06]
 [5.62009737e-02 3.49411428e-01 1.99574783e-01 1.60281271e-01
  1.20419338e-01 8.08014125e-02 2.46689171e-02 6.33655535e-03
  1.86465436e-03 3.78783210e-04 4.18378550e

Best performing model on experiment: data_2-lstm
Best performing model on experiment: data_2-temporal
[[3.64495844e-01 2.49683321e-01 1.31918386e-01 ... 3.64570296e-03
  3.63822584e-03 3.62428324e-03]
 [5.20635247e-01 2.88681716e-01 1.10939816e-01 ... 5.39484550e-04
  5.37873304e-04 5.35016472e-04]
 [6.38020396e-01 2.68559635e-01 7.00106770e-02 ... 4.23052334e-05
  4.21399527e-05 4.18757045e-05]
 ...
 [8.99800301e-01 9.62352753e-02 3.83653515e-03 ... 1.00781904e-10
  9.92124924e-11 9.76895856e-11]
 [8.99867356e-01 9.61738899e-02 3.83103965e-03 ... 9.95635033e-11
  9.96422667e-11 9.74931455e-11]
 [8.99954259e-01 9.60942432e-02 3.82427615e-03 ... 9.84225132e-11
  9.79866674e-11 9.84160253e-11]]


Best performing model on experiment: data_3-lstm
Best performing model on experiment: data_4-lstm
