In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d.proj3d import proj_transform

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import (
    apply_custom_style,
    get_system_filepaths,
    load_trajectory_from_arrow,
)

In [None]:
# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
# run_name = "pft_stand_rff_only_pretrained-0"
run_name = "pft_linattnpolyemb_from_scratch-0"
# run_name = "pft_chattn_emb_w_poly-0"  # NOTE: this is still the best
# run_name = "pft_chattn_fullemb_quartic_enc-0"
# run_name = "pft_chattn_noembed_pretrained_correct-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
pft_model.model.model.encoder.layers[0]

In [None]:
class Arrow3D(FancyArrowPatch):
    def __init__(self, xs, ys, zs, *args, **kwargs):
        super().__init__((0, 0), (0, 0), *args, **kwargs)
        self._verts3d = xs, ys, zs

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, _ = proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        super().draw(renderer)

    def do_3d_projection(self):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj_transform(xs3d, ys3d, zs3d, self.axes.get_proj())
        self.set_positions((xs[0], ys[0]), (xs[1], ys[1]))
        return np.min(zs)

In [None]:
def make_clean_projection(ax_3d):
    ax_3d.grid(False)
    ax_3d.set_facecolor("white")
    ax_3d.set_xticks([])
    ax_3d.set_yticks([])
    ax_3d.set_zticks([])
    ax_3d.axis("off")

In [None]:
def make_arrow_axes(ax_3d):
    ax_3d.grid(False)
    ax_3d.set_facecolor("white")
    ax_3d.set_xticks([])
    ax_3d.set_yticks([])
    ax_3d.set_zticks([])
    ax_3d.axis("off")

    # Get axis limits
    x0, x1 = ax_3d.get_xlim3d()
    y0, y1 = ax_3d.get_ylim3d()
    z0, z1 = ax_3d.get_zlim3d()

    ax_3d.set_box_aspect((x1 - x0, y1 - y0, z1 - z0))
    # Define arrows along the three frame edges
    edges = [
        ((x0, y0, z0), (x1, y0, z0), "X"),
        ((x0, y0, z0), (x0, y1, z0), "Y"),
        ((x0, y0, z0), (x0, y0, z1), "Z"),
    ]

    for (xs, ys, zs), (xe, ye, ze), label in edges:
        arr = Arrow3D(
            [xs, xe],
            [ys, ye],
            [zs, ze],
            mutation_scale=20,
            lw=1.5,
            arrowstyle="-|>",
            color="black",
        )
        ax_3d.add_artist(arr)
        ax_3d.text(xe * 1.03, ye * 1.03, ze * 1.03, label, fontsize=12)

    # Hide the default frame and ticks
    for pane in (ax_3d.xaxis.pane, ax_3d.yaxis.pane, ax_3d.zaxis.pane):
        pane.set_visible(False)
    ax_3d.view_init(elev=30, azim=30)

In [None]:
def plot_model_prediction(
    model,
    context: np.ndarray,
    groundtruth: np.ndarray,
    prediction_length: int,
    title: str | None = None,
    save_path: str | None = None,
    show_plot: bool = True,
    **kwargs,
):
    context_tensor = torch.from_numpy(context.T).float().to(pft_model.device)[None, ...]
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(6, 8))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(2, 1, height_ratios=[0.65, 0.35], hspace=-0.1)

    # Create sub-grid for the plots
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[0.2] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")

    ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth[:3], linestyle="-", color="black", label="Groundtruth")
    ax_3d.plot(*pred.T[:3], color="red", label="Prediction")
    # make_arrow_axes(ax_3d)
    make_clean_projection(ax_3d)

    if title is not None:
        title_name = title.replace("_", " ")
        ax_3d.set_title(title_name, fontweight="bold")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, context[i], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[i], linestyle="-", color="black")
        ax.plot(pred_ts, pred[:, i], color="red")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect("auto")

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        print(f"saving fig to: {save_path}")
        plt.savefig(save_path, bbox_inches="tight")
    if show_plot:
        plt.show()
    plt.close()

In [None]:
# (sample_idx, start_time, prediction_length) assuming context_length = 512
chosen_forecast_settings = {
    # "KawczynskiStrizhak_HyperXu": (1, 1260, 128),
    "ForcedBrusselator_SprottB": (2, 770, 128),
    "SaltonSea_Duffing": (0, 1999, 128),
    "Bouali_Laser": (10, 55, 128),
    # "QiChen": (0, 980, 256),
    "Lorenz96_SprottB": (10, 1999, 512),
    "LorenzStenflo_IsothermalChemical": (0, 322, 256),
    "Coullet_GenesioTesi": (0, 300, 128),
    "SprottK_Duffing": (6, 815, 256),
}

In [None]:
split = "final_skew40"
subsplit = "test_zeroshot"
test_data_dirs = f"/stor/work/AMDG_Gilpin_Summer2024/data/improved/{split}"

In [None]:
context_length = 512

for dyst_name, settings in chosen_forecast_settings.items():
    sample_idx, start_time, pred_length = settings

    syspaths = get_system_filepaths(dyst_name, test_data_dirs, subsplit)
    trajectory, _ = load_trajectory_from_arrow(syspaths[sample_idx])
    # trajectory = trajectory[:, ::2]

    end_time = start_time + context_length

    save_path = os.path.join(
        "../figures",
        run_name,
        split,
        subsplit,
        dyst_name,
        f"{dyst_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{pred_length}_.pdf",
    )

    plot_model_prediction(
        pft_model,
        trajectory[:, start_time:end_time],  # context
        trajectory[:, end_time : end_time + pred_length],  # ground truth
        pred_length,
        limit_prediction_length=False,
        sliding_context=True,
        save_path=save_path,  # save_path,
        show_plot=True,
    )