In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.metrics import compute_metrics
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from dystformer.chronos.pipeline import ChronosPipeline
from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_emb_w_poly-0/checkpoint-final",
    device_map="cuda:1",
)

In [None]:
chronos = ChronosPipeline.from_pretrained(
    "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_bolt_mini-12/checkpoint-final",
    device_map="cuda:3",
    torch_dtype=torch.float32,
)

In [None]:
def forecast(
    model,
    context: np.ndarray,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Args:
        model: The model to use for forecasting.
        context: The context to forecast (n_timesteps, n_features)
        context_length: The length of the context.
        prediction_length: The length of the prediction.
        transpose: Whether to transpose the data.

    Returns:
        The forecasted data (prediction_length, n_features)
    """
    preprocessed_context = context.copy().T if transpose else context.copy()
    if standardize:
        preprocessed_context = safe_standardize(preprocessed_context, axis=0)
    context_tensor = torch.from_numpy(preprocessed_context).float()
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    if transpose:
        pred = pred.T

    if standardize:
        pred = safe_standardize(pred, axis=0, context=context, denormalize=True)

    return pred


def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    transpose: bool = True,
    save_path: str | None = None,
    **kwargs,
):
    context = data[:context_length]
    groundtruth = data[context_length : context_length + prediction_length]
    prediction = forecast(model, context, prediction_length, transpose, **kwargs)
    metrics = compute_metrics(
        prediction, groundtruth, include_metrics=["mse", "mae", "smape"]
    )
    print(metrics)

    total_length = context_length + prediction_length
    context_ts = np.arange(context_length + 1) / total_length
    pred_ts = np.arange(context_length, total_length) / total_length
    fig = plt.figure(figsize=(15, 4))

    outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
    ax_3d.plot(*context.T[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth.T[:3], linestyle="--", color="black", label="Groundtruth")
    ax_3d.plot(*prediction.T[:3], color="red", label="Prediction")
    ax_3d.legend(loc="upper right", fontsize=12)
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, data[i, : context_length + 1], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[:, i], linestyle="--", color="black")
        ax.plot(pred_ts, prediction[:, i], color="red")
        ax.set_ylabel(f"$x_{i + 1}$")
        ax.set_aspect("auto")
    axes_1d[-1].set_xlabel("Time")

    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path)
    plt.close()


# Double Pendulum

In [None]:
SPLIT = "train"
INDEX = 0
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
fpath = f"{base_dir}/double_pendulum_chaotic/train_and_test_split/dpc_dataset_traintest_4_200_csv/{SPLIT}/{INDEX}.csv"
pendulum_data = np.loadtxt(fpath)
print(pendulum_data.shape)

# data is non-stationary, subsample and detrend it
subsampled_pendulum_data = pendulum_data[::10, -4:]
subsampled_pendulum_diff = np.diff(subsampled_pendulum_data, axis=0)

## The position of the pivot point (mostly constant)
plt.plot(pendulum_data[:, 1], -pendulum_data[:, 0])

## The position of the tip of the first pendulum
plt.plot(pendulum_data[:, 3], -pendulum_data[:, 2])

## The position of the tip of the second pendulum
plt.plot(pendulum_data[:, 5], -pendulum_data[:, 4])

In [None]:
diff_prediction = forecast(
    pft_model,
    subsampled_pendulum_diff[:512],
    128,
    standardize=True,
    limit_prediction_length=False,
    sliding_context=True,
)

pft_prediction = subsampled_pendulum_data[512] + diff_prediction.cumsum(axis=0)

pft_metrics = compute_metrics(
    pft_prediction,
    subsampled_pendulum_data[512 : 512 + 128],
    include=["mse", "mae", "smape"],
)
print(pft_metrics)

In [None]:
diff_prediction = forecast(
    chronos,
    subsampled_pendulum_diff[:512],
    128,
    standardize=False,
    transpose=True,
    limit_prediction_length=False,
    num_samples=1,
)
chronos_prediction = subsampled_pendulum_data[512] + diff_prediction.cumsum(axis=0)

chronos_metrics = compute_metrics(
    chronos_prediction,
    subsampled_pendulum_data[512 : 512 + 128],
    include=["mse", "mae", "smape"],
)
print(chronos_metrics)

In [None]:
plt.figure(figsize=(10, 10))

# The position of the tip of the first pendulum
plt.plot(
    subsampled_pendulum_data[:512, 1],
    -subsampled_pendulum_data[:512, 0],
    alpha=0.5,
    color="black",
)
plt.plot(
    subsampled_pendulum_data[512 : 512 + 128, 1],
    -subsampled_pendulum_data[512 : 512 + 128, 0],
    color="black",
    linestyle="--",
)
plt.plot(pft_prediction[:, 1], -pft_prediction[:, 0], linestyle="--", color="blue")
plt.plot(
    chronos_prediction[:, 1], -chronos_prediction[:, 0], linestyle="--", color="red"
)

## The position of the tip of the second pendulum
plt.plot(
    subsampled_pendulum_data[:512, 3],
    -subsampled_pendulum_data[:512, 2],
    alpha=0.5,
    color="black",
)
plt.plot(
    subsampled_pendulum_data[512 : 512 + 128, 3],
    -subsampled_pendulum_data[512 : 512 + 128, 2],
    linestyle="--",
    color="black",
)
plt.plot(pft_prediction[:, 3], -pft_prediction[:, 2], color="blue")
plt.plot(chronos_prediction[:, 3], -chronos_prediction[:, 2], color="red")

axins = inset_axes(plt.gca(), width="40%", height="20%", loc="upper right")
width = 0.3
# metrics = ["mse", "mae", "smape"]
metrics = ["smape"]
for i, metric in enumerate(metrics):
    offset = 0
    axins.bar(
        np.arange(2) * width + offset,
        [pft_metrics[metric], chronos_metrics[metric]],
        width,
        color=["blue", "red"],
        label=metric,
    )
axins.set_xticks([])
# axins.legend(loc="upper left", ncols=3)

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 10))

context_ts = np.arange(512 + 1) / (512 + 128)
pred_ts = np.arange(512, 512 + 128) / (512 + 128)
for i, ax in enumerate(axes.flatten()):
    ax.plot(
        context_ts, subsampled_pendulum_data[: 512 + 1, i], color="black", alpha=0.5
    )
    ax.plot(
        pred_ts,
        subsampled_pendulum_data[512 : 512 + 128, i],
        color="black",
        linestyle="--",
    )
    ax.plot(pred_ts, prediction[:, i], color="red")

plt.show()


# Eigenworms

In [None]:
INDEX = 4
fpath = f"{base_dir}/worm_behavior/data/worm_{INDEX}.pkl"
worm_data = np.load(fpath, allow_pickle=True)

# de-NaN the data with linear interpolation
time_idx = np.arange(len(worm_data))
for d in range(worm_data.shape[1]):
    mask = np.isnan(worm_data[:, d])
    if mask.any():
        valid = ~mask
        worm_data[:, d] = np.interp(time_idx, time_idx[valid], worm_data[valid, d])
assert not np.isnan(worm_data).any()

print(worm_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(worm_data[:1000, 0], worm_data[:1000, 1], worm_data[:1000, 2]);

In [None]:
start = 0
stride = 1
subsampled_worm_data = worm_data[start::stride, :]
stand_subsampled_worm_data = safe_standardize(subsampled_worm_data, axis=0)
_ = plot_model_prediction(
    pft_model,
    stand_subsampled_worm_data.T,
    512,
    512,
    title="Eigenworm",
    sliding_context=True,
    limit_prediction_length=False,
)

In [None]:
start = 0
subsampled_worm_data = worm_data[start::stride, :]
stand_subsampled_worm_data = safe_standardize(subsampled_worm_data, axis=0)
_ = plot_model_prediction(
    chronos,
    stand_subsampled_worm_data.T,
    512,
    512,
    title="Eigenworm",
    transpose=False,
    limit_prediction_length=False,
    num_samples=1,
)

# Turbulent Boundary Layer

In [None]:
turbpca_data = np.load(
    f"{base_dir}/turbulence/BLexp_Re980_pca10.pkl", allow_pickle=True
)
print(turbpca_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(turbpca_data[:, 0], turbpca_data[:, 1], turbpca_data[:, 2])

In [None]:
start = 0
stride = 1
subsampled_turbpca_data = turbpca_data[start::stride, :]
stand_subsampled_turbpca_data = safe_standardize(subsampled_turbpca_data, axis=0)
_ = plot_model_prediction(
    pft_model,
    stand_subsampled_turbpca_data.T,
    512,
    512,
    title="Turbulent Boundary Layer PCA modes",
    sliding_context=True,
    limit_prediction_length=False,
)

In [None]:
subsampled_turbpca_data = turbpca_data[start::stride, :]
stand_subsampled_turbpca_data = safe_standardize(subsampled_turbpca_data, axis=0)
_ = plot_model_prediction(
    chronos,
    stand_subsampled_turbpca_data.T,
    512,
    512,
    title="Turbulent Boundary Layer PCA modes",
    transpose=False,
    limit_prediction_length=False,
    num_samples=1,
)

# ECG

In [None]:
fpath = f"{base_dir}/electrocardiogram/ecg_train.csv.gz"
ecg_data = np.loadtxt(fpath, delimiter=",")
print(ecg_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(ecg_data[:1000], ecg_data[1:1001], ecg_data[2:1002])

In [None]:
context_length = 512
prediction_length = 512
start = 0
stride = 1
subsampled_ecg_data = ecg_data[start::stride]
stand_subsampled_ecg_data = safe_standardize(subsampled_ecg_data, axis=0)
standpred = plot_model_prediction(
    pft_model,
    stand_subsampled_ecg_data[None, :],
    context_length,
    prediction_length,
    show=False,
    limit_prediction_length=False,
)
pred = safe_standardize(
    standpred, axis=0, context=subsampled_ecg_data[:context_length], denormalize=True
)
mse = (
    np.linalg.norm(
        pred - subsampled_ecg_data[context_length : context_length + prediction_length]
    )
    / prediction_length
)
smape_error = (
    smape(
        pred, subsampled_ecg_data[context_length : context_length + prediction_length]
    )
    / 2
)

context_ts = np.arange(context_length + 1) / (context_length + prediction_length)
pred_ts = np.arange(context_length, context_length + prediction_length) / (
    context_length + prediction_length
)

plt.figure(figsize=(15, 5))
plt.title(f"ECG (MSE: {mse:.4f}, SMAPE: {smape_error:.4f})")
plt.plot(
    context_ts,
    subsampled_ecg_data[: context_length + 1],
    color="black",
    alpha=0.5,
    label="context",
)
plt.plot(
    pred_ts,
    subsampled_ecg_data[context_length : context_length + prediction_length],
    color="black",
    linestyle="--",
    label="groundtruth",
)
plt.plot(pred_ts, pred, color="red", label="prediction")
plt.legend()
plt.show()


In [None]:
standpred = plot_model_prediction(
    chronos,
    stand_subsampled_ecg_data[None, :],
    context_length,
    prediction_length,
    show=False,
    transpose=False,
    num_samples=1,
    limit_prediction_length=False,
)
pred = safe_standardize(
    standpred, axis=0, context=subsampled_ecg_data[:context_length], denormalize=True
)
mse = (
    np.linalg.norm(
        pred - subsampled_ecg_data[context_length : context_length + prediction_length]
    )
    / prediction_length
)
smape_error = (
    smape(
        pred, subsampled_ecg_data[context_length : context_length + prediction_length]
    )
    / 2
)

plt.figure(figsize=(15, 5))
plt.title(f"ECG (MSE: {mse:.4f}, SMAPE: {smape_error:.4f})")
plt.plot(
    context_ts,
    subsampled_ecg_data[: context_length + 1],
    color="black",
    alpha=0.5,
    label="context",
)
plt.plot(
    pred_ts,
    subsampled_ecg_data[context_length : context_length + prediction_length],
    color="black",
    linestyle="--",
    label="groundtruth",
)
plt.plot(pred_ts, pred, color="red", label="prediction")
plt.legend()
plt.show()


In [None]:
def timelag_embedding(data, lag, dims):
    """
    Embed a univariate time series into a higher-dimensional space using time-lagged embedding.

    Args:
        data: Input data array (n_timesteps,)
        lag: Time lag for embedding
    """
    n_timesteps = data.shape[0]
    embedded_data = np.zeros((n_timesteps - lag * (dims - 1), dims))
    for i in range(dims):
        embedded_data[:, i] = data[i * lag : i * lag + n_timesteps - lag * (dims - 1)]
    return embedded_data

In [None]:
ecg_lagged = timelag_embedding(ecg_data, 1, 10)
stand_ecg_lagged = safe_standardize(ecg_lagged, axis=0)
predictions = plot_model_prediction(
    pft_model,
    stand_ecg_lagged.T,
    512,
    512,
    title="ECG Time-lagged embedding",
    sliding_context=False,
    limit_prediction_length=False,
)

In [None]:
predictions = plot_model_prediction(
    chronos,
    stand_ecg_lagged.T,
    512,
    512,
    title="ECG Time-lagged embedding",
    transpose=False,
    num_samples=1,
    limit_prediction_length=False,
)

In [None]:
def lag_error_scaling(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    ntrials: int,
    lags: list[int],
    dims: list[int],
    **kwargs,
):
    errors = np.zeros((len(lags), len(dims)))
    for i, lag in enumerate(lags):
        for j, dim in enumerate(dims):
            print(f"{lag=}, {dim=}")
            embedded_data = timelag_embedding(data, lag, dim)
            for start in np.random.randint(
                0, len(embedded_data) - context_length - prediction_length, size=ntrials
            ):
                context = embedded_data[start:]
                stand_context = safe_standardize(context, axis=0)
                stand_predictions = plot_model_prediction(
                    model,
                    stand_context.T,
                    context_length,
                    prediction_length,
                    show=False,
                    **kwargs,
                )
                if stand_predictions.ndim == 1:
                    stand_predictions = stand_predictions[:, None]
                predictions = safe_standardize(
                    stand_predictions, axis=0, context=context, denormalize=True
                )
                errors[i, j] += (
                    np.linalg.norm(
                        predictions
                        - embedded_data[
                            start + context_length : start
                            + context_length
                            + prediction_length
                        ]
                    )
                    / prediction_length
                )
    errors /= ntrials
    return errors

In [None]:
lags = np.arange(1, 11)
dims = np.arange(1, 11)
errors = lag_error_scaling(
    pft_model,
    ecg_data,
    512,
    512,
    ntrials=10,
    lags=lags,
    dims=dims,
    sliding_context=False,
    limit_prediction_length=False,
)

In [None]:
plt.figure()
plt.imshow(errors)
plt.yticks(range(len(lags)), lags)
plt.xticks(range(len(dims)), dims)
plt.ylabel("Lag")
plt.xlabel("Dimensions")
plt.title("ECG Time-lagged embedding per-dimension error scaling")
plt.colorbar(shrink=0.5)
plt.show()

In [None]:
errors = lag_error_scaling(
    chronos,
    ecg_data,
    512,
    512,
    ntrials=10,
    lags=lags,
    dims=dims,
    limit_prediction_length=False,
    transpose=False,
    num_samples=1,
)

In [None]:
plt.figure()
plt.imshow(errors)
plt.yticks(range(len(lags)), lags)
plt.xticks(range(len(dims)), dims)
plt.ylabel("Lag")
plt.xlabel("Dimensions")
plt.title("ECG Time-lagged embedding per-dimension error scaling")
plt.colorbar(shrink=0.5)
plt.show()