In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_emb_w_poly-0/checkpoint-final",
    device_map="cuda:1",
)

In [None]:
def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    title: str | None = None,
    show: bool = True,
    **kwargs,
):
    context = data[:, :context_length]
    groundtruth = data[:, context_length : context_length + prediction_length]
    context_tensor = torch.from_numpy(context.T).float()
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    if not show:
        return pred
    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(15, 4))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)

    # Create sub-grid for the plots
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
    ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth[:3], linestyle="--", color="black", label="Groundtruth")
    ax_3d.plot(*pred.T[:3], color="red", label="Prediction")
    ax_3d.legend(loc="upper right", fontsize=12)
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")
    if title is not None:
        ax_3d.set_title(title)

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, context[i], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
        ax.plot(pred_ts, pred[:, i], color="red")
        ax.set_ylabel(f"$x_{i + 1}$")
        ax.set_aspect("auto")
    axes_1d[-1].set_xlabel("Time")

    plt.show()


# Double Pendulum

In [None]:
SPLIT = "train"
INDEX = 0
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
fpath = f"{base_dir}/double_pendulum_chaotic/train_and_test_split/dpc_dataset_traintest_4_200_csv/{SPLIT}/{INDEX}.csv"
pendulum_data = np.loadtxt(fpath)
print(pendulum_data.shape)

In [None]:
## The position of the pivot point (mostly constant)
plt.plot(pendulum_data[:, 1], -pendulum_data[:, 0])

## The position of the tip of the first pendulum
plt.plot(pendulum_data[:, 3], -pendulum_data[:, 2])

## The position of the tip of the second pendulum
plt.plot(pendulum_data[:, 5], -pendulum_data[:, 4])

In [None]:
standpred = plot_model_prediction(
    pft_model, safe_standardize(pendulum_data, axis=0).T, 512, 128, show=False
)
pred = safe_standardize(standpred, axis=0, context=pendulum_data, denormalize=True)

## The position of the pivot point (mostly constant)
plt.plot(pendulum_data[:512, 1], -pendulum_data[:512, 0], color="black")
plt.plot(pred[:, 1], -pred[:, 0], color="black")

## The position of the tip of the first pendulum
plt.plot(pendulum_data[:512, 3], -pendulum_data[:512, 2], alpha=0.5, color="black")
plt.plot(
    pendulum_data[512 : 512 + 128, 3], -pendulum_data[512 : 512 + 128, 2], color="black"
)
plt.plot(pred[:, 3], -pred[:, 2], color="red")

## The position of the tip of the second pendulum
plt.plot(pendulum_data[:512, 5], -pendulum_data[:512, 4], alpha=0.5, color="black")
plt.plot(
    pendulum_data[512 : 512 + 128, 5],
    -pendulum_data[512 : 512 + 128, 4],
    linestyle="--",
    color="black",
)
plt.plot(pred[:, 5], -pred[:, 4], color="red")
# Create legend with custom handles
from matplotlib.lines import Line2D

legend_elements = [
    Line2D([0], [0], color="black", alpha=0.5, label="Context"),
    Line2D([0], [0], color="black", linestyle="--", label="Ground Truth"),
    Line2D([0], [0], color="red", label="Prediction"),
]
plt.legend(handles=legend_elements)

plt.title("Double Pendulum")

# Eigenworms

In [None]:
INDEX = 0
fpath = f"{base_dir}/worm_behavior/data/worm_{INDEX}.pkl"
worm_data = np.load(fpath, allow_pickle=True)
print(worm_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(worm_data[:1000, 0], worm_data[:1000, 1], worm_data[:1000, 2])

In [None]:
stand_worm_data = safe_standardize(worm_data, axis=0)
plot_model_prediction(pft_model, stand_worm_data.T, 512, 128, title="Eigenworm")

# Turbulent Boundary Layer

In [None]:
turbpca_data = np.load(
    f"{base_dir}/turbulence/BLexp_Re980_pca10.pkl", allow_pickle=True
)
print(turbpca_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(turbpca_data[:, 0], turbpca_data[:, 1], turbpca_data[:, 2])

In [None]:
stand_turbpca_data = safe_standardize(turbpca_data, axis=0)
plot_model_prediction(
    pft_model,
    stand_turbpca_data.T,
    512,
    128,
    title="Turbulent Boundary Layer PCA modes",
)

# Von Karman Street

In [None]:
# Von Karman Street
RE_VAL = 1200
fpath = f"{base_dir}/von_karman_street/vortex_street_vorticities_Re_{RE_VAL}_pca10.pkl"
vortex_data = np.load(fpath, allow_pickle=True)
print(vortex_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(vortex_data[:, 0], vortex_data[:, 1], vortex_data[:, 2], linewidth=0.5)

In [None]:
stand_vortex_data = safe_standardize(vortex_data, axis=0)
plot_model_prediction(
    pft_model, stand_vortex_data.T, 512, 128, title="Von Karman Vortex Sheet PCA modes"
)

# ECG

In [None]:
fpath = f"{base_dir}/electrocardiogram/ecg_train.csv.gz"
ecg_data = np.loadtxt(fpath, delimiter=",")
print(ecg_data.shape)

In [None]:
ax = plt.axes(projection="3d")
ax.plot3D(ecg_data[:1000], ecg_data[1:1001], ecg_data[2:1002])

In [None]:
stand_ecg_data = safe_standardize(ecg_data, axis=0)
standpred = plot_model_prediction(
    pft_model, stand_ecg_data[None, :], 512, 128, show=False
)
pred = safe_standardize(standpred, axis=0, context=ecg_data, denormalize=True)

context_ts = np.arange(512) / (512 + 128)
pred_ts = np.arange(512, 512 + 128) / (512 + 128)

plt.title("ECG")
plt.plot(context_ts, ecg_data[:512], color="black", alpha=0.5, label="context")
plt.plot(
    pred_ts,
    ecg_data[512 : 512 + 128],
    color="black",
    linestyle="--",
    label="groundtruth",
)
plt.plot(pred_ts, pred, color="red", label="prediction")
plt.legend()
plt.show()
