In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import torch

from panda.patchtst.pipeline import PatchTSTPipeline
from panda.utils.plot_utils import (
    apply_custom_style,
    plot_trajs_multivariate,
)

apply_custom_style("../config/plotting.yaml")

In [None]:
model_pipeline = PatchTSTPipeline.from_pretrained(
    mode="pretrain",
    pretrain_path="GilpinLab/panda_mlm",
    device_map="cuda:0",
)

In [None]:
def get_model_completion(
    pipeline,
    context: np.ndarray,
    return_normalized_completions: bool = False,
    verbose: bool = True,
    **kwargs,
):
    # Prepare input tensor
    context_tensor = torch.from_numpy(context.T).float().to(pipeline.device)[None, ...]
    # Generate completions
    completions_output = pipeline.model.generate_completions(
        context_tensor,
        past_observed_mask=None,
        **kwargs,
    )

    if verbose:
        print(f"context_tensor shape: {context_tensor.shape}")
        print(f"completions output shape: {completions_output.completions.shape}")

    # Extract shapes and data
    patch_size = completions_output.completions.shape[-1]

    # Check for required outputs
    if any(x is None for x in [completions_output.mask, completions_output.patched_past_values]):
        raise ValueError("Required completion outputs are None")

    # Process tensors to numpy arrays
    def process_tensor(tensor, reshape=True):
        if reshape:
            return (
                tensor.reshape(context_tensor.shape[0], context_tensor.shape[-1], -1)
                .detach()
                .cpu()
                .numpy()
                .transpose(0, 2, 1)
            )
        return tensor.detach().cpu().numpy()

    completions = process_tensor(completions_output.completions)
    processed_context = process_tensor(completions_output.patched_past_values)
    patch_mask = process_tensor(completions_output.mask, reshape=False)
    timestep_mask = np.repeat(patch_mask, repeats=patch_size, axis=2)

    # Denormalize if needed
    if not return_normalized_completions:
        if completions_output.loc is None or completions_output.scale is None:
            raise ValueError("Loc or scale is None")
        loc = completions_output.loc.detach().cpu().numpy()
        scale = completions_output.scale.detach().cpu().numpy()
        completions = completions * scale + loc
        processed_context = processed_context * scale + loc

    # Reshape for plotting
    processed_context = processed_context.squeeze(0).transpose(1, 0)
    completions = completions.squeeze(0).transpose(1, 0)
    timestep_mask = timestep_mask.squeeze(0)

    if verbose:
        print(f"processed context shape: {processed_context.shape}")
        print(f"completions shape: {completions.shape}")
        print(f"timestep mask shape: {timestep_mask.shape}")

    return completions, processed_context, timestep_mask

In [None]:
def plot_model_completion(
    completions,
    processed_context,
    timestep_mask,
    figsize: tuple[int, int] = (6, 8),
    save_path: str | None = None,
):
    n_timesteps = processed_context.shape[1]
    assert n_timesteps == completions.shape[1] == processed_context.shape[1]

    # Create figure with grid layout
    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(4, 1, height_ratios=[3, 1, 1, 1])

    # Create axes
    ax_3d = fig.add_subplot(gs[0], projection="3d")
    axes_2d = [fig.add_subplot(gs[i]) for i in range(1, 4)]

    # Plot completions in 3D
    ax_3d.plot(
        processed_context[0, :],
        processed_context[1, :],
        processed_context[2, :],
        alpha=0.5,
        color="black",
        linewidth=2,
    )
    # ax_3d.set_title("Completions", y=0.94, fontweight="bold")
    ax_3d.axis("off")
    ax_3d.grid(False)

    # Plot masked segments in 3D
    mask_bool = timestep_mask.astype(bool)
    for dim in range(3):
        # Find contiguous blocks in mask
        change_indices = np.where(np.diff(np.concatenate(([False], mask_bool[dim], [False]))))[0]

        # Plot each contiguous block
        for i in range(0, len(change_indices), 2):
            if i + 1 < len(change_indices):
                start_idx, end_idx = change_indices[i], change_indices[i + 1]
                # Plot masked parts in red
                ax_3d.plot(
                    completions[0, start_idx:end_idx],
                    completions[1, start_idx:end_idx],
                    completions[2, start_idx:end_idx],
                    alpha=1,
                    color="red",
                    linewidth=2,
                    zorder=10,
                )
                # Plot masked parts in red
                ax_3d.plot(
                    processed_context[0, start_idx:end_idx],
                    processed_context[1, start_idx:end_idx],
                    processed_context[2, start_idx:end_idx],
                    alpha=1,
                    color="black",
                    linewidth=2,
                )

    # Plot univariate series for each dimension
    for dim, ax in enumerate(axes_2d):
        mask_bool_dim = timestep_mask[dim, :].astype(bool)

        # Plot context
        ax.plot(processed_context[dim, :], alpha=0.5, color="black", linewidth=2)

        # Find segments where mask changes
        diffs = np.diff(mask_bool_dim.astype(int))
        change_indices = np.where(diffs)[0]
        if not mask_bool_dim[0]:
            change_indices = np.concatenate(([0], change_indices))
        segment_indices = np.concatenate((change_indices, [n_timesteps]))

        # Plot completions for masked segments
        segments = zip(segment_indices[:-1], segment_indices[1:])
        masked_segments = [idx for i, idx in enumerate(segments) if (i + 1) % 2 == 1]
        for start, end in masked_segments:
            if end < n_timesteps - 1:
                end += 1
            ax.plot(
                range(start, end),
                completions[dim, start:end],
                alpha=1,
                color="red",
                linewidth=2,
                zorder=10,
            )
            ax.plot(
                range(start, end),
                processed_context[dim, start:end],
                alpha=1,
                color="black",
                linewidth=2,
            )

        # Fill between completions and context
        ax.fill_between(
            range(n_timesteps),
            processed_context[dim, :],
            completions[dim, :],
            where=~mask_bool_dim,
            alpha=0.2,
        )
        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

## Make Held-Out Skew System from Saved Parameters

In [None]:
from panda.utils import init_skew_system_from_params

In [None]:
params_dir = "../data/params_test_zeroshot"

In [None]:
parameters_json_path_test = os.path.join(params_dir, "filtered_params_dict.json")

In [None]:
saved_params_dict_test = json.load(open(parameters_json_path_test))

In [None]:
print(f"Found {len(saved_params_dict_test.keys())} systems with successful param perts")

### Make Skew System Trajectory

In [None]:
skew_sys_name = "SprottMore_CircadianRhythm"
# skew_sys_name = "PehlivanWei_Duffing"

In [None]:
# Load parameters
skew_sys_params = saved_params_dict_test[skew_sys_name][0]

is_skew = "_" in skew_sys_name
if is_skew:
    driver_name, response_name = skew_sys_name.split("_")
    sys = init_skew_system_from_params(driver_name, response_name, skew_sys_params)
else:
    raise ValueError(f"System {skew_sys_name} is not a skew system")

# Set initial condition
sys.ic = np.array(skew_sys_params["ic"])
print(sys.ic)

if not sys.has_jacobian():
    print(f"Jacobian not implemented for {skew_sys_name}")


# Make trajectory
num_timesteps = 4096
num_periods = 40

ts, traj = sys.make_trajectory(
    num_timesteps,
    pts_per_period=num_timesteps // num_periods,
    return_times=True,
    atol=1e-10,
    rtol=1e-8,
)

In [None]:
transient_frac = 0.05
transient_length = int(transient_frac * num_timesteps)

trajectory_to_plot = traj[None, transient_length:, :].transpose(0, 2, 1)
driver_coords = trajectory_to_plot[:, : sys.driver_dim]
response_coords = trajectory_to_plot[:, sys.driver_dim :]
for name, coords in [
    ("driver", driver_coords),
    ("response", response_coords),
]:
    plot_trajs_multivariate(
        coords,
        save_dir=None,
        plot_name=f"reconstructed_{skew_sys_name}_{name}",
        standardize=True,
        plot_projections=False,
        show_plot=True,
    )

skew_response_traj = traj[:, sys.driver_dim :]
print(f"Skew response trajectory shape: {skew_response_traj.shape}")

In [None]:
context_length = 1024  # actually min(4096, context_length + start_time)

show_plot = True
save_plot = False

sample_idx, start_time, subsample_interval = 0, 1024, 1

skew_response_trajectory = skew_response_traj.T[:, ::subsample_interval]

end_time = start_time + context_length

completions, processed_context, timestep_mask = get_model_completion(
    model_pipeline,
    skew_response_trajectory[:, start_time:end_time],  # context
    return_normalized_completions=False,
    verbose=False,
)
if show_plot:
    plot_model_completion(
        completions,
        processed_context,
        timestep_mask,
        figsize=(6, 8),
        save_path=None,
    )

In [None]:
completions.shape