In [None]:
import json

with open("./params.json", mode = "r", encoding = "utf-8") as f:
    data = json.load(f)
    model_path = data["model_path"]
    dataset_path_train = data["dataset_path"]["train"]
    dataset_path_test = data["dataset_path"]["test"]
    num_single_sample_timesteps = data["num_single_sample_timesteps"]
    input_window_length = data["input_window_length"]
    label_window_length = data["label_window_length"]
    input_features = data["input_features"]
    label_features = data["label_features"]

    # Usually window_stride = 1 since we want to check each input window
    window_stride = 20
    seed_val = 0

In [None]:
import torch
import random
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(seed_val)
random.seed(seed_val)
np.random.seed(seed_val)

In [None]:
from tqdm import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

from utils.pipeline.Model import TimeSeriesHuggingFaceTransformer
from utils.pipeline.Data import get_mean_std_respected_temporal, WindowedIterableDataset
from utils.pipeline.Run import autoregress

In [None]:
# stats = get_mean_std_respected_temporal(
#     dataset_path = dataset_path_train,
#     cols = input_features
# )

stats = get_mean_std_respected_temporal(
    dataset_path = dataset_path_train,
    cols = input_features,
    num_single_sample_timesteps = num_single_sample_timesteps,
    input_window_len = input_window_length,
    label_window_len = label_window_length,
    window_stride = window_stride
)

df_test = WindowedIterableDataset(
    dataset_path = dataset_path_test,
    stats = stats,
    input_features = input_features,
    label_features = label_features,
    num_single_sample_timesteps = num_single_sample_timesteps,
    stride = window_stride,
    input_window_length = input_window_length,
    label_window_length = label_window_length,
    inference = True
)

data_loader_test = DataLoader(
    df_test,
    batch_size = 1,    # One windowed datapoint at a time
    pin_memory = True
)

## Prediction

In [None]:
model = torch.load(model_path, weights_only = False).to(device)
model.eval()

In [None]:
test_loss = 0.0
test_progress_bar = tqdm(
    data_loader_test
)

criterion = torch.nn.MSELoss()

target_timeseries_idx = 0
feature = "u_list"
figure_range = 1

num_datapoints_per_timeseries = 1 + (num_single_sample_timesteps - (input_window_length + label_window_length) + 1) // window_stride

with torch.no_grad():
    for datapoint_idx, (batch_x, batch_y, x_labels) in enumerate(test_progress_bar):
        if(datapoint_idx >= target_timeseries_idx * num_datapoints_per_timeseries and datapoint_idx < (target_timeseries_idx + 1) * num_datapoints_per_timeseries):
            window_idx = datapoint_idx % num_datapoints_per_timeseries

            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)

            preds = autoregress(
                model = model,
                batch_x = batch_x,
                batch_y = batch_y,
                device = device,
                extract_attention = True
            )
            
            loss = criterion(preds, batch_y)
            test_progress_bar.set_postfix({
                "single_test_loss": f"{loss.item():.6f}"
            })

            x = list(range(num_single_sample_timesteps))

            feature_idx = label_features.index(feature)

            feature_label = batch_y[0, :, feature_idx].cpu()
            feature_pred = preds[0, :, feature_idx].cpu()
            feature_x_labels = x_labels[0, :, feature_idx]

            # feature_pred = (feature_pred * stats["std"][feature]) + stats["mean"][feature]
            # feature_label = (feature_label * stats["std"][feature]) + stats["mean"][feature]
            # feature_x_labels = (feature_x_labels * stats["std"][feature]) + stats["mean"][feature]
        
            feature_pred = (feature_pred * stats[window_idx, f"{feature}_std"]) + stats[window_idx, f"{feature}_mean"]
            feature_label = (feature_label * stats[window_idx, f"{feature}_std"]) + stats[window_idx, f"{feature}_mean"]
            # No normalization on x_labels in WindowedIterableDataset!

            sns.set_theme(style = "whitegrid")
            fig, ax = plt.subplots(figsize = (16, 8))
            ax.set_ylim(-figure_range, figure_range)

            ax.axvspan(
                x[window_idx * window_stride],
                x[window_idx * window_stride + input_window_length - 1],
                color = "green",
                alpha = 0.3,
                label = "Input Sequence Region"
            )

            sns.scatterplot(
                x = x,
                y = feature_x_labels,
                marker = "o",
                label = f"{feature}_label (circles)",
                color = "blue",
                ax = ax
            )

            sns.scatterplot(
                x = x[(input_window_length + window_idx * window_stride):(input_window_length + window_idx * window_stride + label_window_length)],
                y = feature_pred,
                marker = "x",
                label = f"{feature}_pred (crosses)",
                color = "red",
                ax = ax
            )

            ax.set_title(f"{feature} Value Ground-Truth vs. Prediction")
            ax.set_xlabel("Timesteps")
            ax.set_ylabel(feature)
            ax.legend()

            plt.tight_layout()
            plt.show()
            
            avg_attn_vals = model.get_average_attention_values()

            for i in range(20):                                      # First 20 predictions following input sequence 
                output_row = avg_attn_vals[i, :]
                top_k_indices = np.argsort(output_row)[::-1][:10]    # Top 10 highest attention input timesteps
                top_k_scores = output_row[top_k_indices]
                print(f"Output Timestep {input_window_length + window_idx * window_stride + i + 1}")
                print(f"    Input Timesteps {top_k_indices + (window_idx * window_stride + 1)}")
                print(f"    Scores {[f'{score:.5f}' for score in top_k_scores]}\n")
