In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import numpy as np
import torch
from gluonts.dataset.common import FileDataset

from dystformer.augmentations import StandardizeTransform
from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import plot_forecast_evaluation, plot_trajs_multivariate

In [None]:
pipeline = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_noembed_pretrained_correct-0/checkpoint-final",
    device_map="cuda:0",
    torch_dtype=torch.float32,
)
pipeline

## Forecasting

In [None]:
def forecast(
    pipeline,
    trajectory: np.ndarray,
    context_length: int,
    normalize: bool = True,
    transpose: bool = False,
    prediction_length: int | None = None,
    **kwargs,
) -> np.ndarray:
    context = trajectory[:context_length]
    if normalize:
        normalizer = StandardizeTransform()
        context = normalizer(context, axis=0)

    if prediction_length is None:
        prediction_length = trajectory.shape[0] - context_length

    if transpose:
        context = context.T

    predictions = (
        pipeline.predict(
            context=torch.tensor(context).float(),
            prediction_length=prediction_length,
            limit_prediction_length=False,
            **kwargs,
        )
        .squeeze()
        .cpu()
        .numpy()
    )
    full_trajectory = np.concatenate([context, predictions], axis=1 if transpose else 0)

    if transpose:
        full_trajectory = full_trajectory.T

    if normalize:
        return normalizer(
            full_trajectory,
            axis=0,
            context=trajectory[:context_length],
            denormalize=True,
        )

    return full_trajectory

In [None]:
dyst_name = "YuWang2_Coullet"
data_dir = [
    f"/stor/work/AMDG_Gilpin_Summer2024/data/copy/final_skew40/test_zeroshot/{dyst_name}"
]

In [None]:
data_paths = []
for data_dir in data_dir:
    data_paths.extend(filter(lambda file: file.is_file(), Path(data_dir).rglob("*")))
print(data_paths)

In [None]:
len(data_paths)

In [None]:
system_idx = 0
dataset = FileDataset(
    path=data_paths[system_idx],
    freq="h",
    one_dim_target=False,
)
trajectory = next(iter(dataset))["target"]

In [None]:
trajectory.shape

In [None]:
plot_trajs_multivariate(np.expand_dims(trajectory, axis=0), show_plot=True)

In [None]:
context_length = 512
prediction_length = 128

start_time = 1024
end_time = start_time + context_length

In [None]:
traj = trajectory.T

In [None]:
preds = forecast(
    pipeline,
    traj[start_time:end_time],
    context_length,
    prediction_length=prediction_length,
    normalize=True,
    sliding_context=True,
)

In [None]:
preds.shape

In [None]:
predictions = preds.T

In [None]:
plot_trajs_multivariate(np.expand_dims(predictions, axis=0), show_plot=True)

In [None]:
traj.shape

In [None]:
plot_forecast_evaluation(
    np.expand_dims(predictions, axis=0),
    np.expand_dims(trajectory[:, start_time : end_time + prediction_length], axis=0),
    context_length,
    show_plot=True,
    plot_name=dyst_name,
)