In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import torch

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import plot_trajs_multivariate

In [None]:
run_name = "pft_chattn_emb_w_poly-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
def get_attn_weights(model, key: str) -> list[dict[str, torch.Tensor]]:
    params = [
        {
            "Wq": getattr(l, key).q_proj.weight,
            "Wk": getattr(l, key).k_proj.weight,
            "Wv": getattr(l, key).v_proj.weight,
        }
        for l in model.model.model.encoder.layers  # lol
    ]
    return params


def get_attn_map(
    weights: list[dict[str, torch.Tensor]], index: int, shift: bool = False
) -> np.ndarray:
    attn_map = (weights[index]["Wq"] @ weights[index]["Wk"].T).detach().cpu().numpy()
    if shift:
        attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map))
    return attn_map


def symmetric_distance(attn_map: np.ndarray) -> float:
    return (
        0.5
        * np.linalg.norm(attn_map - attn_map.T, "fro")
        / np.linalg.norm(attn_map, "fro")
    )  # type: ignore

In [None]:
temporal_weights = get_attn_weights(pft_model, "temporal_self_attn")
channel_weights = get_attn_weights(pft_model, "channel_self_attn")

In [None]:
attn_map = get_attn_map(temporal_weights, 0)
print(symmetric_distance(attn_map))
plt.figure()
plt.imshow(np.log(attn_map**2), cmap="RdBu")
plt.colorbar()
plt.show()

In [None]:
attn_map = get_attn_map(channel_weights, 0)
print(symmetric_distance(attn_map))
plt.figure()
plt.imshow(np.log(attn_map**2), cmap="RdBu")
plt.colorbar()
plt.show()

In [None]:
llayer = pft_model.model.model.encoder.layers[0].ff
print(llayer)
ffw = llayer[0].weight.detach().cpu().numpy()
print(symmetric_distance(ffw))

U, S, V = np.linalg.svd(ffw)
threshold = 1e-3
rank = np.sum(S > threshold)
plt.figure()
plt.plot(range(1, len(S) + 1), S, "o-", linewidth=2)
plt.title("Scree Plot of Singular Values")
plt.xlabel("Singular Value Index")
plt.ylabel("Singular Value Magnitude")
plt.grid(True)
plt.yscale("log")  # Log scale to better visualize the decay
plt.axhline(
    y=threshold, color="r", linestyle="--", label=f"Threshold ({threshold:.1e})"
)
plt.legend()
plt.show()

reconstructed = U[:, :rank] @ np.diag(S)[:rank, :rank] @ V[:rank, :]
plt.figure()
plt.imshow(np.log(reconstructed**2), cmap="RdBu")
plt.colorbar()
plt.show()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for i, ax in enumerate(axes.flatten()):
    attn_map = get_attn_map(temporal_weights, i)
    ax.imshow(attn_map, cmap="RdBu")
    ax.set_title(f"Layer {i}")
plt.tight_layout()
plt.show()

In [None]:
def plot_attn_map(
    model,
    context: np.ndarray,
    patch_size: int,
    sample_idx: int,
    layer_idx: int,
    head_idx: int,
    prefix: str = "",
) -> None:
    """Plot attention matrix with corresponding timeseries patches along edges."""
    attention_type = "temporal" if layer_idx % 2 == 0 else "channel"
    patches = context.reshape(context.shape[0], -1, patch_size)
    if attention_type == "channel":
        patches = patches.transpose(1, 0, 2)

    context_tensor = torch.from_numpy(context.T).float().to(pft_model.device)[None, ...]
    pred = model(context_tensor, output_attentions=True)
    attn_weights = pred.attentions

    # Extract attention weights for specified sample, layer and head
    num_samples = attn_weights[layer_idx].shape[0]
    attn = attn_weights[layer_idx][sample_idx, head_idx].detach().cpu().numpy()
    n_patches = attn.shape[0]

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(10, 10))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(1, 2, width_ratios=[1, 0.05], wspace=0.05)

    # Create sub-grid for the plots
    gs = outer_grid[0].subgridspec(
        2, 2, width_ratios=[0.15, 0.85], height_ratios=[0.15, 0.85], wspace=0, hspace=0
    )

    # Plot attention matrix first to get its size
    ax_main = fig.add_subplot(gs[1, 1])
    im = ax_main.imshow(attn, extent=(0, n_patches, n_patches, 0))
    ax_main.set_xticks([])
    ax_main.set_yticks([])

    # Plot patches along top
    ax_top = fig.add_subplot(gs[0, 1])
    for i in range(n_patches):
        x = np.linspace(i, i + 1, patch_size)
        ax_top.plot(x, patches[sample_idx, i], linewidth=1)
    ax_top.set_xlim(0, n_patches)
    ax_top.set_xticks([])
    ax_top.set_yticks([])
    ax_top.grid(True)

    # Plot patches along left side
    ax_left = fig.add_subplot(gs[1, 0])
    for i in range(n_patches):
        y = np.linspace(i, i + 1, patch_size)
        ax_left.plot(-patches[sample_idx, i], y, linewidth=1)
    ax_left.set_ylim(n_patches, 0)
    ax_left.set_xticks([])
    ax_left.set_yticks([])
    ax_left.grid(True)

    # Add colorbar
    ax_cbar = fig.add_subplot(outer_grid[1])
    plt.colorbar(im, cax=ax_cbar)

    # Remove empty subplot
    fig.delaxes(fig.add_subplot(gs[0, 0]))

    # Force exact alignment of subplots
    main_pos = ax_main.get_position()
    ax_top.set_position(
        [main_pos.x0, main_pos.y1, main_pos.width, ax_top.get_position().height]  # type: ignore
    )
    ax_left.set_position(
        [
            ax_left.get_position().x0,
            main_pos.y0,
            ax_left.get_position().width,
            main_pos.height,
        ]  # type: ignore
    )
    ax_cbar.set_position(
        [
            ax_cbar.get_position().x0,
            main_pos.y0,
            ax_cbar.get_position().width,
            main_pos.height,
        ]  # type: ignore
    )
    sample_type = "channel" if attention_type == "temporal" else "patch"
    ax_top.set_title(
        f"{prefix}{attention_type} attention @ layer {layer_idx}, head {head_idx}, ({sample_type} {sample_idx + 1}/{num_samples})"
    )
    plt.tight_layout()
    plt.show()

In [None]:
from dystformer.utils import get_system_filepaths, load_trajectory_from_arrow

dyst_name = "Lorenz"
test_data_dirs = "/stor/work/AMDG_Gilpin_Summer2024/data/final_base40"
syspaths = get_system_filepaths(dyst_name, test_data_dirs, "train")

sample_idx = 0
trajectory, _ = load_trajectory_from_arrow(syspaths[sample_idx])

In [None]:
plot_attn_map(
    pft_model.model,
    trajectory[:, :1024],
    16,
    sample_idx=1,
    layer_idx=0,
    head_idx=1,
    prefix=syspaths[0].parent.stem + " ",
)

In [None]:
def plot_model_prediction(
    model,
    context: np.ndarray,
    groundtruth: np.ndarray,
    prediction_length: int,
    title: str | None = None,
    save_path: str | None = None,
    **kwargs,
):
    context_tensor = torch.from_numpy(context.T).float().to(pft_model.device)[None, ...]
    pred = (
        model.predict(context_tensor, prediction_length, **kwargs)
        .squeeze()
        .cpu()
        .numpy()
    )
    total_length = context.shape[1] + prediction_length
    context_ts = np.arange(context.shape[1]) / total_length
    pred_ts = np.arange(context.shape[1], total_length) / total_length

    # Create figure with gridspec layout
    fig = plt.figure(figsize=(15, 4))

    # Create main grid with padding for colorbar
    outer_grid = fig.add_gridspec(1, 2, width_ratios=[0.5, 0.5], wspace=0.05)

    # Create sub-grid for the plots
    gs = outer_grid[1].subgridspec(3, 1, height_ratios=[1 / 3] * 3, wspace=0, hspace=0)
    ax_3d = fig.add_subplot(outer_grid[0], projection="3d")
    ax_3d.plot(*context, alpha=0.5, color="black", label="Context")
    ax_3d.plot(*groundtruth, linestyle="--", color="black", label="Groundtruth")
    ax_3d.plot(*pred.T, color="red", label="Prediction")
    ax_3d.set_xlabel("$x_1$")
    ax_3d.set_ylabel("$x_2$")
    ax_3d.set_zlabel("$x_3$")  # type: ignore
    ax_3d.get_legend().remove() if ax_3d.get_legend() else None
    handles, labels = ax_3d.get_legend_handles_labels()
    by_label = dict(zip(labels, handles))
    ax_3d.legend(by_label.values(), by_label.keys(), loc="upper right", fontsize=8)
    if title is not None:
        title_name = title.replace("_", " ")
        ax_3d.set_title(title_name, fontweight="bold")

    axes_1d = [fig.add_subplot(gs[i, 0]) for i in range(3)]
    for i, ax in enumerate(axes_1d):
        ax.plot(context_ts, context[i], alpha=0.5, color="black")
        ax.plot(pred_ts, groundtruth[i], linestyle="--", color="black")
        ax.plot(pred_ts, pred[:, i], color="red")
        ax.set_ylabel(f"$x_{i + 1}$")
        ax.set_aspect("auto")
    axes_1d[-1].set_xlabel("Time")
    # plt.tight_layout()
    if save_path is not None:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
trajectory.shape

In [None]:
context_length = 512
pred_length = 512
start_time = 2048
end_time = start_time + context_length
plot_model_prediction(
    pft_model,
    trajectory[:, start_time:end_time],  # context
    trajectory[:, end_time : end_time + pred_length],  # ground truth
    pred_length,
    limit_prediction_length=False,
    title=dyst_name,
    save_path=os.path.join(
        "figs",
        run_name,
        dyst_name,
        f"{dyst_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{pred_length}_.pdf",
    ),
)

In [None]:
context_length = 512
pred_length = 256
start_time = 128
end_time = start_time + context_length
plot_model_prediction(
    pft_model,
    trajectory[:, start_time:end_time],  # context
    trajectory[:, end_time : end_time + pred_length],  # ground truth
    pred_length,
    limit_prediction_length=False,
    title=dyst_name,
    save_path=os.path.join(
        "figs",
        run_name,
        dyst_name,
        f"{dyst_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{pred_length}_.pdf",
    ),
)

In [None]:
dyst_name = "HyperXu_SprottF"
split = "final_skew40"
subsplit = "test_zeroshot"
test_data_dirs = f"/stor/work/AMDG_Gilpin_Summer2024/data/copy/{split}"
syspaths = get_system_filepaths(dyst_name, test_data_dirs, subsplit)

sample_idx = 0
trajectory, _ = load_trajectory_from_arrow(syspaths[sample_idx])

In [None]:
traj_subsampled = trajectory[:, ::1]

In [None]:
plot_trajs_multivariate(np.expand_dims(traj_subsampled, axis=0), show_plot=True)

In [None]:
context_length = 512
pred_length = 128
start_time = 2048
end_time = start_time + context_length

save_path = os.path.join(
    "figs",
    run_name,
    split,
    subsplit,
    dyst_name,
    f"{dyst_name}_sample{sample_idx}_context{start_time}-{end_time}_pred{pred_length}_.pdf",
)

plot_model_prediction(
    pft_model,
    traj_subsampled[:, start_time:end_time],  # context
    traj_subsampled[:, end_time : end_time + pred_length],  # ground truth
    pred_length,
    limit_prediction_length=False,
    # sliding_context=True,
    title=dyst_name,
    save_path=None,  # save_path,
)