In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch

from panda.patchtst.pipeline import PatchTSTPipeline
from panda.utils.data_utils import load_trajectory_from_arrow, safe_standardize

In [None]:
WORK_DIR = os.getenv("WORK", "")
DATA_DIR = os.path.join(WORK_DIR, "data")

In [None]:
def forecast(
    model,
    context: np.ndarray,
    prediction_length: int,
    transpose: bool = False,
    standardize: bool = True,
    differenced: bool = False,
    **kwargs,
) -> np.ndarray:
    """
    Args:
        model: The model to use for forecasting.
        context: The context to forecast (n_timesteps, n_features)
        context_length: The length of the context.
        prediction_length: The length of the prediction.
        transpose: Whether to transpose the data.

    Returns:
        The forecasted data (prediction_length, n_features)
    """
    preprocessed_context = context.copy()

    if differenced:
        differenced_context = np.diff(preprocessed_context, axis=0)
        preprocessed_context = differenced_context.copy()
    if standardize:
        preprocessed_context = safe_standardize(preprocessed_context, axis=0)

    context_tensor = torch.from_numpy(preprocessed_context.T if transpose else preprocessed_context).float()
    pred = model.predict(context_tensor, prediction_length, verbose=False, **kwargs).squeeze().cpu().numpy()
    if transpose:
        pred = pred.T

    if standardize:
        pred = safe_standardize(
            pred,
            axis=0,
            context=differenced_context if differenced else context,
            denormalize=True,
        )
    if differenced:
        pred = np.cumsum(pred, axis=0) + context[-1]

    # prediction length may be shorter than model output length
    return pred[:prediction_length, :] if pred.ndim == 2 else pred[:prediction_length]


def setup_3d_axes(ax_3d, scale: float = 0.8, elevation: float = 30, azimuth: float = 45):
    """Set up clean 3D axes with coordinate system from origin."""
    ax_3d.grid(False)
    ax_3d.set_axis_off()

    # Get data limits
    xmin, xmax = ax_3d.get_xlim()
    ymin, ymax = ax_3d.get_ylim()
    zmin, zmax = ax_3d.get_zlim()

    # Calculate origin and axis length
    origin = [
        min(0, xmin),
        min(0, ymin),
        min(0, zmin),
    ]  # Ensure origin includes (0,0,0)
    axis_length = scale * max(xmax - xmin, ymax - ymin, zmax - zmin)  # Slightly longer than data range

    # Plot coordinate axes with thicker lines
    ax_3d.plot(
        [origin[0], origin[0] + axis_length],
        [origin[1], origin[1]],
        [origin[2], origin[2]],
        "k-",
        lw=1.5,
    )  # x-axis
    ax_3d.plot(
        [origin[0]],
        [origin[1], origin[1] + axis_length],
        [origin[2], origin[2]],
        "k-",
        lw=1.5,
    )  # y-axis
    ax_3d.plot([origin[0]], [origin[1]], [origin[2], origin[2] + axis_length], "k-", lw=1.5)  # z-axis

    # Set better viewing angle
    ax_3d.view_init(elev=elevation, azim=azimuth)  # Adjusted for better perspective

    # Ensure axes limits include both data and coordinate system
    margin = axis_length * 0.2
    ax_3d.set_xlim(origin[0], origin[0] + axis_length + margin)
    ax_3d.set_ylim(origin[1], origin[1] + axis_length + margin)
    ax_3d.set_zlim(origin[2], origin[2] + axis_length + margin)


def plot_model_prediction(
    model,
    trajectory: np.ndarray,
    context_length: int,
    prediction_length: int,
    title: str | None = None,
    save_path: str | None = None,
    **kwargs,
):
    context = trajectory[:, :context_length]
    groundtruth = trajectory[:, context_length : context_length + prediction_length]
    context_tensor = torch.from_numpy(context.T).float().to(model.device)[None, ...]
    pred = model.predict(context_tensor, prediction_length, **kwargs).squeeze().cpu().numpy()
    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(6, 8))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(2, 1, height_ratios=[0.65, 0.35], hspace=-0.2)

    # Create sub-grid for the plots
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[0.2] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")

    ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth[:3], linestyle="--", color="black", label="Groundtruth")
    ax_3d.plot(*pred.T[:3], color="red", label="Prediction")
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")  # type: ignore
    ax_3d.grid(False)
    ax_3d.set_axis_off()

    if title is not None:
        title_name = title.replace("_", " ")
        ax_3d.set_title(title_name, fontweight="bold")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, context[i], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
        ax.plot(pred_ts, pred[:, i], color="red")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect("auto")

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
device_rank = 0
run_name = "pft_chattn_emb_w_poly-0"
# run_name = "pft_chattn_noembed_pretrained_correct-0"
pipeline = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map=f"cuda:{device_rank}",
    torch_dtype=torch.float32,
)

In [None]:
work_dir = os.environ.get("WORK", "/stor/work/AMDG_Gilpin_Summer2024")
base_dir = Path(work_dir) / "data/improved/final_skew40" / "test_zeroshot_z5_z10"

arrow_paths = sorted(base_dir.glob("*/*.arrow"))
assert arrow_paths, f"No Arrow files found under {base_dir}"

rng = np.random.default_rng(42)
system_path = arrow_paths[int(rng.integers(len(arrow_paths)))]

trajectory, metadata = load_trajectory_from_arrow(system_path, one_dim_target=False)

In [None]:
system_name = system_path.parent.stem
context_length = 512
prediction_length = 256

print(f"{system_name=}")
print(f"{context_length=}")
print(f"{prediction_length=}")

plot_model_prediction(
    pipeline,
    trajectory=trajectory,
    context_length=context_length,
    prediction_length=prediction_length,
    limit_prediction_length=False,
)