In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.lines import Line2D

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_emb_w_poly-0/checkpoint-final",
    device_map="cuda:1",
)

In [None]:
def optimize_start_and_stride(
    model,
    data,
    context_length,
    prediction_length,
    ntrials=100,
    start_buffer_proportion=0.5,
    min_strided_pts=10,
):
    def satisfied(start, stride):
        return (data.shape[0] - start) / stride >= context_length + prediction_length

    def objective(start, stride):
        strided_data = data[start::stride]
        context = (
            torch.from_numpy(strided_data[:context_length]).float().to(model.device)
        )
        future = strided_data[context_length : context_length + prediction_length]
        pred = (
            model.predict(context, prediction_length).squeeze().detach().cpu().numpy()
        )
        return np.linalg.norm(pred - future)

    def sample_start_and_stride():
        valid_length = data.shape[0] - context_length - prediction_length
        start_buffer_length = int(valid_length * start_buffer_proportion)
        start = np.random.randint(0, valid_length - start_buffer_length)
        stride = np.random.randint(1, (valid_length - start) // min_strided_pts)
        return start, stride

    best_cost = np.inf
    best_start = None
    best_stride = None
    for trial in range(ntrials):
        start, stride = sample_start_and_stride()
        while not satisfied(start, stride):
            start, stride = sample_start_and_stride()

        cost = objective(start, stride)
        if cost < best_cost:
            best_cost = cost
            best_start = start
            best_stride = stride

    return best_start, best_stride, best_cost


In [None]:
def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    title: str | None = None,
    show: bool = True,
    transpose: bool = True,
    **kwargs,
) -> np.ndarray:
    context = data[:, :context_length]
    groundtruth = data[:, context_length : context_length + prediction_length]
    context_tensor = torch.from_numpy(context.T if transpose else context).float()
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    if not transpose:
        pred = pred.T

    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    if show:
        fig = plt.figure(figsize=(15, 4))

        outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)
        gs = outer_grid[1].subgridspec(
            3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0
        )
        ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
        ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
        ax_3d.plot(*groundtruth[:3], linestyle="--", color="black", label="Groundtruth")
        ax_3d.plot(*pred.T[:3], color="red", label="Prediction")
        ax_3d.legend(loc="upper right", fontsize=12)
        ax_3d.set_xlabel("$x_1$")
        ax_3d.set_ylabel("$x_2$")
        ax_3d.set_zlabel("$x_3$")
        if title is not None:
            ax_3d.set_title(title)

        axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
        for i, ax in enumerate(axes_1d):
            ax.plot(context_ts, context[i], alpha=0.5, color="black")
            ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
            ax.plot(pred_ts, pred[:, i], color="red")
            ax.set_ylabel(f"$x_{i + 1}$")
            ax.set_aspect("auto")
        axes_1d[-1].set_xlabel("Time")

        plt.show()

    return pred


# Double Pendulum

In [None]:
SPLIT = "train"
INDEX = 0
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
fpath = f"{base_dir}/double_pendulum_chaotic/train_and_test_split/dpc_dataset_traintest_4_200_csv/{SPLIT}/{INDEX}.csv"
pendulum_data = np.loadtxt(fpath)
print(pendulum_data.shape)

In [None]:
## The position of the pivot point (mostly constant)
plt.plot(pendulum_data[:, 1], -pendulum_data[:, 0])

## The position of the tip of the first pendulum
plt.plot(pendulum_data[:, 3], -pendulum_data[:, 2])

## The position of the tip of the second pendulum
plt.plot(pendulum_data[:, 5], -pendulum_data[:, 4])

In [None]:
best_start, best_stride, best_cost = optimize_start_and_stride(
    pft_model, pendulum_data[:, :4], 512, 128, ntrials=500, start_buffer_proportion=0.4
)
print(best_start, best_stride, best_cost)
subsampled_pendulum_data = pendulum_data[best_start::best_stride, -4:]
stand_subsampled_pendulum_data = safe_standardize(subsampled_pendulum_data, axis=0)
standpred = plot_model_prediction(
    pft_model, stand_subsampled_pendulum_data.T, 512, 128, show=False
)
print(subsampled_pendulum_data.shape, pendulum_data.shape)
pred = safe_standardize(
    standpred, axis=0, context=subsampled_pendulum_data[:512], denormalize=True
)

plt.figure()

# ## The position of the pivot point (mostly constant)
# plt.plot(stand_subsampled_pendulum_data[:512, 1], -stand_subsampled_pendulum_data[:512, 0], color="black")
# plt.plot(standpred[:, 1], -standpred[:, 0], color="black")

## The position of the tip of the first pendulum
plt.plot(
    subsampled_pendulum_data[:512, 1],
    -subsampled_pendulum_data[:512, 0],
    alpha=0.5,
    color="black",
)
plt.plot(
    subsampled_pendulum_data[512 : 512 + 128, 1],
    -subsampled_pendulum_data[512 : 512 + 128, 0],
    color="blue",
    linestyle="--",
)
plt.plot(pred[:, 1], -pred[:, 0], color="blue")

## The position of the tip of the second pendulum
plt.plot(
    subsampled_pendulum_data[:512, 3],
    -subsampled_pendulum_data[:512, 2],
    alpha=0.5,
    color="black",
)
plt.plot(
    subsampled_pendulum_data[512 : 512 + 128, 3],
    -subsampled_pendulum_data[512 : 512 + 128, 2],
    linestyle="--",
    color="black",
)
plt.plot(pred[:, 3], -pred[:, 2], color="red")


legend_elements = [
    Line2D([0], [0], color="black", alpha=0.5, label="Context"),
    Line2D([0], [0], color="black", linestyle="--", label="Ground Truth"),
    Line2D([0], [0], color="red", label="Prediction"),
]
plt.legend(handles=legend_elements)

plt.title("Double Pendulum");

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(10, 10))

context_ts = np.arange(512) / (512 + 128)
pred_ts = np.arange(512, 512 + 128) / (512 + 128)
for i, ax in enumerate(axes.flatten()):
    ax.plot(context_ts, subsampled_pendulum_data[:512, i], color="black", alpha=0.5)
    ax.plot(
        pred_ts,
        subsampled_pendulum_data[512 : 512 + 128, i],
        color="black",
        linestyle="--",
    )
    ax.plot(pred_ts, pred[:, i], color="red")

plt.show()


# Eigenworms

In [None]:
INDEX = 0
fpath = f"{base_dir}/worm_behavior/data/worm_{INDEX}.pkl"
worm_data = np.load(fpath, allow_pickle=True)
print(worm_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(worm_data[:1000, 0], worm_data[:1000, 1], worm_data[:1000, 2]);

In [None]:
# de-NaN the data with linear interpolation
time_idx = np.arange(len(worm_data))
for d in range(worm_data.shape[1]):
    mask = np.isnan(worm_data[:, d])
    if mask.any():
        valid = ~mask
        worm_data[:, d] = np.interp(time_idx, time_idx[valid], worm_data[valid, d])
assert not np.isnan(worm_data).any()

best_start, best_stride, best_cost = optimize_start_and_stride(
    pft_model,
    worm_data,
    512,
    128,
    ntrials=500,
    start_buffer_proportion=0.10,
    min_strided_pts=100,
)
print(best_start, best_stride, best_cost)
subsampled_worm_data = worm_data[best_start::best_stride, :]
stand_subsampled_worm_data = safe_standardize(subsampled_worm_data, axis=0)
_ = plot_model_prediction(
    pft_model, stand_subsampled_worm_data.T, 512, 128, title="Eigenworm"
)

# Turbulent Boundary Layer

In [None]:
turbpca_data = np.load(
    f"{base_dir}/turbulence/BLexp_Re980_pca10.pkl", allow_pickle=True
)
print(turbpca_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(turbpca_data[:, 0], turbpca_data[:, 1], turbpca_data[:, 2])

In [None]:
best_start, best_stride, best_cost = optimize_start_and_stride(
    pft_model,
    turbpca_data,
    512,
    128,
    ntrials=500,
    start_buffer_proportion=0.10,
    min_strided_pts=100,
)
print(best_start, best_stride, best_cost)
subsampled_turbpca_data = turbpca_data[best_start::best_stride, :]
stand_subsampled_turbpca_data = safe_standardize(subsampled_turbpca_data, axis=0)
_ = plot_model_prediction(
    pft_model,
    stand_subsampled_turbpca_data.T,
    512,
    128,
    title="Turbulent Boundary Layer PCA modes",
)

# Von Karman Street

In [None]:
# Von Karman Street
RE_VAL = 1200
vortex_fpath = (
    f"{base_dir}/von_karman_street/vortex_street_vorticities_Re_{RE_VAL}_pca10.pkl"
)
pod_fpath = f"{base_dir}/von_karman_street/vortex_street_pod_Re_{RE_VAL}_long.npz"
vel_fpath = f"{base_dir}/von_karman_street/vortex_street_velocities_Re_{RE_VAL}.npz"
vortex_data = np.load(vortex_fpath, allow_pickle=True)
pod_data = np.load(pod_fpath, allow_pickle=True)
vfield = np.load(vel_fpath, allow_pickle=True)
print(vortex_data.shape, pod_data.shape, vfield.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(vortex_data[:, 0], vortex_data[:, 1], vortex_data[:, 2], linewidth=0.5)

In [None]:
best_start, best_stride, best_cost = optimize_start_and_stride(
    pft_model,
    vortex_data,
    512,
    128,
    ntrials=500,
    start_buffer_proportion=0.90,
    min_strided_pts=10,
)
print(best_start, best_stride, best_cost)
subsampled_vortex_data = vortex_data[best_start::best_stride, :]
stand_subsampled_vortex_data = safe_standardize(subsampled_vortex_data, axis=0)
predictions = plot_model_prediction(
    pft_model,
    stand_subsampled_vortex_data.T,
    512,
    128,
    title="Von Karman Vortex Sheet PCA modes",
)

# ECG

In [None]:
fpath = f"{base_dir}/electrocardiogram/ecg_train.csv.gz"
ecg_data = np.loadtxt(fpath, delimiter=",")
print(ecg_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(ecg_data[:1000], ecg_data[1:1001], ecg_data[2:1002])

In [None]:
context_length = 512
best_start, best_stride, best_cost = optimize_start_and_stride(
    pft_model,
    ecg_data[:, None],
    context_length,
    128,
    ntrials=1000,
    start_buffer_proportion=0.10,
    min_strided_pts=100,
)
print(best_start, best_stride, best_cost)
subsampled_ecg_data = ecg_data[best_start::best_stride]
stand_subsampled_ecg_data = safe_standardize(subsampled_ecg_data, axis=0)
standpred = plot_model_prediction(
    pft_model, stand_subsampled_ecg_data[None, :], context_length, 128, show=False
)
pred = safe_standardize(
    standpred, axis=0, context=subsampled_ecg_data[:context_length], denormalize=True
)

context_ts = np.arange(context_length) / (context_length + 128)
pred_ts = np.arange(context_length, context_length + 128) / (context_length + 128)

plt.title("ECG")
plt.plot(
    context_ts,
    subsampled_ecg_data[:context_length],
    color="black",
    alpha=0.5,
    label="context",
)
plt.plot(
    pred_ts,
    subsampled_ecg_data[context_length : context_length + 128],
    color="black",
    linestyle="--",
    label="groundtruth",
)
plt.plot(pred_ts, pred, color="red", label="prediction")
plt.legend()
plt.show()


In [None]:
def timelag_embedding(data, lag, dims):
    """
    Embed a univariate time series into a higher-dimensional space using time-lagged embedding.

    Args:
        data: Input data array (n_timesteps,)
        lag: Time lag for embedding
    """
    n_timesteps = data.shape[0]
    embedded_data = np.zeros((n_timesteps - lag * (dims - 1), dims))
    for i in range(dims):
        embedded_data[:, i] = data[i * lag : i * lag + n_timesteps - lag * (dims - 1)]
    return embedded_data


ecg_lagged = timelag_embedding(ecg_data, 2, 10)

In [None]:
stand_ecg_lagged = safe_standardize(ecg_lagged, axis=0)
predictions = plot_model_prediction(
    pft_model,
    stand_ecg_lagged.T,
    512,
    128,
    title="ECG Time-lagged embedding",
)

In [None]:
def lag_error_scaling(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    ntrials: int,
    lags: list[int],
    dims: list[int],
):
    errors = np.zeros((len(lags), len(dims)))
    for i, lag in enumerate(lags):
        for j, dim in enumerate(dims):
            embedded_data = timelag_embedding(data, lag, dim)
            for start in np.random.randint(
                0, len(embedded_data) - context_length - prediction_length, size=ntrials
            ):
                context = embedded_data[start : start + context_length]
                stand_context = safe_standardize(context, axis=0)
                stand_predictions = plot_model_prediction(
                    model,
                    stand_context.T,
                    context_length,
                    prediction_length,
                    show=False,
                )
                if stand_predictions.ndim == 1:
                    stand_predictions = stand_predictions[:, None]
                predictions = safe_standardize(
                    stand_predictions, axis=0, context=context, denormalize=True
                )
                errors[i, j] += (
                    np.linalg.norm(
                        predictions
                        - embedded_data[
                            start + context_length : start
                            + context_length
                            + prediction_length
                        ]
                    )
                    / dim
                )
    errors /= ntrials
    return errors

In [None]:
lags = np.arange(1, 11)
dims = np.arange(1, 21)
errors = lag_error_scaling(
    pft_model, ecg_data, 512, 128, ntrials=10, lags=lags, dims=dims
)

In [None]:
plt.figure()
plt.imshow(errors)
plt.yticks(range(len(lags)), lags)
plt.xticks(range(len(dims)), dims)
plt.ylabel("Lag")
plt.xlabel("Dimensions")
plt.title("ECG Time-lagged embedding per-dimension error scaling")
plt.colorbar(shrink=0.5)
plt.show()

In [None]:
from dystformer.chronos.pipeline import ChronosPipeline

chronos = ChronosPipeline.from_pretrained(
    "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_mini_ft-0/checkpoint-final",
    device_map="cuda:5",
    torch_dtype=torch.float32,
)
chronos

In [None]:
prediction = plot_model_prediction(
    chronos,
    ecg_lagged.T,
    512,
    128,
    show=True,
    transpose=False,
    num_samples=1,
    limit_prediction_length=False,
)

In [None]:
ntrials = 10
prediction_length = 128
errors = np.zeros((len(lags), len(dims)))
for i, lag in enumerate(lags):
    for j, dim in enumerate(dims):
        embedded_data = timelag_embedding(ecg_data, lag, dim)
        for start in np.random.randint(
            0, len(embedded_data) - context_length - prediction_length, size=ntrials
        ):
            print(f"{lag=}, {dim=}, {start=}")
            context = embedded_data[start : start + context_length]
            predictions = plot_model_prediction(
                chronos,
                context.T,
                context_length,
                prediction_length,
                show=False,
                transpose=False,
                num_samples=1,
                limit_prediction_length=False,
            )
            if predictions.ndim == 1:
                predictions = predictions[:, None]
            errors[i, j] += (
                np.linalg.norm(
                    predictions
                    - embedded_data[
                        start + context_length : start
                        + context_length
                        + prediction_length
                    ]
                )
                / dim
            )
errors /= ntrials

In [None]:
plt.figure()
plt.imshow(errors)
plt.yticks(range(len(lags)), lags)
plt.xticks(range(len(dims)), dims)
plt.ylabel("Lag")
plt.xlabel("Dimensions")
plt.title("ECG Time-lagged embedding per-dimension error scaling")
plt.colorbar(shrink=0.5)
plt.show()