In [None]:
import os
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm, trange

from dystformer.patchtst.pipeline import PatchTSTPipeline

In [None]:
# run_name = "pft_stand_rff_univariate-0"
# run_name = "pft_chattn_emb_w_poly-0"
# run_name = "pft_linattn_noemb_from_scratch-0"
run_name = "pft_chattn_fullemb_pretrained-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
@dataclass
class SineND:
    freqs: list[float]
    ts: np.ndarray

    def __call__(self, s: float):
        return np.array([np.sin(2 * np.pi * s * f * self.ts) for f in self.freqs])


ts = np.linspace(0, 1, 4096)
incomm_fn = SineND([1.0, np.pi, np.pi**2], ts)
comm_fn = SineND([1, 3, 9], ts)

fig = plt.figure(figsize=(10, 5))
gs = fig.add_gridspec(
    2, 2, height_ratios=[0.5, 0.5], width_ratios=[0.4, 0.6], wspace=0.2, hspace=0
)
ax = fig.add_subplot(gs[:, 0], projection="3d")
ax.plot3D(*incomm_fn(1), color="blue", label="incomm")
ax.plot3D(*comm_fn(1), color="orange", label="comm")
ax.legend()
for i, d in enumerate([incomm_fn(1), comm_fn(1)]):
    ax = fig.add_subplot(gs[i, 1])
    for j in range(3):
        ax.plot(ts, d[j], color=["blue", "orange"][i])
fig.show()

In [None]:
def temporal_attn_map_sequence(
    model,
    series_generator,
    series_params: np.ndarray,
    context_length: int,
    channel_idx: int,
    head_idx: int,
    patch_size: int,
    colormap="magma",
    show_title=True,
    save_path: str | None = None,
    fps: int = 30,
    linear_attn: bool = False,
):
    import matplotlib.pyplot as plt
    from IPython.display import HTML
    from matplotlib.animation import FuncAnimation
    from matplotlib.gridspec import GridSpec

    fig = plt.figure(figsize=(12, 7))
    gs = GridSpec(3, 4, figure=fig, height_ratios=[1, 2, 2])

    ax_ts = fig.add_subplot(gs[0, :])
    ax_ts.set_xlim(0, 1)
    ax_ts.set_xticks([])
    ax_ts.set_yticks([])
    (line,) = ax_ts.plot([], [])  # Empty line for time series

    # Create attention map axes
    axes = []
    im_list = []
    for i in range(8):
        row = (i // 4) + 1  # Start from row 1 (after time series)
        col = i % 4
        ax = fig.add_subplot(gs[row, col])
        im = ax.imshow(np.zeros((context_length, context_length)), cmap=colormap)
        ax.set_axis_off()
        ax.set_title(f"Layer {2 * i}")
        axes.append(ax)
        im_list.append(im)

    attnmaps = np.zeros(
        (len(series_params), 8, context_length // 16, context_length // 16)
    )
    ts = np.linspace(0, 1, context_length + 128)

    def update(frame):
        param = series_params[frame]
        series = series_generator(param)
        context = series[:, :context_length]

        context_tensor = torch.from_numpy(context.T).float().to(model.device)[None, ...]
        pred = model(context_tensor, output_attentions=True, linear_attn=linear_attn)
        attn_weights = pred.attentions

        for i, ax in enumerate(axes):
            attnmap = (
                attn_weights[(1 if len(attn_weights) == 8 else 2) * i][
                    channel_idx, head_idx
                ]
                .detach()
                .cpu()
                .numpy()
            )
            attnmaps[frame, i, :, :] = attnmap
            im_list[i].set_array(attnmap)
            vmin, vmax = attnmap.min(), attnmap.max()
            im_list[i].set_clim(vmin, vmax)

        # plot context and prediction
        pred_channel = pred.prediction_outputs[0, :, channel_idx].detach().cpu().numpy()
        ax_ts.clear()

        for i in range(1, context_length // patch_size + 1):
            boundary = i * patch_size
            if boundary < context_length + 1:
                ax_ts.axvline(
                    x=float(ts[boundary]), color="gray", linestyle=":", alpha=0.5
                )

        ax_ts.plot(ts[:context_length], context[channel_idx], color="black")
        ax_ts.plot(
            ts[context_length:],
            series[channel_idx, context_length : context_length + 128],
            color="black",
            linestyle="--",
        )
        ax_ts.plot(ts[context_length:], pred_channel, color="red")
        ax_ts.set_xlim(0, 1)
        ax_ts.set_xticks([])
        ax_ts.set_yticks([])
        ax_ts.set_title(
            f"Input Time Series head {head_idx} channel {channel_idx} param={param:.3f}"
        )

        return im_list

    anim = FuncAnimation(
        fig,
        update,
        frames=len(series_params),
        interval=1000 / fps,
        blit=True,
    )

    if save_path is not None:
        anim.save(
            save_path,
            writer="pillow" if save_path.endswith(".gif") else "ffmpeg",
            fps=fps,
        )
        plt.close()
    else:
        plt.close()
        display(HTML(anim.to_jshtml()))

    return attnmaps

In [None]:
head_idx = 2
channel_idx = 2

params = np.linspace(20, 30, 50)
attnmaps = temporal_attn_map_sequence(
    pft_model.model,
    incomm_fn,
    params,
    context_length=512,
    channel_idx=channel_idx,
    head_idx=head_idx,
    patch_size=16,
    fps=5,
    save_path=f"../figures/interp_{run_name}_{channel_idx}_{head_idx}.mp4",
    linear_attn=False,
)

In [None]:
norms = np.linalg.norm(attnmaps, axis=(2, 3), ord="fro")
fig = plt.figure(figsize=(10, 5))
print(norms.shape)
for i in range(8):
    sc = plt.plot(params, norms[:, i], color=plt.cm.tab10(i), label=f"Layer {2 * i}")
plt.legend()
plt.grid(True, which="both", linestyle="--", alpha=0.7)
plt.title("Attention Map Norm")
plt.xlabel("Frequency")
plt.show()


In [None]:
@torch.no_grad()
def extract_attn_maps(
    model,
    series: np.ndarray,
    context_length: int,
    linear_attn: bool = False,
):
    context = series[:, :context_length]
    if context.ndim == 1:
        context = context[None, ...]
    context_tensor = torch.from_numpy(context).float().to(model.device)
    pred = model(
        context_tensor[:, -context_length:, :],
        output_attentions=True,
        linear_attn=linear_attn,
    )
    return pred.attentions


def interaction_index(
    matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)
) -> float | np.ndarray:
    fronorm = np.linalg.norm(matrix, axis=axes, ord="fro")
    twonorm = np.linalg.norm(matrix, axis=axes, ord=2)
    return (fronorm - twonorm) / fronorm


def mean_row_entropy(
    matrix: np.ndarray, axis: int = -1, eps: float = 1e-10
) -> float | np.ndarray:
    assert np.allclose(matrix.sum(axis=axis), 1), (
        "All rows must be a probability distribution"
    )
    return -np.sum(matrix * np.log(matrix + eps), axis=axis).mean(axis=axis)


def fronorm(matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)) -> float | np.ndarray:
    return np.linalg.norm(matrix, axis=axes, ord="fro")


In [None]:
# DO NOT RUN UNLESS YOU DONT HAVE BISPECTRA.NPY
@dataclass
class TwoToneSine:
    freqs1: np.ndarray
    freqs2: np.ndarray
    base_freq: float
    ts: np.ndarray

    def __post_init__(self):
        self.indices = np.arange(self.freqs1.shape[0] * self.freqs2.shape[0])
        self.basewave = np.sin(2 * np.pi * self.base_freq * self.ts)

    @property
    def dims(self) -> int:
        return 3

    def product_array_indices(self, s: slice) -> tuple[np.ndarray, ...]:
        return np.unravel_index(
            self.indices[s], (self.freqs1.shape[0], self.freqs2.shape[0])
        )

    def __len__(self) -> int:
        return self.freqs1.shape[0] * self.freqs2.shape[0]

    def __getitem__(self, idx: int | slice) -> np.ndarray:
        if isinstance(idx, int):
            i = idx % self.freqs1.shape[0]
            j = idx // self.freqs1.shape[0]
            return np.array(
                [
                    np.sin(2 * np.pi * self.freqs1[i] * self.ts),
                    np.sin(2 * np.pi * self.freqs2[j] * self.ts),
                    np.sin(2 * np.pi * self.base_freq * self.ts),
                ]
            ).T
        elif isinstance(idx, slice):
            idxi, idxj = self.product_array_indices(idx)
            return np.stack(
                [
                    np.sin(2 * np.pi * self.freqs1[idxi, None] * self.ts),
                    np.sin(2 * np.pi * self.freqs2[idxj, None] * self.ts),
                    self.basewave[None, :].repeat(len(idxi), axis=0),
                ],
                axis=-1,
            )


resolution = 1024
freqs = np.linspace(1, 50, resolution)
series_fn = TwoToneSine(freqs, freqs, 10, np.linspace(0, 1, 1024))

bispectra = np.zeros((8, resolution * resolution, 3, 8))
channel_idx = 1
head_idx = 7
batch_size = 2048

torch.cuda.empty_cache()
for batch in trange(resolution * resolution // batch_size):
    series_batch = series_fn[batch * batch_size : (batch + 1) * batch_size]
    attn_weights = extract_attn_maps(
        pft_model.model,
        series_batch,
        512,
        linear_attn=False,
    )
    for i in range(8):
        attnmap = attn_weights[2 * i].detach().cpu().numpy()
        attnmap = attnmap.reshape(batch_size, 3, 8, 32, 32)
        response = fronorm(attnmap)
        bispectra[i, batch * batch_size : (batch + 1) * batch_size] = response

    torch.cuda.empty_cache()

bispectra = bispectra.reshape(-1, resolution, resolution, 3, 8)

np.save("bispectra.npy", bispectra)

In [None]:
bispectra = np.load("bispectra.npy")

bispectra.shape

In [None]:
channel_idx = 0
head_idx = 2
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for i in range(8):
    axes[i // 4, i % 4].imshow(
        bispectra[i, :, :, channel_idx, head_idx], origin="lower"
    )
    axes[i // 4, i % 4].set_title(f"Layer {2 * i}")
plt.show()

In [None]:
def plot_bispectra_across_checkpoints(
    series_fn,
    checkpoint_dir: str,
    context_length: int = 512,
    batch_size: int = 2048,
    figsize: tuple[int, int] = (20, 10),
    dims: int | None = None,
):
    """
    Plots the bispectra across all checkpoints in the given directory.

    Assumes that the checkpoint directory contains checkpoint-{i} folders for i in [0, 1, 2, ...]
    and a single checkpoint-final folder.
    """
    if dims is None:
        assert hasattr(series_fn, "dims"), (
            "series_fn must have a dims attribute if dims is not provided"
        )
    dims = dims or series_fn.dims

    checkpoints = os.listdir(checkpoint_dir)
    checkpoints.remove("checkpoint-final")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))

    for checkpoint in tqdm(checkpoints, desc="Processing checkpoints"):
        model = PatchTSTPipeline.from_pretrained(
            mode="predict",
            pretrain_path=f"{checkpoint_dir}/{checkpoint}",
            device_map="cuda:0",
        )
        num_heads = model.model.config.num_attention_heads
        num_tokens = context_length // model.model.config.patch_length

        attnmap_shape = (batch_size, dims, num_heads, num_tokens, num_tokens)

        bispectra = np.zeros(
            (model.model.config.num_hidden_layers, resolution * resolution, dims, 8)
        )

        torch.cuda.empty_cache()
        for batch in trange(resolution * resolution // batch_size):
            attn_weights = extract_attn_maps(
                model.model,
                series_fn[batch * batch_size : (batch + 1) * batch_size],
                context_length,
                linear_attn=False,
            )
            for i in range(8):
                attnmap = (
                    attn_weights[2 * i].detach().cpu().numpy().reshape(attnmap_shape)
                )
                response = fronorm(attnmap)
                bispectra[i, batch * batch_size : (batch + 1) * batch_size] = response

            torch.cuda.empty_cache()

        bispectra = bispectra.reshape(-1, resolution, resolution, 3, 8)

        # cleanup manually
        del model
        torch.cuda.empty_cache()

    print(checkpoints)

In [None]:
workdir = os.environ["WORK"]
plot_bispectra_across_checkpoints(
    series_fn,
    f"{workdir}/checkpoints/pft_chattn_fullemb_pretrained-0",
    figsize=(20, 10),
)
