In [None]:
import json
from pathlib import Path
from operator import itemgetter
import pandas as pd

In [None]:
# List all model directories
ray_results = Path("../ray_results/")
model_dirs = [d for d in ray_results.iterdir() if d.is_dir()]


# Function to find best checkpoints for a model directory
def find_best_checkpoints(model_dir, num_best=10):
    checkpoints = []

    # Iterate over all training runs in the model directory
    for run_dir in model_dir.iterdir():
        if run_dir.is_dir():
            # Read the progress.csv file to get the validation losses
            progress_file = run_dir / "progress.csv"
            if progress_file.exists():
                with open(run_dir / "params.json", "r") as f:
                    params = json.load(f)

                progress_data = pd.read_csv(progress_file)
                best_val_idx = progress_data["val_loss"].idxmin()
                best_val_loss = progress_data.loc[best_val_idx, "val_loss"]

                # Save the checkpoint path and validation loss
                checkpoint_path = run_dir / "my_model" / "checkpoint.pt"
                checkpoints.append((checkpoint_path, best_val_loss, params))

    # Sort the checkpoints based on validation loss
    checkpoints.sort(key=itemgetter(1))

    return checkpoints[:num_best]


model_dfs = {}
for model_dir in model_dirs:
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=10)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        rows.append(
            {
                "model": params["model"],
                "val_loss": val_loss,
                "variables": params["data"]["variables"],
                #"params": params
            }
        )
    df = pd.DataFrame(rows).sort_values("val_loss")
    model_dfs[model_dir.name] = df

In [None]:
import sys
import os

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import *

In [None]:
def get_dataloader(params):
    d = Data(params["data_file"], params["datetime"])
    data_loader = d.prepare_data(**params["data"])
    return data_loader

In [None]:
def load_model_from_checkpoint(checkpoint_path, model_class):
    model = model_class()
    model.load_state_dict(torch.load(checkpoint_path))
    return model

In [None]:
import plotly.graph_objs as go
import plotly.subplots as sp


def evaluate_and_plot(model, test_dataloader):
    model.eval()
    all_preds = []
    all_actuals = []

    with torch.no_grad():
        for X_batch, y_batch in test_dataloader:
            batch_output = model(X_batch.float())
            all_preds.extend(batch_output.numpy())
            all_actuals.extend(y_batch.numpy())

    return all_preds, all_actuals


def plot_pred_actual(all_preds, all_actuals, datetime_index, model_name):
    # Create a line plot with Plotly
    fig = sp.make_subplots(
        rows=1, cols=1, subplot_titles=[f"{model_name} - Predicted vs. Actual"]
    )

    # Add traces for predicted and actual values
    fig.add_trace(
        go.Scatter(x=datetime_index, y=all_preds, mode="lines", name="Predicted"),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(x=datetime_index, y=all_actuals, mode="lines", name="Actual"),
        row=1,
        col=1,
    )

    # Set axis labels
    fig.update_xaxes(title_text="Time or Iteration", row=1, col=1)
    fig.update_yaxes(title_text="Values", row=1, col=1)

    # Set the plot layout
    fig.update_layout(showlegend=True)

    # Show the plot
    fig.show()

In [None]:
from train import create_model

In [None]:
def load_model_from_checkpoint(model, checkpoint_path):
    try:
        model_state_dict, _ = torch.load(checkpoint_path)
        model.load_state_dict(model_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {checkpoint_path}")
        print(e)
    return model

In [None]:
def calculate_metrics(predictions, actuals):
    assert len(predictions) == len(
        actuals
    ), "Predictions and actuals must have the same length."

    # Mean Absolute Error (MAE)
    mae = np.mean(np.abs(predictions - actuals))

    # Mean Squared Error (MSE)
    mse = np.mean((predictions - actuals) ** 2)

    # Root Mean Squared Error (RMSE)
    rmse = np.sqrt(mse)

    # Mean Absolute Percentage Error (MAPE)
    mape = np.mean(np.abs((actuals - predictions) / actuals)) * 100

    # Symmetric Mean Absolute Percentage Error (sMAPE)
    smape = (
        np.mean(
            2 * np.abs(predictions - actuals) / (np.abs(predictions) + np.abs(actuals))
        )
        * 100
    )

    return {"mae": mae, "mse": mse, "rmse": rmse, "mape": mape, "smape": smape}

In [None]:
for model_dir in model_dirs:
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):

        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)

        data_loader = get_dataloader(params)
        preds, actuals = evaluate_and_plot(
            model=model, test_dataloader=data_loader["test"]
        )
        plot_pred_actual(
            preds,
            actuals,
            data_loader["test"].datetime_index,
            model_name=params["model"],
        )

        # Calculate metrics
        metrics = calculate_metrics(np.array(preds), np.array(actuals))
        print(metrics)
    break

In [None]:
model_dfs.keys()
pd.set_option('max_colwidth', 5000)

In [None]:
model_dfs.get("data_1-location_based")

In [None]:
model_dfs.get("data_1-random_var")

In [None]:
model_dfs.get("data_2-random_var")

In [None]:
model_dfs.get("data_3-random_var")

In [None]:
model_dfs.get("data_4-random_var")

In [41]:
def plot_attention(temporal_weights=None, spatial_weights=None):
   # Assuming spatial_weights and temporal_weights are numpy arrays
    # If they are torch tensors, convert them to numpy arrays first
    if isinstance(spatial_weights, torch.Tensor):
        spatial_weights = spatial_weights.detach().cpu().numpy()

    if isinstance(temporal_weights, torch.Tensor):
        temporal_weights = temporal_weights.detach().cpu().numpy()

    # Compute average attention weights across samples
    avg_spatial_weights = np.mean(spatial_weights, axis=0)
    avg_temporal_weights = np.mean(temporal_weights, axis=0)

    # Plot temporal attention weights
    fig_temporal = go.Figure(data=go.Bar(x=list(range(len(avg_temporal_weights))), y=avg_temporal_weights))
    fig_temporal.update_layout(
        title='Average Temporal Attention Weights',
        xaxis_title='Time Steps',
        yaxis_title='Attention Weight'
    )
    fig_temporal.show()

    # Plot spatial attention weights
    fig_spatial = go.Figure(data=go.Bar(x=list(range(len(avg_spatial_weights))), y=avg_spatial_weights))
    fig_spatial.update_layout(
        title='Average Spatial Attention Weights',
        xaxis_title='Input Variables',
        yaxis_title='Attention Weight'
    )
    fig_spatial.show()


In [43]:
for model_dir in model_dirs: #TODO: This is wrong! Need to modify the model to return the att_weights I want to visulise the attention
    print(f"Best performing model on experiment: {os.path.basename(model_dir)}")
    rows = []
    best_checkpoints = find_best_checkpoints(model_dir, num_best=1)
    for i, (checkpoint, val_loss, params) in enumerate(best_checkpoints):
        # Load model and weights
        model = create_model(params)
        model = load_model_from_checkpoint(model, checkpoint)
        # get the state dictionary of the model
        state_dict = model.state_dict()
        spatial_weights = state_dict
        # access the weights of the SpatialAttention layer
        spatial_weights = state_dict['spatial_attention.attn.weight']
        temporal_weights = state_dict['temporal_attention.attn.weight']
        plot_attention(temporal_weights, spatial_weights)
    break

Best performing model on experiment: data_1-location_based
torch.Size([64, 64])
