In [25]:
import json
from pathlib import Path
from operator import itemgetter
import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error, mean_absolute_percentage_error, r2_score

import sys
import os

import plotly.graph_objs as go
import plotly.subplots as sp
from plotly.offline import init_notebook_mode, plot, iplot
import plotly.express as px
from plotly.subplots import make_subplots

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import *
from train import create_model

In [26]:
# List all model directories
ray_results = Path("../ray_results/")
model_dirs = [d for d in ray_results.iterdir() if d.is_dir()]

In [27]:
# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=5):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)
                progress_data = pd.read_csv(progress_file)

                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]

In [28]:
def get_dataloader(params):
    d = Data(params["data_file"], params["datetime"])
    data_loader = d.prepare_data(**params["data"])
    return data_loader

In [29]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [30]:
def get_preds_actuals(model, test_dataloader):
    model.eval()
    all_preds = []
    all_actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_dataloader:
            batch_output = model(X_batch.float())
            all_preds.extend(batch_output.numpy())
            all_actuals.extend(y_batch.numpy())

    return all_preds, all_actuals

In [31]:
def load_model_from_checkpoint(model, checkpoint_path):
    try:
        model_state_dict, _ = torch.load(checkpoint_path)
        model.load_state_dict(model_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {checkpoint_path}")
        print(e)
    return model

In [32]:
def plot_pred_actual(data_loader, model, model_name):
    val_loader, test_loader = data_loader['val'], data_loader['test']
    y_val = [y for _, y in val_loader]
    y_val = np.concatenate(y_val)
    y_preds, y_test = get_preds_actuals(model, test_loader)
    
    # Get datetime values
    datetime_val = val_loader.datetime_index
    datetime_test = test_loader.datetime_index

    # Create dataframes
    val_df = pd.DataFrame({'datetime': datetime_val, 'target': y_val})
    test_df = pd.DataFrame({'datetime': datetime_test, 'target': y_test})
    predictions_df = pd.DataFrame({'datetime': datetime_test, 'predictions': y_preds})
    
    # Create a scatter plot for each dataset
    fig = go.Figure()

    fig.add_trace(go.Scatter(x=val_df['datetime'], y=val_df['target'],
                        mode='lines', name='Validation'))
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'],
                        mode='lines', name='Test'))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'],
                        mode='lines', name='Predictions'))

    # Add spread plot between Test and Predictions
    fig.add_trace(go.Scatter(x=test_df['datetime'], y=test_df['target'], fill=None,
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', showlegend=False))
    fig.add_trace(go.Scatter(x=predictions_df['datetime'], y=predictions_df['predictions'], fill='tonexty',
                             mode='lines', line_color='rgba(0, 0, 0, 0.1)', name='Spread'))

    # Add axis labels and plot title
    fig.update_layout(
        title=f'Time Series Data with {model_name} Predictions',
        xaxis_title='Datetime',
        yaxis_title='Target Value',
    )

    # Show the plot
    fig.show()

# Plotting the data and predictions

In [33]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        plot_pred_actual(
            data_loader,
            model,
            model_name=params["model"],
        )
    break

Best performing model on experiment: data_1-lstm


# Plotting the loss

In [40]:
def get_losses(run_dir):
    if run_dir.is_dir():
        # Read the progress.csv file to get the validation losses
        progress_file = run_dir / "progress.csv"
        if progress_file.exists():
            return pd.read_csv(progress_file)[['train_loss', 'val_loss']]

def plot_losses(losses):
    fig = go.Figure()

    fig.add_trace(go.Scatter(y=losses['train_loss'], mode='lines', name='Train Loss'))
    fig.add_trace(go.Scatter(y=losses['val_loss'], mode='lines', name='Validation Loss'))

    fig.update_layout(title='Train and Validation Losses', xaxis_title='Epoch', yaxis_title='Loss')

    fig.show()

for model_dir in model_dirs:
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Access run_dir from checkpoint_path
        run_dir = checkpoint.parents[1]
        losses = get_losses(run_dir)
        plot_losses(losses)


# Looking at the results

In [35]:
model_dfs = {}
for model_dir in model_dirs:
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=5)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        test_loader = data_loader['test']

        y_preds, y_test = get_preds_actuals(model, test_loader)

        # Calculate the Mean Absolute Error (MAE)
        mae = mean_absolute_error(y_test, y_preds)
        # Calculate the Root Mean Squared Error (RMSE)
        rmse = np.sqrt(mean_squared_error(y_test, y_preds))
        # Calculate the Mean Absolute Percentage Error (MAPE)
        mape = mean_absolute_percentage_error(y_test, y_preds)
        # Calculate the Pearson correlation coefficient
        pearson_corr = np.corrcoef(y_test, y_preds)[0, 1]

        rows.append(
            {
                "model": params["model"],
                "val_loss": val_loss,
                "mae": mae,
                "rmse": rmse,
                "mape": mape,
                "pearson_corr": pearson_corr,                 
                "variables": params["data"]["variables"],
            }
        )
    df = pd.DataFrame(rows).sort_values("val_loss")
    model_dfs[model_dir.name] = df

In [36]:
model_dfs.keys()
pd.set_option('max_colwidth', 500)

In [37]:
model_dfs.get("data_1-lstm")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTM,0.439091,0.199775,0.55881,0.034859,0.997744,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTM,0.533215,0.378866,1.001488,0.056366,0.992159,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTM,0.553092,0.565395,0.878341,0.11946,0.995278,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
3,LSTM,0.581492,0.367497,0.733976,0.070528,0.995347,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
4,LSTM,0.658663,0.511873,1.069651,0.078204,0.993358,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [38]:
model_dfs.get("data_2-lstm")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTM,2.245538,1.466659,4.118863,0.133941,0.898387,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
1,LSTM,2.293975,0.782211,1.57796,0.109884,0.978202,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
2,LSTM,2.418191,0.511738,1.274372,0.07219,0.985854,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
3,LSTM,8.300144,0.514777,2.841595,0.031575,0.934935,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTM,8.409332,2.089594,3.648704,0.375529,0.895766,[]


In [24]:
model_dfs.get("data_3-lstm")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTM,0.146963,0.393993,1.25475,0.040614,0.989625,[]
1,LSTM,0.155951,0.486038,1.252461,0.066075,0.989965,[]
2,LSTM,0.176208,0.616786,1.017868,0.119195,0.993674,[]
3,LSTM,0.178266,0.430111,1.315957,0.051519,0.991115,[]
4,LSTM,0.181979,0.454136,1.268554,0.04935,0.986984,[]


In [39]:
model_dfs.get("data_4-lstm")

In [17]:
model_dfs.get("data_1-temporal_pop")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.229716,0.938325,1.036357,0.254424,0.99805,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMTemporalAttention,0.229932,0.246188,0.58555,0.040277,0.997534,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMTemporalAttention,0.235287,0.369611,0.675901,0.056547,0.996587,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
3,LSTMTemporalAttention,0.236991,0.288191,0.524748,0.062725,0.997866,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
4,LSTMTemporalAttention,0.241544,0.197489,0.466419,0.032183,0.998267,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"


In [18]:
model_dfs.get("data_2-temporal_pop")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.485027,0.301705,0.632535,0.046385,0.996631,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
1,LSTMTemporalAttention,0.498153,0.455234,0.701889,0.090455,0.996556,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
2,LSTMTemporalAttention,0.500456,0.340001,0.691118,0.052637,0.996453,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
3,LSTMTemporalAttention,0.533397,0.72908,0.903176,0.140107,0.996559,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
4,LSTMTemporalAttention,0.535221,0.439225,0.740171,0.084814,0.996118,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [19]:
model_dfs.get("data_3-temporal_pop")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.084827,0.137693,0.568767,0.013731,0.997333,"[Air_Temperature_Fister, Precipitation_Fister]"
1,LSTMTemporalAttention,0.086661,0.150578,0.594632,0.013469,0.997077,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMTemporalAttention,0.089329,0.345929,1.058985,0.025865,0.996352,"[Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMTemporalAttention,0.089708,0.181095,0.585167,0.017174,0.99724,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMTemporalAttention,0.091911,0.270654,0.739278,0.02147,0.997017,"[Air_Temperature_Fister, Precipitation_Fister]"


In [20]:
model_dfs.get("data_4-temporal_pop")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMTemporalAttention,0.146128,0.321565,0.829366,0.032383,0.99436,"[Air_Temperature_Fister, Precipitation_Fister]"
1,LSTMTemporalAttention,0.147472,0.376787,0.918586,0.045706,0.993381,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMTemporalAttention,0.149233,0.516708,1.191054,0.041129,0.997496,"[Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMTemporalAttention,0.152056,0.230761,0.702048,0.015556,0.997559,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMTemporalAttention,0.155312,0.411793,1.151096,0.03137,0.995304,"[Air_Temperature_Fister, Precipitation_Fister]"


In [21]:
model_dfs.get("data_1-spatial_pop")

Unnamed: 0,model,val_loss,mae,rmse,mape,pearson_corr,variables
0,LSTMSpatialAttention,0.403956,0.386948,0.826948,0.066017,0.995803,"[Air_Temperature_Fister, Precipitation_Fister]"
1,LSTMSpatialAttention,0.438033,0.305548,0.888092,0.038228,0.993372,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMSpatialAttention,0.477879,0.364967,1.307348,0.052661,0.990304,"[Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMSpatialAttention,0.483794,0.708904,1.332174,0.106728,0.993646,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMSpatialAttention,0.50605,0.333628,0.991294,0.041901,0.994055,"[Air_Temperature_Fister, Precipitation_Fister]"


# Attention weights understanding 

In [22]:
def plot_attention(params, spatial_weights=None, temporal_weights=None):
    # If the attention weights are torch tensors, convert them to numpy arrays first
    if isinstance(spatial_weights, torch.Tensor):
        spatial_attention_weights = spatial_weights.detach().cpu().numpy()
        average_spatial_attention_weights = np.mean(spatial_attention_weights, axis=0)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_spatial_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Spatial Attention Weights - {params['model']}",
            xaxis_title="Sequence Length",
            yaxis_title="Input Size (Features)"
        )
        vars = [params["data"]['target_variable']] + params["data"]["variables"]
        # Show the figure
        print({i: vars[i] for i in range(len(vars))})
        fig.show()
    if isinstance(temporal_weights, torch.Tensor):
        temporal_attention_weights = temporal_weights.detach().cpu().numpy()
        average_temporal_attention_weights = np.mean(temporal_attention_weights, axis=0)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_temporal_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Temporal Attention Weights - {params['model']}",
            xaxis_title="Sequence Length (Key)",
            yaxis_title="Sequence Length (Query)"
        )

        # Show the figure
        fig.show()

In [23]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=10)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)
        data_loader = get_dataloader(params)
        test_dataloader = data_loader["test"]

        # Get a batch of input sequences and their corresponding targets
        inputs, targets = next(iter(test_dataloader))
        output, (spatial_attention_weights, temporal_attention_weights) = model(inputs, True)

        
        plot_attention(params, spatial_attention_weights, temporal_attention_weights)
        break

        #print("sample_temporal_attention_weights", average_temporal_attention_weights)
        #print("sample_spatial_attention_weights", average_spatial_attention_weights)

Best performing model on experiment: data_1-LSTM_pop


TypeError: LSTM.forward() takes 2 positional arguments but 3 were given