In [None]:
import os
from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import torch
from dysts.metrics import compute_metrics
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy.io import loadmat
from tqdm import tqdm

from dystformer.chronos.pipeline import ChronosPipeline
from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/pft_chattn_emb_w_poly-0/checkpoint-final",
    device_map="cuda:1",
)

In [None]:
chronos = ChronosPipeline.from_pretrained(
    "/stor/work/AMDG_Gilpin_Summer2024/checkpoints/chronos_bolt_mini-12/checkpoint-final",
    device_map="cuda:3",
    torch_dtype=torch.float32,
)

In [None]:
def forecast(
    model,
    context: np.ndarray,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = True,
    **kwargs,
) -> np.ndarray:
    """
    Args:
        model: The model to use for forecasting.
        context: The context to forecast (n_timesteps, n_features)
        context_length: The length of the context.
        prediction_length: The length of the prediction.
        transpose: Whether to transpose the data.

    Returns:
        The forecasted data (prediction_length, n_features)
    """
    preprocessed_context = context.copy().T if transpose else context.copy()
    if standardize:
        preprocessed_context = safe_standardize(
            preprocessed_context, axis=int(transpose)
        )
    context_tensor = torch.from_numpy(preprocessed_context).float()
    pred = (
        model.predict(context_tensor, prediction_length, verbose=False, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    if transpose:
        pred = pred.T

    if standardize:
        pred = safe_standardize(pred, axis=0, context=context, denormalize=True)

    return pred


def compute_rollout_metrics(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    step: int = 64,
    num_windows: int = 5,
    metrics: list[str] = ["mse", "mae", "smape"],
    **kwargs,
) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray], np.ndarray, list[np.ndarray]]:
    full_metrics = defaultdict(
        lambda: np.zeros((num_windows, prediction_length // step))
    )
    starts = np.random.randint(
        0, len(data) - context_length - prediction_length, num_windows
    )
    predictions = []
    for s in tqdm(range(num_windows), desc="Sampling contexts", total=num_windows):
        start = starts[s]
        context = data[start : start + context_length]
        prediction = forecast(model, context, prediction_length, **kwargs)
        for i in range(0, prediction_length, step):
            pred = prediction[i : i + step]

            gt = data[start + context_length + i : start + context_length + i + step]
            submetrics = compute_metrics(pred, gt, include=metrics)
            for k, v in submetrics.items():
                full_metrics[k][s, i // step] += v
        predictions.append(prediction)
    mean_metrics = {k: v.mean(axis=0) for k, v in full_metrics.items()}
    std_metrics = {
        k: v.std(axis=0) / np.sqrt(num_windows) for k, v in full_metrics.items()
    }
    return mean_metrics, std_metrics, starts, predictions


def plot_model_prediction(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = True,
    save_path: str | None = None,
    color: str = "red",
    **kwargs,
):
    context = data[:context_length]
    groundtruth = data[context_length : context_length + prediction_length]
    prediction = forecast(
        model, context, prediction_length, transpose, standardize, **kwargs
    )

    total_length = context_length + prediction_length
    context_ts = np.arange(context_length + 1)
    pred_ts = np.arange(context_length, total_length)

    fig = plt.figure(figsize=(15, 4))
    outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
    ax_3d.plot(*context.T[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth.T[:3], linestyle="--", color="black", label="Groundtruth")
    ax_3d.plot(*prediction.T[:3], color=color, label="Prediction")
    ax_3d.legend(loc="upper right", fontsize=8)
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, data[: context_length + 1, i], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[:, i], linestyle="--", color="black")
        ax.plot(pred_ts, prediction[:, i], color=color)
        ax.set_ylabel(f"$x_{i + 1}$")
        ax.set_aspect("auto")
    axes_1d[-1].set_xlabel("Time")

    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path)
    plt.close()


In [None]:
def plot_comparison(
    data: np.ndarray,
    pft_prediction: np.ndarray,
    chronos_prediction: np.ndarray,
    context_length: int,
    prediction_length: int,
    pft_metrics: dict[str, list[float]],
    chronos_metrics: dict[str, list[float]],
    step: int = 8,
    metric: str = "smape",
    num_ticks: int = 4,
):
    context = data[: context_length + 1, :3]
    groundtruth = data[context_length : context_length + prediction_length, :3]

    plt.figure(figsize=(6, 6))
    ax = plt.axes(projection="3d")
    ax._axis3don = False

    xmin, ymin, zmin = np.min(
        np.stack(
            [
                context.min(axis=0),
                groundtruth.min(axis=0),
                pft_prediction[:, :3].min(axis=0),
                chronos_prediction[:, :3].min(axis=0),
            ]
        ),
        axis=0,
    )
    xmax, ymax, zmax = np.max(
        np.stack(
            [
                context.max(axis=0),
                groundtruth.max(axis=0),
                pft_prediction[:, :3].max(axis=0),
                chronos_prediction[:, :3].max(axis=0),
            ]
        ),
        axis=0,
    )

    ax.xaxis.pane.set_visible(False)
    ax.yaxis.pane.set_visible(False)
    ax.zaxis.pane.set_visible(False)
    ax.grid(False)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zticks([])

    ax.plot3D(*context.T, alpha=0.1, color="black")
    ax.plot3D(*groundtruth.T, alpha=0.5, color="black", linestyle="--")
    ax.plot3D(*pft_prediction[:, :3].T, color="red")
    ax.plot3D(*chronos_prediction[:, :3].T, color="blue")

    ax.quiver(
        xmin,
        ymax,
        zmin,
        xmax - xmin,
        0,
        0,
        color="black",
        linewidth=2,
        arrow_length_ratio=0.05,
        zorder=5,
    )
    ax.quiver(
        xmin,
        ymax,
        zmin,
        0,
        -ymax + ymin,
        0,
        color="black",
        linewidth=2,
        arrow_length_ratio=0.05,
        zorder=5,
    )
    ax.quiver(
        xmin,
        ymax,
        zmin,
        0,
        0,
        zmax - zmin,
        color="black",
        linewidth=2,
        arrow_length_ratio=0.05,
        zorder=5,
    )
    ax.text(xmax + 1, ymax - 0.5, zmin, "X", fontsize=12)
    ax.text(xmin, ymin - 2, zmin, "Y", fontsize=12)
    ax.text(xmin - 0.5, ymax, zmax + 1, "Z", fontsize=12)

    steps = np.arange(0, prediction_length, step)
    axins = inset_axes(
        plt.gca(), width="40%", height="20%", loc="upper right", borderpad=1
    )
    axins.plot(steps + step, pft_metrics[metric], color="red")
    axins.plot(steps + step, chronos_metrics[metric], color="blue")
    axins.set_ylabel(metric)
    axins.set_xlabel("Prediction Length")
    axins.set_xticks(
        np.arange(
            prediction_length // num_ticks,
            prediction_length + prediction_length // num_ticks,
            prediction_length // num_ticks,
        )
    )
    axins.set_xlim(step, prediction_length)

# Double Pendulum

In [None]:
SPLIT = "train"
INDEX = 0
WORK = os.environ.get("WORK", "")
base_dir = f"{WORK}/physics-datasets"
fpath = f"{base_dir}/double_pendulum_chaotic/train_and_test_split/dpc_dataset_traintest_4_200_csv/{SPLIT}/{INDEX}.csv"
pendulum_data = np.loadtxt(fpath)
print(pendulum_data.shape)

# data is non-stationary, subsample and detrend it
subsampled_pendulum_data = pendulum_data[::10, -4:]
subsampled_pendulum_diff = np.diff(subsampled_pendulum_data, axis=0)

## The position of the pivot point (mostly constant)
plt.plot(pendulum_data[:, 1], -pendulum_data[:, 0])

## The position of the tip of the first pendulum
plt.plot(pendulum_data[:, 3], -pendulum_data[:, 2])

## The position of the tip of the second pendulum
plt.plot(pendulum_data[:, 5], -pendulum_data[:, 4])

In [None]:
context_length = 512
prediction_length = 128

metrics = ["mse", "mae", "smape"]

In [None]:
pft_prediction = forecast(
    pft_model,
    subsampled_pendulum_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
)
# pft_prediction = subsampled_pendulum_data[context_length] + pft_diff_prediction.cumsum(
#     axis=0
# )

pft_metrics = defaultdict(list)
step = 8
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        pft_prediction[0 : i + step],
        subsampled_pendulum_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        pft_metrics[metric].append(submetrics[metric])
print(pft_metrics)

In [None]:
chronos_prediction = forecast(
    chronos,
    subsampled_pendulum_data[:context_length],
    prediction_length,
    transpose=True,
    limit_prediction_length=False,
    num_samples=1,
    deterministic=True,
)

### Chronos diff prediction is bad
# chronos_prediction = subsampled_pendulum_data[context_length] + chronos_diff_prediction.cumsum(
#     axis=0
# )

chronos_metrics = defaultdict(list)
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        chronos_prediction[0 : i + step],
        subsampled_pendulum_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        chronos_metrics[metric].append(submetrics[metric])
print(chronos_metrics)

In [None]:
plt.figure(figsize=(10, 10))

## The position of the tip of the second pendulum
plt.plot(
    subsampled_pendulum_data[:context_length, 3],
    -subsampled_pendulum_data[:context_length, 2],
    alpha=0.1,
    color="black",
)
plt.plot(
    subsampled_pendulum_data[context_length : context_length + prediction_length, 3],
    -subsampled_pendulum_data[context_length : context_length + prediction_length, 2],
    alpha=0.5,
    color="black",
    linestyle="--",
)
plt.plot(pft_prediction[:, 3], -pft_prediction[:, 2], color="red")
plt.plot(chronos_prediction[:, 3], -chronos_prediction[:, 2], color="blue")

steps = np.arange(0, prediction_length, step)
axins = inset_axes(plt.gca(), width="40%", height="20%", loc="upper right", borderpad=1)
axins.plot(steps + step, pft_metrics["smape"], color="red")
axins.plot(steps + step, chronos_metrics["smape"], color="blue")
axins.set_ylabel("sMAPE")
axins.set_xlabel("Prediction Length")
axins.set_xticks([32, 64, 96, 128])
axins.set_xlim(step, prediction_length);

In [None]:
# sanity check
fig, axes = plt.subplots(3, 1, figsize=(10, 4), sharex=True)
plt.subplots_adjust(hspace=0.0)

context_ts = np.arange(context_length + 1)
pred_ts = np.arange(context_length, context_length + prediction_length)
for i, ax in enumerate(axes.flatten()):
    ax.plot(
        context_ts,
        subsampled_pendulum_data[: context_length + 1, i],
        color="black",
        alpha=0.2,
    )
    ax.plot(
        pred_ts,
        subsampled_pendulum_data[
            context_length : context_length + prediction_length, i
        ],
        color="black",
        linestyle="--",
        alpha=0.5,
    )
    ax.plot(pred_ts, pft_prediction[:, i], color="red")
    ax.plot(pred_ts, chronos_prediction[:, i], color="blue")

plt.show()

In [None]:
pft_mean_metrics, pft_std_metrics, pft_starts, pft_predictions = (
    compute_rollout_metrics(
        pft_model,
        subsampled_pendulum_data,
        context_length,
        prediction_length,
        step=8,
        num_windows=10,
        sliding_context=True,
        limit_prediction_length=False,
    )
)
chronos_mean_metrics, chronos_std_metrics, chronos_starts, chronos_predictions = (
    compute_rollout_metrics(
        chronos,
        subsampled_pendulum_data,
        context_length,
        prediction_length,
        step=8,
        num_windows=10,
        limit_prediction_length=False,
        num_samples=1,
        deterministic=True,
        transpose=True,
    )
)

In [None]:
# sanity check
fig, axes = plt.subplots(3, 1, figsize=(10, 4), sharex=True)
plt.subplots_adjust(hspace=0.0)

total_ts = np.arange(len(subsampled_pendulum_data))
for i, ax in enumerate(axes.flatten()):
    ax.plot(
        total_ts,
        subsampled_pendulum_data[:, i],
        color="black",
        alpha=0.2,
    )
    for j, (c, p) in enumerate(zip(chronos_starts, pft_starts)):
        pft_pred_ts = np.arange(
            p + context_length, p + context_length + prediction_length
        )
        chronos_context_ts = np.arange(c, c + context_length)
        chronos_pred_ts = np.arange(
            c + context_length, c + context_length + prediction_length
        )
        ax.plot(pft_pred_ts, pft_predictions[j][:, i], color="red", alpha=0.1)
        ax.plot(chronos_pred_ts, chronos_predictions[j][:, i], color="blue", alpha=0.1)

plt.show()


# Eigenworms

In [None]:
INDEX = 9
fpath = f"{base_dir}/worm_behavior/data/worm_{INDEX}.pkl"
worm_data = np.load(fpath, allow_pickle=True)[2048::1]
eigenworms = loadmat(f"{base_dir}/worm_behavior/data/EigenWorms.mat")["EigenWorms"]

# de-NaN the data with linear interpolation
time_idx = np.arange(len(worm_data))
for d in range(worm_data.shape[1]):
    mask = np.isnan(worm_data[:, d])
    if mask.any():
        valid = ~mask
        worm_data[:, d] = np.interp(time_idx, time_idx[valid], worm_data[valid, d])
assert not np.isnan(worm_data).any()

print(worm_data.shape)

In [None]:
from IPython.display import HTML
from matplotlib.animation import FuncAnimation


def reconstruct_worm(coeffs, eigenworms, segment_length=1.0):
    """
    Reconstruct a worm from its coefficients and the eigenworms.

    Args:
        coeffs: The coefficients of the worm (n_timesteps, n_eigenworms)
        eigenworms: The eigenworms (n_features, n_eigenworms)
        segment_length: The length of each segment of the worm.

    Returns:
        The reconstructed worm.
    """
    T, nworms = coeffs.shape
    n_segments = eigenworms.shape[0]
    basis = eigenworms[:, :nworms]
    theta = coeffs @ basis.T

    x = np.zeros((T, n_segments + 1))
    y = np.zeros((T, n_segments + 1))
    x[:, 1:] = segment_length * np.cos(theta)
    y[:, 1:] = segment_length * np.sin(theta)

    return x.cumsum(axis=1), y.cumsum(axis=1)


def animate_worm(x, y, num_frames=200, interval=50, save_path=None):
    """
    Create an animation of the worm's movement over time.

    Args:
        x: Array of x coordinates with shape (T, n_segments+1)
        y: Array of y coordinates with shape (T, n_segments+1)
        num_frames: Number of frames to include in the animation
        interval: Time between frames in milliseconds

    Returns:
        HTML animation that can be displayed in the notebook
    """
    fig, ax = plt.subplots(figsize=(8, 6))

    # Set consistent axis limits for the animation
    x_min, x_max = x.min(), x.max()
    y_min, y_max = y.min(), y.max()

    # Add some padding to the limits
    x_padding = (x_max - x_min) * 0.1
    y_padding = (y_max - y_min) * 0.1

    ax.set_xlim(x_min - x_padding, x_max + x_padding)
    ax.set_ylim(y_min - y_padding, y_max + y_padding)
    ax.set_aspect("equal")
    ax.set_title("Worm Movement")

    # Create line and fill objects
    line = ax.plot([], [], "b-", lw=2)[0]
    fill = ax.fill([], [], color="blue")
    time_text = ax.text(0.02, 0.95, "", transform=ax.transAxes)

    # Calculate width profile - increases toward middle, decreases toward ends
    n_points = x.shape[1]
    width_profile = np.zeros(n_points)
    max_width = 3  # Maximum width of the worm body
    for i in range(n_points):
        arg = 2 * i / (n_points - 1) - 1  # normalize to [-1, 1]
        width_profile[i] = max_width * (
            1
            / (1 + np.exp(-8 * (arg + 0.7)))
            * (1 - 1 / (1 + np.exp(-8 * (arg - 0.7))))
        )

    def init():
        line.set_data([], [])
        fill[0].set_xy(np.zeros((0, 2)))
        time_text.set_text("")
        return line, fill[0], time_text

    def update(frame):
        # Update the centerline
        line.set_data(x[frame], y[frame])

        # Calculate perpendicular vectors for width
        dx = np.diff(x[frame])
        dy = np.diff(y[frame])
        # Normalize and rotate 90 degrees to get perpendicular direction
        lengths = np.sqrt(dx**2 + dy**2)
        nx = -dy / lengths
        ny = dx / lengths

        # Create polygon vertices for the worm body
        vertices = []

        # Top edge (add points from head to tail)
        for i in range(n_points - 1):
            vertices.append(
                (
                    x[frame][i] + width_profile[i] * nx[i],
                    y[frame][i] + width_profile[i] * ny[i],
                )
            )

        # Bottom edge (add points from tail to head)
        for i in range(n_points - 2, -1, -1):
            vertices.append(
                (
                    x[frame][i] - width_profile[i] * nx[i],
                    y[frame][i] - width_profile[i] * ny[i],
                )
            )

        # Update the fill
        fill[0].set_xy(vertices)
        time_text.set_text(f"Frame: {frame}")

        return line, fill[0], time_text

    # Use a subset of frames if there are too many
    total_frames = min(num_frames, len(x))
    frame_indices = np.linspace(0, len(x) - 1, total_frames, dtype=int)

    anim = FuncAnimation(
        fig, update, frames=frame_indices, init_func=init, blit=True, interval=interval
    )
    if save_path is not None:
        anim.save(save_path, writer="ffmpeg")
    plt.close()
    return HTML(anim.to_jshtml())


# Create and display the animation
x, y = reconstruct_worm(worm_data, eigenworms)
worm_animation = animate_worm(x[:1000], y[:1000], save_path="../figures/wormanim.mp4")

worm_animation

In [None]:
context_length = 512
prediction_length = 128

diffed_worm_data = np.diff(worm_data, axis=0)

In [None]:
diff_prediction = forecast(
    pft_model,
    diffed_worm_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
)
pft_prediction = worm_data[context_length] + diff_prediction.cumsum(axis=0)

pft_metrics = defaultdict(list)
step = 8
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        pft_prediction[0 : i + step],
        worm_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        pft_metrics[metric].append(submetrics[metric])
print(pft_metrics)

In [None]:
diff_prediction = forecast(
    chronos,
    diffed_worm_data[:context_length],
    prediction_length,
    transpose=True,
    limit_prediction_length=False,
    num_samples=1,
    deterministic=True,
)
chronos_prediction = worm_data[context_length] + diff_prediction.cumsum(axis=0)

chronos_metrics = defaultdict(list)
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        chronos_prediction[0 : i + step],
        worm_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        chronos_metrics[metric].append(submetrics[metric])
print(chronos_metrics)

In [None]:
plot_comparison(
    worm_data,
    pft_prediction,
    chronos_prediction,
    context_length,
    prediction_length,
    pft_metrics,
    chronos_metrics,
)

In [None]:
# sanity check
_ = plot_model_prediction(
    pft_model,
    worm_data,
    context_length=context_length,
    prediction_length=prediction_length,
    sliding_context=True,
    limit_prediction_length=False,
)
_ = plot_model_prediction(
    chronos,
    worm_data,
    context_length=context_length,
    prediction_length=prediction_length,
    transpose=True,
    num_samples=1,
    limit_prediction_length=False,
    color="blue",
)

# Turbulent Boundary Layer

In [None]:
turbpca_data = np.load(
    f"{base_dir}/turbulence/BLexp_Re980_pca10.pkl", allow_pickle=True
)
print(turbpca_data.shape)

In [None]:
context_length = 512
prediction_length = 512

diffed_turbpca_data = np.diff(turbpca_data, axis=0)

In [None]:
diff_prediction = forecast(
    pft_model,
    diffed_turbpca_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
)
pft_prediction = turbpca_data[context_length] + diff_prediction.cumsum(axis=0)

pft_metrics = defaultdict(list)
step = 8
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        pft_prediction[0 : i + step],
        turbpca_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        pft_metrics[metric].append(submetrics[metric])
print(pft_metrics)

In [None]:
diff_prediction = forecast(
    chronos,
    diffed_turbpca_data[:context_length],
    prediction_length,
    transpose=True,
    limit_prediction_length=False,
    num_samples=1,
    deterministic=True,
)
chronos_prediction = turbpca_data[context_length] + diff_prediction.cumsum(axis=0)
chronos_metrics = defaultdict(list)
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        chronos_prediction[0 : i + step],
        turbpca_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        chronos_metrics[metric].append(submetrics[metric])
print(chronos_metrics)

In [None]:
plot_comparison(
    turbpca_data,
    pft_prediction,
    chronos_prediction,
    context_length,
    prediction_length,
    pft_metrics,
    chronos_metrics,
)

In [None]:
# sanity check
_ = plot_model_prediction(
    pft_model,
    turbpca_data,
    512,
    512,
    sliding_context=True,
    limit_prediction_length=False,
)
_ = plot_model_prediction(
    chronos,
    turbpca_data,
    512,
    512,
    transpose=True,
    num_samples=1,
    limit_prediction_length=False,
    color="blue",
)

# Electronic Circuit

In [None]:
netfpath = f"{base_dir}/electronic_circuit/Structure/Net_1.dat"
fpath = f"{base_dir}/electronic_circuit/R1/ST_0_3.dat"
net = np.loadtxt(netfpath)
circuit_data = np.loadtxt(fpath)
print(net.shape, circuit_data.shape)

In [None]:
context_length = 512
prediction_length = 512

In [None]:
pft_prediction = forecast(
    pft_model,
    circuit_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
)

pft_metrics = defaultdict(list)
step = 8
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        pft_prediction[0 : i + step],
        circuit_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        pft_metrics[metric].append(submetrics[metric])
print(pft_metrics)

In [None]:
chronos_prediction = forecast(
    chronos,
    circuit_data[:context_length],
    prediction_length,
    transpose=True,
    limit_prediction_length=False,
    num_samples=1,
    deterministic=True,
)

chronos_metrics = defaultdict(list)
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        chronos_prediction[0 : i + step],
        circuit_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        chronos_metrics[metric].append(submetrics[metric])
print(chronos_metrics)

In [None]:
plot_comparison(
    circuit_data,
    pft_prediction,
    chronos_prediction,
    context_length,
    prediction_length,
    pft_metrics,
    chronos_metrics,
)

In [None]:
# sanity check
_ = plot_model_prediction(
    pft_model,
    circuit_data,
    512,
    512,
    sliding_context=True,
    limit_prediction_length=False,
)
_ = plot_model_prediction(
    chronos,
    circuit_data,
    512,
    512,
    transpose=True,
    num_samples=1,
    limit_prediction_length=False,
    color="blue",
)

### Coupling Strength Scaling Law

In [None]:
# classify data by type {1, 2, 3} and sort by coupling strength
fpaths = os.listdir(f"{base_dir}/electronic_circuit/R1")
ec_fpaths = defaultdict(list)
for fpath in fpaths:
    ec_fpaths[int(fpath.split("_")[2][0])].append(fpath)
for k, v in ec_fpaths.items():
    ec_fpaths[k] = sorted(v, key=lambda x: int(x.split("_")[1]))
print(ec_fpaths)

In [None]:
n_steps = 8
pft_errors = {
    k: {m: np.zeros((n_steps, len(v))) for m in metrics} for k, v in ec_fpaths.items()
}
chronos_errors = {
    k: {m: np.zeros((n_steps, len(v))) for m in metrics} for k, v in ec_fpaths.items()
}
for k, v in ec_fpaths.items():
    for i, fpath in tqdm(
        enumerate(v), desc=f"Processing type-{k} circuit data", total=len(v)
    ):
        circuit_data = np.loadtxt(f"{base_dir}/electronic_circuit/R1/{fpath}")
        pft_prediction = forecast(
            pft_model,
            circuit_data[:context_length],
            prediction_length,
            limit_prediction_length=False,
            sliding_context=True,
        )
        chronos_prediction = forecast(
            chronos,
            circuit_data[:context_length],
            prediction_length,
            transpose=True,
            limit_prediction_length=False,
            num_samples=1,
        )
        for chunk, j in enumerate(
            np.arange(0, prediction_length, prediction_length // n_steps)
        ):
            pft_submetrics = compute_metrics(
                pft_prediction[0 : j + step],
                circuit_data[context_length : context_length + j + step],
                include=metrics,
            )
            chronos_submetrics = compute_metrics(
                chronos_prediction[0 : j + step],
                circuit_data[context_length : context_length + j + step],
                include=metrics,
            )
            for metric in metrics:
                pft_errors[k][metric][chunk, i] = pft_submetrics[metric]
                chronos_errors[k][metric][chunk, i] = chronos_submetrics[metric]


In [None]:
mean_pft_errors = {
    m: np.mean([pft_errors[k][m] for k in ec_fpaths], axis=0) for m in metrics
}
mean_chronos_errors = {
    m: np.mean([chronos_errors[k][m] for k in ec_fpaths], axis=0) for m in metrics
}
std_pft_errors = {
    m: np.std([pft_errors[k][m] for k in ec_fpaths], axis=0) for m in metrics
}
std_chronos_errors = {
    m: np.std([chronos_errors[k][m] for k in ec_fpaths], axis=0) for m in metrics
}

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(mean_pft_errors["smape"][-1], color="red")
plt.plot(mean_chronos_errors["smape"][-1], color="blue")
plt.fill_between(
    np.arange(len(mean_pft_errors["smape"][-1])),
    mean_pft_errors["smape"][-1] - std_pft_errors["smape"][-1],
    mean_pft_errors["smape"][-1] + std_pft_errors["smape"][-1],
    color="red",
    alpha=0.2,
)
plt.fill_between(
    np.arange(len(mean_chronos_errors["smape"][-1])),
    mean_chronos_errors["smape"][-1] - std_chronos_errors["smape"][-1],
    mean_chronos_errors["smape"][-1] + std_chronos_errors["smape"][-1],
    color="blue",
    alpha=0.2,
)
plt.xlabel("Coupling Strength");

In [None]:
plt.figure(figsize=(15, 4))
plt.title("%$\Delta$sMAPE ($\\uparrow$ is better)")
percentage_error = (
    chronos_errors[k]["smape"] - pft_errors[k]["smape"]
) / chronos_errors[k]["smape"]
plt.imshow(percentage_error, cmap="seismic", label=f"Type-{k}", aspect="auto")
plt.ylabel("Prediction length")
plt.yticks(
    np.arange(n_steps),
    np.arange(0, prediction_length, prediction_length // n_steps)
    + prediction_length // n_steps,
)
plt.xlabel("Coupling strength")
plt.show()

# ECG

In [None]:
fpath = f"{base_dir}/electrocardiogram/ecg_train.csv.gz"
ecg_data = np.loadtxt(fpath, delimiter=",")
print(ecg_data.shape)

In [None]:
context_length = 512
prediction_length = 512
start = 0
stride = 1

subsampled_ecg_data = ecg_data[start::stride]

In [None]:
pft_prediction = forecast(
    pft_model,
    subsampled_ecg_data[:context_length],
    prediction_length,
    limit_prediction_length=False,
    sliding_context=True,
)

pft_metrics = defaultdict(list)
step = 8
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        pft_prediction[0 : i + step],
        subsampled_ecg_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        pft_metrics[metric].append(submetrics[metric])
print(pft_metrics)

In [None]:
chronos_prediction = forecast(
    chronos,
    subsampled_ecg_data[:context_length],
    prediction_length,
    transpose=True,
    limit_prediction_length=False,
    num_samples=1,
    deterministic=True,
)

chronos_metrics = defaultdict(list)
for i in np.arange(0, prediction_length, step):
    submetrics = compute_metrics(
        chronos_prediction[0 : i + step],
        subsampled_ecg_data[context_length : context_length + i + step],
        include=metrics,
    )
    for metric in metrics:
        chronos_metrics[metric].append(submetrics[metric])
print(chronos_metrics)

In [None]:
# sanity check
context_ts = np.arange(context_length + 1)
pred_ts = np.arange(context_length, context_length + prediction_length)

fig, axes = plt.subplots(2, 1, figsize=(15, 10), sharex=True)
plt.subplots_adjust(hspace=0.0)
axes[0].plot(
    context_ts,
    subsampled_ecg_data[: context_length + 1],
    color="black",
    alpha=0.5,
    label="context",
)
axes[0].plot(
    pred_ts,
    subsampled_ecg_data[context_length : context_length + prediction_length],
    color="black",
    linestyle="--",
    label="groundtruth",
)
axes[0].plot(pred_ts, pft_prediction, color="red", label="prediction")
axes[0].plot(pred_ts, chronos_prediction, color="blue", label="chronos")

error_ts = (
    np.arange(context_length - step, context_length + prediction_length, step) + step
)
ticks = np.arange(0, context_length + prediction_length + 128, 128)
pft_with_zero = np.r_[0, pft_metrics["smape"]]
chronos_with_zero = np.r_[0, chronos_metrics["smape"]]
axes[1].plot(np.arange(context_length), np.zeros(context_length), color="red")
axes[1].plot(error_ts, pft_with_zero, color="red")
axes[1].plot(error_ts, chronos_with_zero, color="blue")
axes[1].set_xlabel("Prediction length")
axes[1].set_ylabel("sMAPE")
axes[1].set_xticks(np.arange(0, context_length + prediction_length + 128, 128))
plt.show()
