In [1]:
import json
from pathlib import Path
from operator import itemgetter
import pandas as pd

In [23]:
# List all model directories
ray_results = Path("../ray_results/")
model_dirs = [d for d in ray_results.iterdir() if d.is_dir()]


# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=20):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)
                progress_data = pd.read_csv(progress_file)

                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]


model_dfs = {}
for model_dir in model_dirs:
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=40)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        rows.append(
            {
                "model": params["model"],
                "val_loss": val_loss,
                "variables": params["data"]["variables"],
                #"params": params
            }
        )
    df = pd.DataFrame(rows).sort_values("val_loss")
    model_dfs[model_dir.name] = df

In [3]:
# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=20):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)
                progress_data = pd.read_csv(progress_file)

                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]

In [4]:
import sys
import os

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import *

In [5]:
def get_dataloader(params):
    d = Data(params["data_file"], params["datetime"])
    data_loader = d.prepare_data(**params["data"])
    return data_loader

In [6]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [7]:
import plotly.graph_objs as go
import plotly.subplots as sp


def evaluate_and_plot(model, test_dataloader):
    model.eval()
    all_preds = []
    all_actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_dataloader:
            batch_output = model(X_batch.float())
            all_preds.extend(batch_output.numpy())
            all_actuals.extend(y_batch.numpy())

    return all_preds, all_actuals


def plot_pred_actual(all_preds, all_actuals, datetime_index, model_name):
    # Create a line plot with Plotly
    fig = sp.make_subplots(
        rows=1, cols=1, subplot_titles=[f"{model_name} - Predicted vs. Actual"]
    )

    # Add traces for predicted and actual values
    fig.add_trace(
        go.Scatter(x=datetime_index, y=all_preds, mode="lines", name="Predicted"),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(x=datetime_index, y=all_actuals, mode="lines", name="Actual"),
        row=1,
        col=1,
    )

    # Set axis labels
    fig.update_xaxes(title_text="Time or Iteration", row=1, col=1)
    fig.update_yaxes(title_text="Values", row=1, col=1)

    # Set the plot layout
    fig.update_layout(showlegend=True)

    # Show the plot
    fig.show()

In [8]:
from train import create_model

In [9]:
def load_model_from_checkpoint(model, checkpoint_path):
    try:
        model_state_dict, _ = torch.load(checkpoint_path)
        model.load_state_dict(model_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {checkpoint_path}")
        print(e)
    return model

In [10]:
def calculate_metrics(predictions, actuals):
    assert len(predictions) == len(
        actuals
    ), "Predictions and actuals must have the same length."

    # Mean Absolute Error (MAE)
    mae = np.mean(np.abs(predictions - actuals))

    # Mean Squared Error (MSE)
    mse = np.mean((predictions - actuals) ** 2)

    # Root Mean Squared Error (RMSE)
    rmse = np.sqrt(mse)

    # Mean Absolute Percentage Error (MAPE)
    mape = np.mean(np.abs((actuals - predictions) / actuals)) * 100

    # Symmetric Mean Absolute Percentage Error (sMAPE)
    smape = (
        np.mean(
            2 * np.abs(predictions - actuals) / (np.abs(predictions) + np.abs(actuals))
        )
        * 100
    )

    return {"mae": mae, "mse": mse, "rmse": rmse, "mape": mape, "smape": smape}

In [11]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        preds, actuals = evaluate_and_plot(
            model = model, test_dataloader=data_loader["test"]
        )
        plot_pred_actual(
            preds,
            actuals,
            data_loader["test"].datetime_index,
            model_name=params["model"],
        )

        # Calculate metrics
        metrics = calculate_metrics(np.array(preds), np.array(actuals))
        print(metrics)
    break

Best performing model on experiment: data_1-attention


{'mae': 0.23498368, 'mse': 0.24786478, 'rmse': 0.4978602, 'mape': 3.479515016078949, 'smape': 3.449157252907753}


In [12]:
model_dfs.keys()
pd.set_option('max_colwidth', 5000)

In [13]:
model_dfs.get("data_1-attention")

Unnamed: 0,model,val_loss,variables
0,LSTMTemporalAttention,0.249677,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
1,LSTMSpatialTemporalAttention,0.270236,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMTemporalAttention,0.283689,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
3,LSTMSpatialAttention,0.33606,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
4,LSTMSpatialTemporalAttention,0.379068,"[Air_Temperature_Fister, Precipitation_Fister]"
5,LSTMSpatialAttention,0.389734,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
6,LSTMSpatialAttention,0.389797,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
7,LSTMSpatialAttention,0.432451,"[Air_Temperature_Fister, Precipitation_Fister]"
8,LSTMSpatialTemporalAttention,0.502394,[]
9,LSTM,0.611074,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [14]:
model_dfs.get("data_1-attention_max")

Unnamed: 0,model,val_loss,variables
0,LSTMTemporalAttention,0.218954,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
1,LSTM,0.241497,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
2,LSTM,0.252542,"[Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMTemporalAttention,0.268289,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
4,LSTM,0.277166,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
5,LSTMSpatialAttention,0.28539,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
6,LSTMTemporalAttention,0.285974,"[Air_Temperature_Fister, Precipitation_Fister]"
7,LSTMTemporalAttention,0.32662,"[Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
8,LSTMSpatialTemporalAttention,0.36692,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
9,LSTMSpatialTemporalAttention,0.428279,"[Air_Temperature_Fister, Precipitation_Fister]"


In [15]:
model_dfs.get("data_2-attention")

Unnamed: 0,model,val_loss,variables
0,LSTMTemporalAttention,0.239872,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMSpatialTemporalAttention,0.253439,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
2,LSTMTemporalAttention,0.362923,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
3,LSTMSpatialTemporalAttention,0.430255,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMSpatialAttention,0.439733,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
5,LSTM,0.45328,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
6,LSTMTemporalAttention,0.524526,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
7,LSTMTemporalAttention,0.573968,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
8,LSTMTemporalAttention,0.631287,"[Air_Temperature_Fister, Precipitation_Fister]"
9,LSTMSpatialAttention,0.632688,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [16]:
model_dfs.get("data_2-attention_max")

Unnamed: 0,model,val_loss,variables
0,LSTMSpatialAttention,0.176857,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
1,LSTMTemporalAttention,0.193519,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMSpatialTemporalAttention,0.271957,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
3,LSTM,0.380462,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMTemporalAttention,0.389693,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
5,LSTMSpatialAttention,0.399451,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
6,LSTMSpatialTemporalAttention,0.427786,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
7,LSTMSpatialAttention,0.442682,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Lyngsaana]"
8,LSTMTemporalAttention,0.48293,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
9,LSTMTemporalAttention,0.517923,[]


In [17]:
model_dfs.get("data_3-attention")

Unnamed: 0,model,val_loss,variables
0,LSTMTemporalAttention,0.056075,"[Air_Temperature_Fister, Precipitation_Fister]"
1,LSTM,0.09083,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
2,LSTMSpatialTemporalAttention,0.102167,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
3,LSTMSpatialAttention,0.109993,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTM,0.11258,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
5,LSTMSpatialAttention,0.112826,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
6,LSTMSpatialTemporalAttention,0.115277,[]
7,LSTMTemporalAttention,0.136753,[]
8,LSTM,0.157454,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
9,LSTMSpatialTemporalAttention,0.188431,"[Air_Temperature_Fister, Precipitation_Fister]"


In [18]:
model_dfs.get("data_3-attention_max")

Unnamed: 0,model,val_loss,variables
0,LSTMSpatialAttention,0.085994,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
1,LSTMTemporalAttention,0.08941,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMTemporalAttention,0.103625,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
3,LSTMSpatialTemporalAttention,0.121284,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
4,LSTMSpatialTemporalAttention,0.136435,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
5,LSTM,0.136684,"[Air_Temperature_Fister, Precipitation_Fister]"
6,LSTMSpatialTemporalAttention,0.145039,[]
7,LSTMTemporalAttention,0.14592,[]
8,LSTM,0.14685,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana]"
9,LSTMTemporalAttention,0.156306,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"


In [19]:
model_dfs.get("data_4-attention")

Unnamed: 0,model,val_loss,variables
0,LSTMTemporalAttention,0.139961,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
1,LSTMSpatialAttention,0.140898,"[Air_Temperature_Fister, Precipitation_Fister]"
2,LSTMTemporalAttention,0.188778,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
3,LSTMTemporalAttention,0.295775,"[Air_Temperature_Fister, Precipitation_Fister]"
4,LSTMTemporalAttention,0.299069,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
5,LSTM,0.389498,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
6,LSTMSpatialTemporalAttention,0.394251,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
7,LSTMSpatialAttention,0.424525,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
8,LSTMSpatialAttention,0.425964,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"
9,LSTMSpatialAttention,0.438712,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"


In [20]:
model_dfs.get("data_4-attention_max")

Unnamed: 0,model,val_loss,variables
0,LSTMTemporalAttention,0.170333,"[Air_Temperature_Fister, Precipitation_Fister]"
1,LSTMSpatialAttention,0.21814,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
2,LSTMSpatialTemporalAttention,0.227588,"[Air_Temperature_Fister, Precipitation_Fister]"
3,LSTMTemporalAttention,0.236823,"[Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
4,LSTMTemporalAttention,0.269266,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu]"
5,LSTMTemporalAttention,0.277161,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
6,LSTMSpatialTemporalAttention,0.399849,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu, Water_Level_Lyngsaana, Water_Temperature_Hiafossen, Water_Level_Hiafossen, Water_Level_Kalltveit, Water_Temperature_Kalltveit_Kum, Water_Temperature_Hiavatn, Water_Level_Hiavatn, Water_Temperature_Musdalsvatn, Water_Level_Musdalsvatn, Water_Temperature_Musdalsvatn_Downstream, Water_Level_Musdalsvatn_Downstream, Water_Temperature_Viglesdalsvatn, Water_Level_Viglesdalsvatn, Water_Temperature_Lyngsaana, Water_Temperature_Kalltveit_River]"
7,LSTMSpatialTemporalAttention,0.410831,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
8,LSTMSpatialAttention,0.44195,"[Wind_Speed_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Fister, Precipitation_Nilsebu]"
9,LSTMSpatialTemporalAttention,0.466904,"[Wind_Speed_Nilsebu, Air_Temperature_Nilsebu, Wind_Direction_Nilsebu, Relative_Humidity_Nilsebu, Precipitation_Nilsebu, Air_Temperature_Fister, Precipitation_Fister]"


In [70]:
def plot_attention(params, spatial_weights=None, temporal_weights=None):
    # If the attention weights are torch tensors, convert them to numpy arrays first
    if isinstance(spatial_weights, torch.Tensor):
        spatial_attention_weights = spatial_weights.detach().cpu().numpy()
        average_spatial_attention_weights = np.mean(spatial_attention_weights, axis=0)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_spatial_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Spatial Attention Weights - {params['model']}",
            xaxis_title="Sequence Length",
            yaxis_title="Input Size (Features)"
        )
        vars = [params["data"]['target_variable']] + params["data"]["variables"]
        # Show the figure
        print(vars)
        fig.show()


    if isinstance(temporal_weights, torch.Tensor):
        temporal_attention_weights = temporal_weights.detach().cpu().numpy()
        average_temporal_attention_weights = np.mean(temporal_attention_weights, axis=0)
        # Create the heatmap
        fig = go.Figure(data=go.Heatmap(z=average_temporal_attention_weights, colorscale='Viridis'))

        # Set axis labels and title
        fig.update_layout(
            title=f"Temporal Attention Weights - {params['model']}",
            xaxis_title="Sequence Length (Key)",
            yaxis_title="Sequence Length (Query)"
        )

        # Show the figure
        fig.show()

In [71]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=10)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)
        data_loader = get_dataloader(params)
        test_dataloader = data_loader["test"]

        # Get a batch of input sequences and their corresponding targets
        inputs, targets = next(iter(test_dataloader))
        output, (spatial_attention_weights, temporal_attention_weights) = model(inputs, True)

        
        plot_attention(params, spatial_attention_weights, temporal_attention_weights)
        break

        #print("sample_temporal_attention_weights", average_temporal_attention_weights)
        #print("sample_spatial_attention_weights", average_spatial_attention_weights)

Best performing model on experiment: data_1-attention


Best performing model on experiment: data_1-attention_max


Best performing model on experiment: data_2-attention


Best performing model on experiment: data_2-attention_max
['Flow_Kalltveit', 'Water_Level_Lyngsaana', 'Water_Temperature_Hiafossen', 'Water_Level_Hiafossen', 'Water_Level_Kalltveit', 'Water_Temperature_Kalltveit_Kum', 'Water_Temperature_Lyngsaana']


Best performing model on experiment: data_3-attention


Best performing model on experiment: data_3-attention_max
['Flow_Kalltveit', 'Water_Level_Lyngsaana', 'Water_Temperature_Hiafossen', 'Water_Level_Hiafossen', 'Water_Level_Kalltveit', 'Water_Temperature_Kalltveit_Kum', 'Water_Temperature_Hiavatn', 'Water_Level_Hiavatn', 'Water_Temperature_Musdalsvatn', 'Water_Level_Musdalsvatn', 'Water_Temperature_Musdalsvatn_Downstream', 'Water_Level_Musdalsvatn_Downstream', 'Water_Temperature_Viglesdalsvatn', 'Water_Level_Viglesdalsvatn', 'Water_Temperature_Lyngsaana']


Best performing model on experiment: data_4-attention


Best performing model on experiment: data_4-attention_max
