In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch

from dystformer.patchtst.pipeline import PatchTSTPipeline

In [None]:
if os.path.exists("../custom_style.mplstyle"):
    plt.style.use(["ggplot", "../custom_style.mplstyle"])

In [None]:
# NOTE: need to comment out the final projection layer in PatchTSTKernelEmbedding for backwards compatibility with the checkpoints trained before polynomial features were added
# run_name = "pft_stand_rff_only_pretrained-0"
run_name = "pft_chattn_emb_w_poly-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:2",
)

In [None]:
def get_attn_weights(model, key: str) -> list[dict[str, torch.Tensor]]:
    params = [
        {
            "Wq": getattr(l, key).q_proj.weight,
            "Wk": getattr(l, key).k_proj.weight,
            "Wv": getattr(l, key).v_proj.weight,
        }
        for l in model.model.model.encoder.layers  # lol
    ]
    return params


def get_attn_map(
    weights: list[dict[str, torch.Tensor]], index: int, shift: bool = False
) -> np.ndarray:
    attn_map = (weights[index]["Wq"] @ weights[index]["Wk"].T).detach().cpu().numpy()
    if shift:
        attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map))
    return attn_map


def symmetric_distance(attn_map: np.ndarray) -> float:
    return (
        0.5
        * np.linalg.norm(attn_map - attn_map.T, "fro")
        / np.linalg.norm(attn_map, "fro")
    )  # type: ignore

In [None]:
temporal_weights = get_attn_weights(pft_model, "temporal_self_attn")
channel_weights = get_attn_weights(pft_model, "channel_self_attn")

In [None]:
# attn_map = get_attn_map(temporal_weights, 0)
# print(symmetric_distance(attn_map))
# plt.figure()
# plt.imshow(np.log(attn_map**2), cmap="RdBu")
# plt.colorbar()
# plt.show()

In [None]:
# attn_map = get_attn_map(channel_weights, 0)
# print(symmetric_distance(attn_map))
# plt.figure()
# plt.imshow(np.log(attn_map**2), cmap="RdBu")
# plt.colorbar()
# plt.show()

In [None]:
# llayer = pft_model.model.model.encoder.layers[0].ff
# print(llayer)
# ffw = llayer[0].weight.detach().cpu().numpy()
# print(symmetric_distance(ffw))

# U, S, V = np.linalg.svd(ffw)
# threshold = 1e-3
# rank = np.sum(S > threshold)
# plt.figure()
# plt.plot(range(1, len(S) + 1), S, "o-", linewidth=2)
# plt.title("Scree Plot of Singular Values")
# plt.xlabel("Singular Value Index")
# plt.ylabel("Singular Value Magnitude")
# plt.grid(True)
# plt.yscale("log")  # Log scale to better visualize the decay
# plt.axhline(
#     y=threshold, color="r", linestyle="--", label=f"Threshold ({threshold:.1e})"
# )
# plt.legend()
# plt.show()

# reconstructed = U[:, :rank] @ np.diag(S)[:rank, :rank] @ V[:rank, :]
# plt.figure()
# plt.imshow(np.log(reconstructed**2), cmap="RdBu")
# plt.colorbar()
# plt.show()

In [None]:
# fig, axes = plt.subplots(2, 4, figsize=(20, 10))
# for i, ax in enumerate(axes.flatten()):
#     attn_map = get_attn_map(temporal_weights, i)
#     ax.imshow(attn_map, cmap="RdBu")
#     ax.set_title(f"Layer {i}")
# plt.tight_layout()
# plt.show()

In [None]:
def plot_attn_map(
    model,
    context: np.ndarray,
    patch_size: int,
    sample_idx: int,
    layer_idx: int,
    head_idx: int,
    prefix: str = "",
    colormap: str = "magma",
    show_colorbar: bool = True,
    show_title: bool = True,
    save_path: str | None = None,
) -> None:
    """Plot attention matrix with corresponding timeseries patches along edges."""
    attention_type = "temporal" if layer_idx % 2 == 0 else "channel"
    patches = context.reshape(context.shape[0], -1, patch_size)
    if attention_type == "channel":
        patches = patches.transpose(1, 0, 2)

    context_tensor = torch.from_numpy(context.T).float().to(pft_model.device)[None, ...]
    pred = model(context_tensor, output_attentions=True)
    attn_weights = pred.attentions

    # Extract attention weights for specified sample, layer and head
    num_samples = attn_weights[layer_idx].shape[0]
    attn = attn_weights[layer_idx][sample_idx, head_idx].detach().cpu().numpy()
    n_patches = attn.shape[0]

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(10, 10))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(1, 2, width_ratios=[1, 0.05], wspace=0.05)

    # Create sub-grid for the plots
    gs = outer_grid[0].subgridspec(
        2, 2, width_ratios=[0.15, 0.85], height_ratios=[0.15, 0.85], wspace=0, hspace=0
    )

    # Plot attention matrix first to get its size
    ax_main = fig.add_subplot(gs[1, 1])
    im = ax_main.imshow(attn, extent=(0, n_patches, n_patches, 0), cmap=colormap)
    ax_main.set_xticks([])
    ax_main.set_yticks([])

    linewidth = 2
    # Plot patches along top
    ax_top = fig.add_subplot(gs[0, 1])
    for i in range(n_patches):
        x = np.linspace(i, i + 1, patch_size)
        ax_top.plot(x, patches[sample_idx, i], linewidth=linewidth)
    ax_top.set_xlim(0, n_patches)
    ax_top.set_xticks([])
    ax_top.set_yticks([])
    ax_top.grid(True)

    # Plot patches along left side
    ax_left = fig.add_subplot(gs[1, 0])
    for i in range(n_patches):
        y = np.linspace(i, i + 1, patch_size)
        ax_left.plot(-patches[sample_idx, i], y, linewidth=linewidth)
    ax_left.set_ylim(n_patches, 0)
    ax_left.set_xticks([])
    ax_left.set_yticks([])
    ax_left.grid(True)

    ax_cbar = fig.add_subplot(outer_grid[1])
    if show_colorbar:
        # Add colorbar
        plt.colorbar(im, cax=ax_cbar)

    else:
        # remove outer_grid[1]
        fig.delaxes(ax_cbar)
    # Remove empty subplot
    fig.delaxes(fig.add_subplot(gs[0, 0]))

    # Force exact alignment of subplots
    main_pos = ax_main.get_position()
    ax_top.set_position(
        [main_pos.x0, main_pos.y1, main_pos.width, ax_top.get_position().height]  # type: ignore
    )
    ax_left.set_position(
        [
            ax_left.get_position().x0,
            main_pos.y0,
            ax_left.get_position().width,
            main_pos.height,
        ]  # type: ignore
    )
    ax_cbar.set_position(
        [
            ax_cbar.get_position().x0,
            main_pos.y0,
            ax_cbar.get_position().width,
            main_pos.height,
        ]  # type: ignore
    )
    sample_type = "channel" if attention_type == "temporal" else "patch"
    if show_title:
        ax_top.set_title(
            f"{prefix} {attention_type} attention @ layer {layer_idx}, head {head_idx}, ({sample_type} {sample_idx + 1}/{num_samples})"
        )
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
    else:
        plt.show()

In [None]:
def get_attn_map_from_spec(
    model,
    context: np.ndarray,
    patch_size: int,
    sample_idx: int,
    layer_idx: int,
    head_idx: int,
) -> np.ndarray:
    """get attention matrix"""
    attention_type = "temporal" if layer_idx % 2 == 0 else "channel"
    patches = context.reshape(context.shape[0], -1, patch_size)
    if attention_type == "channel":
        patches = patches.transpose(1, 0, 2)

    context_tensor = torch.from_numpy(context.T).float().to(pft_model.device)[None, ...]
    pred = model(context_tensor, output_attentions=True)
    attn_weights = pred.attentions

    attn = attn_weights[layer_idx][sample_idx, head_idx].detach().cpu().numpy()
    return attn

In [None]:
from dystformer.utils import get_system_filepaths, load_trajectory_from_arrow

dyst_name = "Lorenz"
test_data_dirs = "/stor/work/AMDG_Gilpin_Summer2024/data/final_base40"
syspaths = get_system_filepaths(dyst_name, test_data_dirs, "train")

sample_idx = 0
trajectory, _ = load_trajectory_from_arrow(syspaths[sample_idx])

In [None]:
trajectory.shape

In [None]:
# Create a sample trajectory.
# Here we generate a simple (regular) trajectory as an example.
t = np.linspace(0, 10 * np.pi, 4096)

test_system_periodic = np.array(
    [
        np.sin(2 * t),  # x-coordinate
        np.sin(2 * t),  # y-coordinate
        np.sin(2 * t),  # z-coordinate
        # np.sin(2 * t),  # y-coordinate
        # np.sin(3 * t),  # z-coordinate
    ]
)
test_system_fourier = np.zeros((3, 4096))
for i in range(3):  # For each dimension
    for j in range(10):  # For each mode
        freq = np.random.rand() * 2 * np.pi  # Random frequency
        phase = np.random.rand() * 2 * np.pi  # Random phase
        test_system_fourier[i] += np.sin(freq * t + phase)

test_system_noise = np.random.randn(3, 4097).cumsum(axis=1)[:, :-1]

print(test_system_periodic.shape)
print(test_system_noise.shape)
print(test_system_fourier.shape)

In [None]:
plt.figure(figsize=(5, 5))
ax = plt.axes(projection="3d")
ax.plot3D(test_system_noise[0], test_system_noise[1], test_system_noise[2])
# plt.plot(test_system_periodic[0])
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")
plt.show()

In [None]:
sample_idx = 1
layer_idx = 10
head_idx = 7
start_time = 0
end_time = start_time + 1024

plot_attn_map(
    pft_model.model,
    test_system_noise[:, start_time:end_time],
    16,
    sample_idx=sample_idx,
    layer_idx=layer_idx,
    head_idx=head_idx,
    prefix=dyst_name,
    colormap="Blues",
    show_title=False,
    show_colorbar=False,
    save_path=None,
    # save_path=f"../figures/{dyst_name}_attn_map_layer{layer_idx}_head{head_idx}_sample{sample_idx}_context{start_time}-{end_time}.pdf",
)

In [None]:
layer_indices = [0, 2, 4, 6, 8, 10, 12, 14]
attn_maps_by_layer = {}
for layer_idx in layer_indices:
    attn_maps_by_layer[layer_idx] = get_attn_map_from_spec(
        pft_model.model,
        test_system_noise[:, start_time:end_time],
        16,
        sample_idx=sample_idx,
        layer_idx=layer_idx,
        head_idx=head_idx,
    )

In [None]:
attn_maps_by_layer.keys()

In [None]:
n_cols = 4
n_rows = 2  # Explicitly set to 2 to ensure enough rows for 8 layers
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))

# Flatten axes for easier indexing
axes = axes.flatten()

for i, (layer_idx, attn_map) in enumerate(attn_maps_by_layer.items()):
    # Plot the 2D FFT power spectrum using the flattened index
    axes[i].imshow(attn_map, cmap="Blues")
    # axes[i].colorbar(label="Log Power")
    axes[i].set_title(
        f"Head {head_idx} - Layer {layer_idx}", fontsize=12, fontweight="bold"
    )

plt.tight_layout()
# plt.savefig(
#     f"figs/double_freq_attn_maps_head{head_idx}_all_layers.pdf",
#     bbox_inches="tight",
# )
# plt.suptitle(f"Head {head_idx} - All Layers")
plt.show()

In [None]:
attn_maps_by_layer_fft = {}
for layer_idx, attn_map in attn_maps_by_layer.items():
    attn_maps_by_layer_fft[layer_idx] = np.fft.fft2(attn_map)

power_spectra_by_layer = {}
for layer_idx, attn_map_fft in attn_maps_by_layer_fft.items():
    power_spectra_by_layer[layer_idx] = np.abs(attn_map_fft) ** 2

In [None]:
def radial_profile(data, center=None):
    """
    Find the axizumuthally averaged radial profile of an image

    Args:
        data (N x N np.ndarray): A two-dimensional array
        center (length-2 iterable): The center of the radial profile trace
    """
    x, y = np.indices((data.shape))
    if not center:
        center = data.shape[0] // 2, data.shape[1] // 2
    r = np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
    r = r.astype(int)
    tbin = np.bincount(r.ravel(), data.ravel())
    nr = np.bincount(r.ravel())
    radialprofile = tbin / nr
    return radialprofile[:-1]


def psd_radial(arr, return_k=False):
    """
    Compute the radially-averaged power spectrum of a 2D image

    Args:
        arr (ndarray): a two-dimensional array
        return_k (bool): whether to return the k values as well as the PSD

    Returns:
        psd (ndarray): the radially-averaged power spectrum
        k (ndarray): the spatial frequencies corresponding to each PSD value

    """
    psd = np.fft.fftshift(np.real(np.fft.fft2(arr) * np.conj(np.fft.fft2(arr))))
    if return_k:
        nx, ny = arr.shape
        kx = np.fft.fftshift(np.fft.fftfreq(nx))
        ky = np.fft.fftshift(np.fft.fftfreq(ny))
        k = np.sqrt(kx[:, None] ** 2 + ky[None, :] ** 2)
        return radial_profile(psd), radial_profile(k)
    else:
        return radial_profile(psd)


def circular_mask(r, nx, ny, center=(0.5, 0.5)):
    """
    A circular mask of fixed radius.

    Args:
        r (float): the radius of the mask
        nx, ny (int): the shape of the mask along each axis
        center (length-2 iterable): The center of the radial mask

    """
    xc, yc = center
    xx, yy = np.meshgrid(np.linspace(0, 1, nx), np.linspace(0, 1, ny))
    return 1 - np.heaviside((xx - xc) ** 2 + (yy - yc) ** 2 - r**2, 0)


def annular_mask(r_min, r_max, *args, center=(0.5, 0.5)):
    """
    An annular mask of fixed inner and outer radii.

    Args:
        r_min, r_max (float): the inner and outer radii of the mask
        nx, ny (int): the shape of the mask along each axis
        center (length-2 iterable): The center of the radial mask

    """
    return circular_mask(r_max, *args, center=center) - circular_mask(
        r_min, *args, center=center
    )


from scipy.ndimage import gaussian_filter1d


def gaussian_blur_periodic(arr, sigma, axis=0):
    """
    Apply a Gaussian blur to an array with periodic boundary conditions.

    Args:
        arr (ndarray): The array to blur
        sigma (float): The standard deviation of the Gaussian kernel
        axis (int): The axis along which to blur

    Returns:
        arr_blurred (ndarray): The blurred array
    """
    arr = np.rollaxis(arr, axis)  # move axis to first dimension
    arr_blurred = gaussian_filter1d(arr, sigma, mode="wrap")
    return np.rollaxis(arr_blurred, 0, axis + 1)


def diff_periodic(a, axis=0, smooth=None):
    """
    Compute the difference of an array along a given axis, with periodic boundary
    conditions.

    Args:
        a (ndarray): The array to differentiate
        axis (int): The axis along which to differentiate
        smooth (int): The number of points to smooth over
    """
    fwd = np.roll(a, -1, axis=axis) - a
    bwd = a - np.roll(a, 1, axis=axis)
    grad = (fwd + bwd) / 2
    # if smooth:
    #     ## Gaussian blur grad with periodic boundary conditions
    #     grad = gaussian_blur_periodic(grad, smooth, axis=axis)
    return grad


import scipy.fft


def derivative_spectral(field, coords, axis=0):
    """
    Take the derivative of a field by casting to spectral space, multiplying by
    the wavenumber, and then transforming back to real space.

    Args:
        field (ndarray): The field to differentiate
        coords (ndarray): The 1D array of field coordinates
        axis (int): The axis along which to differentiate the field

    Returns:
        ndarray: The derivative of the field
    """
    field = np.swapaxes(field, 0, axis)  # swap axis to first dimension
    n = field.shape[0]

    # Take FFT along specified axis
    f_k = scipy.fft.fft(field, axis=0)

    # Get wavenumbers and multiply by wavenumber
    dk = 2 * np.pi / (coords[-1] - coords[0])
    k = scipy.fft.fftfreq(coords.shape[0], d=dk)
    f_k = f_k * 1j * k[:, None, None] * n

    # Take inverse FFT to return to real space
    df_dx = scipy.fft.ifft(f_k, axis=0).real

    df_dx = np.swapaxes(df_dx, 0, axis)  # Swap back to original configuration
    return df_dx


def energy_spectrum(v_squared):
    """
    Compute the energy spectrum of a 2D field. Correctly scales the radial average
    of the power spectrum to account for the fact that the area of each annulus
    increases with radius.

    Args:
        v_squared (ndarray): a two-dimensional array

    Returns:
        k_bins (ndarray): the spatial frequencies corresponding to each PSD value
        E_k (ndarray): the energy spectrum
    """

    # Take the 2D Fourier Transform
    ft_v_squared = np.fft.fftshift(np.fft.fft2(v_squared))

    # Compute the power spectral density (PSD)
    psd = np.abs(ft_v_squared) ** 2

    # Get the wave numbers kx and ky
    nx, ny = psd.shape
    kx = np.fft.fftshift(np.fft.fftfreq(nx)) * nx
    ky = np.fft.fftshift(np.fft.fftfreq(ny)) * ny
    kx, ky = np.meshgrid(kx, ky)

    # Compute the radial wave number k
    k = np.sqrt(kx**2 + ky**2)

    # Set up bins for k and integrate the PSD over each shell in k-space
    k_bins = np.arange(0, np.max(k), 1)  # adjust the bin size as needed
    E_k = [
        np.sum(psd[(k >= k_bin) & (k < k_bin + 1)]) * (2 * np.pi * k_bin)
        for k_bin in k_bins
    ]

    return k_bins, np.array(E_k)

In [None]:
energy_spectrum_by_layer = {}
for layer_idx, attn_map in attn_maps_by_layer.items():
    energy_spectrum_by_layer[layer_idx] = energy_spectrum(attn_map)

In [None]:
n_cols = 4
n_rows = 2  # Explicitly set to 2 to ensure enough rows for 8 layers
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))

# Flatten axes for easier indexing
axes = axes.flatten()

for i, (layer_idx, energy_spectrum) in enumerate(energy_spectrum_by_layer.items()):
    # Plot the 2D FFT power spectrum using the flattened index
    k_vals = energy_spectrum[0]
    time_vals = 2 * np.pi / (k_vals)
    energy_vals = energy_spectrum[1]
    axes[i].plot(time_vals, energy_vals, marker="o", markersize=2)
    axes[i].set_title(
        f"Energy Spectrum - Layer {layer_idx}", fontsize=12, fontweight="bold"
    )
    axes[i].set_xlabel("Time")
    axes[i].set_xticks(time_vals)
    # axes[i].set_xlim(time_vals[0], 1)
    axes[i].set_ylabel("Energy")
    axes[i].set_xscale("log")
    axes[i].set_yscale("log")


plt.tight_layout()
# plt.savefig(
#     f"figs/double_freq_2d_energy_spectrum_all_layers.pdf",
#     bbox_inches="tight",
# )
plt.show()

In [None]:
n_cols = 4
n_rows = 2  # Explicitly set to 2 to ensure enough rows for 8 layers
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))

# Flatten axes for easier indexing
axes = axes.flatten()

for i, (layer_idx, attn_map_fft) in enumerate(attn_maps_by_layer_fft.items()):
    # Plot the 2D FFT power spectrum using the flattened index
    axes[i].imshow(attn_map_fft.real, cmap="viridis", origin="lower")
    # axes[i].colorbar(label="Log Power")
    axes[i].set_title(
        f"2D FFT Power Spectrum - Layer {layer_idx}", fontsize=12, fontweight="bold"
    )

plt.tight_layout()
# plt.savefig(
#     f"../figures/{dyst_name}_2d_fft_power_spectrum_all_layers.png",
#     bbox_inches="tight",
# )
plt.show()

In [None]:
next(iter(attn_maps_by_layer.values())).shape

In [None]:
import matplotlib.cm as cm

cmaps = cm.get_cmap("cividis")
fig = plt.figure(figsize=(3, 3))
for layer_idx, attn_map in attn_maps_by_layer.items():
    singular_values = np.linalg.svd(attn_map, compute_uv=False)
    print(f"Layer {layer_idx} has {len(singular_values)} singular values")
    # plot the eigenvalues
    plt.plot(
        singular_values,
        label=f"Layer {layer_idx}",
        color=cmaps(layer_idx / len(attn_maps_by_layer)),
    )
plt.legend(loc="lower left", fontsize=6, frameon=True)
plt.yscale("log")
plt.tight_layout()
plt.show()

In [None]:
def setup_3d_axes(
    ax_3d, scale: float = 0.8, elevation: float = 30, azimuth: float = 45
):
    """Set up clean 3D axes with coordinate system from origin."""
    ax_3d.grid(False)
    ax_3d.set_axis_off()

    # Get data limits
    xmin, xmax = ax_3d.get_xlim()
    ymin, ymax = ax_3d.get_ylim()
    zmin, zmax = ax_3d.get_zlim()

    # Calculate origin and axis length
    origin = [
        min(0, xmin),
        min(0, ymin),
        min(0, zmin),
    ]  # Ensure origin includes (0,0,0)
    axis_length = scale * max(
        xmax - xmin, ymax - ymin, zmax - zmin
    )  # Slightly longer than data range

    # Plot coordinate axes with thicker lines
    ax_3d.plot(
        [origin[0], origin[0] + axis_length],
        [origin[1], origin[1]],
        [origin[2], origin[2]],
        "k-",
        lw=1.5,
    )  # x-axis
    ax_3d.plot(
        [origin[0]],
        [origin[1], origin[1] + axis_length],
        [origin[2], origin[2]],
        "k-",
        lw=1.5,
    )  # y-axis
    ax_3d.plot(
        [origin[0]], [origin[1]], [origin[2], origin[2] + axis_length], "k-", lw=1.5
    )  # z-axis

    # Add axis labels with better positioning and consistent style
    label_offset = axis_length * 1.1
    ax_3d.text(
        origin[0] + label_offset,
        origin[1],
        origin[2],
        "$x_1$",
        fontsize=12,
        ha="center",
    )
    ax_3d.text(
        origin[0],
        origin[1] + label_offset,
        origin[2],
        "$x_2$",
        fontsize=12,
        ha="center",
    )
    ax_3d.text(
        origin[0],
        origin[1],
        origin[2] + label_offset,
        "$x_3$",
        fontsize=12,
        ha="center",
    )

    # Set better viewing angle
    ax_3d.view_init(elev=elevation, azim=azimuth)  # Adjusted for better perspective

    # Ensure axes limits include both data and coordinate system
    margin = axis_length * 0.2
    ax_3d.set_xlim(origin[0], origin[0] + axis_length + margin)
    ax_3d.set_ylim(origin[1], origin[1] + axis_length + margin)
    ax_3d.set_zlim(origin[2], origin[2] + axis_length + margin)


def plot_model_prediction(
    model,
    context: np.ndarray,
    groundtruth: np.ndarray,
    prediction_length: int,
    title: str | None = None,
    save_path: str | None = None,
    elevation: float = 30,
    axis_scale: float = 0.6,
    azimuth: float = 45,
    **kwargs,
):
    context_tensor = torch.from_numpy(context.T).float().to(pft_model.device)[None, ...]
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(6, 8))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(2, 1, height_ratios=[0.65, 0.35], hspace=-0.2)

    # Create sub-grid for the plots
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[0.2] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")

    ax_3d.plot(*context[:3], alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth[:3], linestyle="--", color="black", label="Groundtruth")
    ax_3d.plot(*pred.T[:3], color="red", label="Prediction")
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")  # type: ignore
    setup_3d_axes(ax_3d, scale=axis_scale, elevation=elevation, azimuth=azimuth)

    if title is not None:
        title_name = title.replace("_", " ")
        ax_3d.set_title(title_name, fontweight="bold")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, context[i], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
        ax.plot(pred_ts, pred[:, i], color="red")
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect("auto")

    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
dyst_name = "KawczynskiStrizhak_HyperXu"
split = "final_skew40"
subsplit = "test_zeroshot"
test_data_dirs = f"/stor/work/AMDG_Gilpin_Summer2024/data/copy/{split}"
syspaths = get_system_filepaths(dyst_name, test_data_dirs, subsplit)

sample_idx = 1
trajectory, _ = load_trajectory_from_arrow(syspaths[sample_idx])

In [None]:
context_length = 512
pred_length = 128
start_time = 1260
end_time = start_time + context_length

save_path = os.path.join(
    "../figures",
    run_name,
    split,
    subsplit,
    dyst_name,
    f"{dyst_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{pred_length}_.pdf",
)

plot_model_prediction(
    pft_model,
    trajectory[:, start_time:end_time],  # context
    trajectory[:, end_time : end_time + pred_length],  # ground truth
    pred_length,
    limit_prediction_length=False,
    sliding_context=True,
    save_path=save_path,  # save_path,
    azimuth=30,
    elevation=30,
    axis_scale=0.6,
)