In [1]:
import os
import sys

import json
import torch
from torch.utils.data import DataLoader

from tqdm import tqdm
import numpy as np

module_path = os.path.abspath(os.path.join("../src"))
if module_path not in sys.path:
    sys.path.append(module_path)

from data import Data
from train import create_model

In [2]:
def load_models_and_configs(ray_results_dir):
    state_dicts_and_configs = []
    for experiment_dir in os.listdir(ray_results_dir):
        experiment_path = os.path.join(ray_results_dir, experiment_dir)
        if os.path.isdir(experiment_path):
            for trial_dir in os.listdir(experiment_path):
                trial_path = os.path.join(experiment_path, trial_dir)
                if os.path.isdir(trial_path):
                    config_path = os.path.join(trial_path, 'params.json')
                    model_path = os.path.join(trial_path, 'my_model', 'checkpoint.pt')
                    if os.path.exists(config_path) and os.path.exists(model_path):
                        with open(config_path, 'r') as f:
                            config = json.load(f)
                        checkpoint = torch.load(model_path)
                        model_state_dict = checkpoint[0]
                        state_dicts_and_configs.append((model_state_dict, config))
    return state_dicts_and_configs


In [3]:
def get_test_loader(config):
    data_file = config['data_file']
    datetime = config['datetime']

    data = Data(data_file=data_file, datetime_variable=datetime)

    target_variable = config['data']['target_variable']
    sequence_length = config['data']['sequence_length']
    batch_size = config['data']['batch_size']
    variables = config['data']['variables']
    
    data_loader = data.prepare_data(target_variable, sequence_length, batch_size, variables)

    return data_loader['test']

In [14]:
def evaluate_model(model, test_loader, metric='mae'):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    errors = []
    ground_truth = []
    predictions = []

    with torch.no_grad():
        for X_batch, y_batch in tqdm(test_loader):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            
            ground_truth.extend(y_batch.cpu().numpy().flatten())
            predictions.extend(outputs.cpu().numpy().flatten())

            if metric == 'mae':
                error = torch.abs(outputs - y_batch).mean().item()
            elif metric == 'mse':
                error = torch.square(outputs - y_batch).mean().item()
            elif metric == 'rmse':
                error = torch.sqrt(torch.square(outputs - y_batch).mean()).item()
            else:
                raise ValueError("Invalid metric. Choose from 'mae', 'mse', or 'rmse'.")
            errors.append(error)

    # Calculate the average error over all test batches
    avg_error = np.mean(errors)
    return avg_error, ground_truth, predictions


In [5]:
ray_results_dir = "../ray_results/"
state_dicts_and_configs = load_models_and_configs(ray_results_dir)

# Find the best model and its configuration based on the minimum error
errors = []
configs = []
state_dicts = []

for state_dict, config in state_dicts_and_configs:
    model, config = create_model(config)
    model.load_state_dict(state_dict)
    test_loader = get_test_loader(config)
    error = evaluate_model(model, test_loader)
    errors.append(error)
    configs.append(config)
    state_dicts.append(state_dict)

100%|██████████| 13/13 [00:00<00:00, 30.93it/s]
100%|██████████| 25/25 [00:00<00:00, 83.78it/s]
100%|██████████| 13/13 [00:00<00:00, 57.80it/s]
100%|██████████| 25/25 [00:00<00:00, 172.71it/s]
100%|██████████| 13/13 [00:00<00:00, 48.84it/s]
100%|██████████| 25/25 [00:00<00:00, 114.91it/s]
100%|██████████| 25/25 [00:00<00:00, 88.69it/s]
100%|██████████| 25/25 [00:00<00:00, 72.59it/s]
100%|██████████| 13/13 [00:00<00:00, 38.64it/s]
100%|██████████| 13/13 [00:00<00:00, 64.56it/s]
100%|██████████| 13/13 [00:00<00:00, 68.89it/s]
100%|██████████| 13/13 [00:00<00:00, 67.83it/s]
100%|██████████| 13/13 [00:00<00:00, 65.01it/s]
100%|██████████| 13/13 [00:00<00:00, 48.00it/s]
100%|██████████| 25/25 [00:00<00:00, 95.52it/s]
100%|██████████| 13/13 [00:00<00:00, 60.91it/s]
100%|██████████| 13/13 [00:00<00:00, 52.92it/s]
100%|██████████| 25/25 [00:00<00:00, 129.63it/s]
100%|██████████| 13/13 [00:00<00:00, 145.96it/s]
100%|██████████| 25/25 [00:00<00:00, 75.65it/s]
100%|██████████| 13/13 [00:00<00:00,

In [15]:
best_error_index = np.argmin(errors)
best_config = configs[best_error_index]
best_state_dict = state_dicts[best_error_index]

# Create and load the best model
best_model, _ = create_model(best_config)
best_model.load_state_dict(best_state_dict)

<All keys matched successfully>

In [16]:
test_loader = get_test_loader(best_config)
avg_error, ground_truth, predictions = evaluate_model(best_model, test_loader)


100%|██████████| 13/13 [00:00<00:00, 55.79it/s]


In [20]:
import plotly.graph_objects as go

fig = go.Figure()

fig.add_trace(go.Scatter(x=list(range(len(ground_truth))), y=ground_truth,
                    mode='markers',
                    name='Ground truth'))
fig.add_trace(go.Scatter(x=list(range(len(predictions))), y=predictions,
                    mode='markers',
                    name='Predictions'))

fig.update_layout(
    title="Ground truth vs Predictions",
    xaxis_title="Index",
    yaxis_title="Value",
    legend_title="Legend",
    font=dict(
        family="Courier New, monospace",
        size=18,
        color="RebeccaPurple"
    )
)

fig.show()
