In [None]:
import os
from dataclasses import dataclass
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm.auto import tqdm, trange

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import apply_custom_style

In [None]:
# Apply matplotlib style from config
apply_custom_style("../config/plotting.yaml")

In [None]:
# run_name = "pft_stand_rff_univariate-0"
# run_name = "pft_chattn_emb_w_poly-0"
# run_name = "pft_linattn_noemb_from_scratch-0"
run_name = "pft_chattn_fullemb_pretrained-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
@dataclass
class SineND:
    freqs: list[float]
    ts: np.ndarray

    def __call__(self, s: float):
        return np.array([np.sin(2 * np.pi * s * f * self.ts) for f in self.freqs])


ts = np.linspace(0, 1, 4096)
incomm_fn = SineND([1.0, np.pi, np.pi**2], ts)
comm_fn = SineND([1, 3, 9], ts)

fig = plt.figure(figsize=(10, 5))
gs = fig.add_gridspec(
    2, 2, height_ratios=[0.5, 0.5], width_ratios=[0.4, 0.6], wspace=0.2, hspace=0
)
ax = fig.add_subplot(gs[:, 0], projection="3d")
ax.plot3D(*incomm_fn(1), color="blue", label="incomm")
ax.plot3D(*comm_fn(1), color="orange", label="comm")
ax.legend()
for i, d in enumerate([incomm_fn(1), comm_fn(1)]):
    ax = fig.add_subplot(gs[i, 1])
    for j in range(3):
        ax.plot(ts, d[j], color=["blue", "orange"][i])
fig.show()

In [None]:
def temporal_attn_map_sequence(
    model,
    series_generator,
    series_params: np.ndarray,
    context_length: int,
    channel_idx: int,
    head_idx: int,
    patch_size: int,
    colormap="magma",
    save_path: str | None = None,
    fps: int = 30,
    linear_attn: bool = False,
):
    import matplotlib.pyplot as plt
    from IPython.display import HTML
    from matplotlib.animation import FuncAnimation
    from matplotlib.gridspec import GridSpec

    fig = plt.figure(figsize=(12, 7))
    gs = GridSpec(3, 4, figure=fig, height_ratios=[1, 2, 2])

    ax_ts = fig.add_subplot(gs[0, :])
    ax_ts.set_xlim(0, 1)
    ax_ts.set_xticks([])
    ax_ts.set_yticks([])
    (line,) = ax_ts.plot([], [])  # Empty line for time series

    # Create attention map axes
    axes = []
    im_list = []
    for i in range(8):
        row = (i // 4) + 1  # Start from row 1 (after time series)
        col = i % 4
        ax = fig.add_subplot(gs[row, col])
        im = ax.imshow(np.zeros((context_length, context_length)), cmap=colormap)
        ax.set_axis_off()
        ax.set_title(f"Layer {2 * i}")
        axes.append(ax)
        im_list.append(im)

    attnmaps = np.zeros(
        (len(series_params), 8, context_length // 16, context_length // 16)
    )
    ts = np.linspace(0, 1, context_length + 128)

    def update(frame):
        param = series_params[frame]
        series = series_generator(param)
        context = series[:, :context_length]

        context_tensor = torch.from_numpy(context.T).float().to(model.device)[None, ...]
        pred = model(context_tensor, output_attentions=True, linear_attn=linear_attn)
        attn_weights = pred.attentions

        for i, ax in enumerate(axes):
            attnmap = (
                attn_weights[(1 if len(attn_weights) == 8 else 2) * i][
                    channel_idx, head_idx
                ]
                .detach()
                .cpu()
                .numpy()
            )
            attnmaps[frame, i, :, :] = attnmap
            im_list[i].set_array(attnmap)
            vmin, vmax = attnmap.min(), attnmap.max()
            im_list[i].set_clim(vmin, vmax)

        # plot context and prediction
        pred_channel = pred.prediction_outputs[0, :, channel_idx].detach().cpu().numpy()
        ax_ts.clear()

        for i in range(1, context_length // patch_size + 1):
            boundary = i * patch_size
            if boundary < context_length + 1:
                ax_ts.axvline(
                    x=float(ts[boundary]), color="gray", linestyle=":", alpha=0.5
                )

        ax_ts.plot(ts[:context_length], context[channel_idx], color="black")
        ax_ts.plot(
            ts[context_length:],
            series[channel_idx, context_length : context_length + 128],
            color="black",
            linestyle="--",
        )
        ax_ts.plot(ts[context_length:], pred_channel, color="red")
        ax_ts.set_xlim(0, 1)
        ax_ts.set_xticks([])
        ax_ts.set_yticks([])
        ax_ts.set_title(
            f"Input Time Series head {head_idx} channel {channel_idx} param={param:.3f}"
        )

        return im_list

    anim = FuncAnimation(
        fig,
        update,
        frames=len(series_params),
        interval=1000 / fps,
        blit=True,
    )

    if save_path is not None:
        anim.save(
            save_path,
            writer="pillow" if save_path.endswith(".gif") else "ffmpeg",
            fps=fps,
        )
        plt.close()
    else:
        plt.close()
        display(HTML(anim.to_jshtml()))

    return attnmaps

In [None]:
head_idx = 2
channel_idx = 2

params = np.linspace(20, 30, 50)
attnmaps = temporal_attn_map_sequence(
    pft_model.model,
    incomm_fn,
    params,
    context_length=512,
    channel_idx=channel_idx,
    head_idx=head_idx,
    patch_size=16,
    fps=5,
    save_path=f"../figures/interp_{run_name}_{channel_idx}_{head_idx}.mp4",
    linear_attn=False,
)

In [None]:
@torch.no_grad()
def extract_attn_maps(
    model,
    series: np.ndarray,
    context_length: int,
    linear_attn: bool = False,
):
    context = series[:, :context_length]
    if context.ndim == 1:
        context = context[None, ...]
    context_tensor = torch.from_numpy(context).float().to(model.device)
    pred = model(
        context_tensor[:, -context_length:, :],
        output_attentions=True,
        linear_attn=linear_attn,
    )
    return pred.attentions


def interaction_index(
    matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)
) -> np.ndarray:
    fronorm = np.linalg.norm(matrix, axis=axes, ord="fro")
    twonorm = np.linalg.norm(matrix, axis=axes, ord=2)
    return (fronorm - twonorm) / (fronorm + 1e-10)


def mean_row_entropy(
    matrix: np.ndarray, axis: int = -1, eps: float = 1e-10
) -> np.ndarray:
    assert np.allclose(matrix.sum(axis=axis), 1), (
        "All rows must be a probability distribution"
    )
    return -np.sum(matrix * np.log(matrix + eps), axis=axis).mean(axis=axis)


def fronorm(matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)) -> np.ndarray:
    return np.linalg.norm(matrix, axis=axes, ord="fro")

In [None]:
@dataclass
class TwoToneSine:
    freqs1: np.ndarray
    freqs2: np.ndarray
    base_freq: float
    ts: np.ndarray

    def __post_init__(self):
        self.indices = np.arange(self.freqs1.shape[0] * self.freqs2.shape[0])
        self.basewave = np.sin(2 * np.pi * self.base_freq * self.ts)

    @property
    def shape(self) -> tuple[int, int]:
        return self.freqs1.shape[0], self.freqs2.shape[0]

    @property
    def dims(self) -> int:
        return 3

    def product_array_indices(self, s: slice) -> tuple[np.ndarray, ...]:
        return np.unravel_index(
            self.indices[s], (self.freqs1.shape[0], self.freqs2.shape[0])
        )

    def __len__(self) -> int:
        return self.freqs1.shape[0] * self.freqs2.shape[0]

    def __getitem__(self, idx: int | slice) -> np.ndarray:
        if isinstance(idx, int):
            i = idx % self.freqs1.shape[0]
            j = idx // self.freqs1.shape[0]
            return np.array(
                [
                    np.sin(2 * np.pi * self.freqs1[i] * self.ts),
                    np.sin(2 * np.pi * self.freqs2[j] * self.ts),
                    np.sin(2 * np.pi * self.base_freq * self.ts),
                ]
            ).T
        elif isinstance(idx, slice):
            idxi, idxj = self.product_array_indices(idx)
            return np.stack(
                [
                    np.sin(2 * np.pi * self.freqs1[idxi, None] * self.ts),
                    np.sin(2 * np.pi * self.freqs2[idxj, None] * self.ts),
                    self.basewave[None, :].repeat(len(idxi), axis=0),
                ],
                axis=-1,
            )

In [None]:
bispectra_fpath = "bispectra.npy"

if os.path.exists(bispectra_fpath):
    bispectra = np.load(bispectra_fpath)
else:
    resolution = 1024
    freqs = np.linspace(1, 50, resolution)
    series_fn = TwoToneSine(freqs, freqs, 10, np.linspace(0, 1, 1024))

    bispectra = np.zeros((8, resolution * resolution, 3, 8))
    channel_idx = 1
    head_idx = 7
    batch_size = 2048

    torch.cuda.empty_cache()
    for batch in trange(resolution * resolution // batch_size):
        series_batch = series_fn[batch * batch_size : (batch + 1) * batch_size]
        attn_weights = extract_attn_maps(
            pft_model.model,
            series_batch,
            512,
            linear_attn=False,
        )
        for i in range(8):
            attnmap = attn_weights[2 * i].detach().cpu().numpy()
            attnmap = attnmap.reshape(batch_size, 3, 8, 32, 32)
            response = fronorm(attnmap)
            bispectra[i, batch * batch_size : (batch + 1) * batch_size] = response

        torch.cuda.empty_cache()

    bispectra = bispectra.reshape(-1, resolution, resolution, 3, 8)

    np.save("bispectra.npy", bispectra)

print(bispectra.shape)

In [None]:
channel_idx = 0
head_idx = 2
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for i in range(8):
    axes[i // 4, i % 4].imshow(
        bispectra[i, :, :, channel_idx, head_idx], origin="lower"
    )
    axes[i // 4, i % 4].set_title(f"Layer {2 * (i + 1)}")
plt.show()

In [None]:
def bispectra_scaling(
    series_fn,
    checkpoint_dir: str,
    context_length: int = 512,
    batch_size: int = 2048,
    dims: int | None = None,
    norm: Callable = fronorm,
) -> np.ndarray:
    """
    Plots the bispectra across all checkpoints in the given directory.

    Assumes that the checkpoint directory contains checkpoint-{i} folders for i in [0, 1, 2, ...]
    and a single checkpoint-final folder.
    """
    if dims is None:
        assert hasattr(series_fn, "dims"), (
            "series_fn must have a dims attribute if dims is not provided"
        )
    dims = dims or series_fn.dims

    H, W = series_fn.shape
    checkpoints = os.listdir(checkpoint_dir)
    checkpoints.remove("checkpoint-final")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))

    max_interactions = []
    for i, checkpoint in enumerate(tqdm(checkpoints, desc="Processing checkpoints")):
        model = PatchTSTPipeline.from_pretrained(
            mode="predict",
            pretrain_path=f"{checkpoint_dir}/{checkpoint}",
            device_map="cuda:0",
        )

        batch_size = min(batch_size, H * W)
        num_heads = model.model.config.num_attention_heads
        num_tokens = context_length // model.model.config.patch_length
        num_layers = model.model.config.num_hidden_layers

        attnmap_shape = (batch_size, dims, num_heads, num_tokens, num_tokens)

        bispectra = np.zeros((num_layers, H * W, dims, num_heads))

        for batch in range(H * W // batch_size):
            attn_weights = extract_attn_maps(
                model.model,
                series_fn[batch * batch_size : (batch + 1) * batch_size],
                context_length,
                linear_attn=False,
            )
            for i in range(num_layers):
                attnmap = (
                    attn_weights[2 * i].detach().cpu().numpy().reshape(attnmap_shape)
                )
                response = norm(attnmap)
                bispectra[i, batch * batch_size : (batch + 1) * batch_size] = response
            torch.cuda.empty_cache()

        # shape: (num_layers, H, W, dims, num_heads)
        bispectra = bispectra.reshape(-1, H, W, dims, num_heads)
        interactions = interaction_index(bispectra, axes=(1, 2)).mean(axis=(-1, -2))
        max_interactions.append(interactions)

        # cleanup manually
        del model
        torch.cuda.empty_cache()

    return np.array(max_interactions)

In [None]:
freqs = np.linspace(1, 40, 32)
series_fn = TwoToneSine(freqs, freqs, 10, np.linspace(0, 1, 1024))
workdir = os.environ["WORK"]
interactions = bispectra_scaling(
    series_fn,
    # f"{workdir}/checkpoints/pft_chattn_fullemb_pretrained-0",
    f"{workdir}/checkpoints/pft_chattn_mlm_sys5245_ic4-0",
    # f"{workdir}/checkpoints/pft_chattn_mlm_sys164_ic128-0",
    norm=fronorm,
)

In [None]:
plt.figure(figsize=(20, 10))
plt.imshow(interactions.T, cmap="magma", aspect="auto", origin="lower")
plt.ylabel("Layer (channel attention)")
plt.xlabel("Checkpoint")
plt.yticks(np.arange(8), np.arange(2, 2 * 8 + 2, 2))
plt.colorbar()
plt.show()

In [None]:
from scipy.integrate import solve_ivp


@dataclass
class CoupledOscillator:
    num_oscillators: int
    mass: float = 1.0
    spring_constant: float = 1.0
    w0: float = 1.0
    initial_conditions: np.ndarray | None = None

    tspan: tuple[float, float] = (0, 100)
    num_eval_points: int = 1000
    k_function: Callable = lambda x: 1.0

    def __post_init__(self):
        self.stencil = np.zeros((self.num_oscillators, self.num_oscillators))
        for i in range(self.num_oscillators):
            self.stencil[i, (i - 1) % self.num_oscillators] = 1
            self.stencil[i, (i + 1) % self.num_oscillators] = 1
            self.stencil[i, i] = -2
        self.stencil /= self.mass

        if self.initial_conditions is None:
            self.initial_conditions = np.random.randn(2 * self.dim)

        self.ts = np.linspace(self.tspan[0], self.tspan[1], self.num_eval_points)

    @property
    def dim(self) -> int:
        return self.num_oscillators

    def __call__(self, t: float, uv: np.ndarray) -> np.ndarray:
        u, v = uv[: self.dim], uv[self.dim :]
        dudt = v
        dvdt = -(self.w0**2) * u + self.spring_constant * self.stencil @ u
        return np.concatenate([dudt, dvdt])

    def integrate(self) -> np.ndarray:
        sol = solve_ivp(self, self.tspan, self.initial_conditions, t_eval=self.ts)
        return sol.y[: self.dim].T

    def __getitem__(self, idx: int | slice) -> np.ndarray:
        if isinstance(idx, int):
            self.spring_constant = self.k_function(idx)
            return self.integrate()
        elif isinstance(idx, slice):
            inds = np.arange(idx.start, idx.stop, idx.step)
            solutions = np.zeros((len(inds), self.num_eval_points, self.dim))
            for i, ind in tqdm(enumerate(inds), total=len(inds)):
                self.spring_constant = self.k_function(ind)
                solutions[i] = self.integrate()
            return solutions
        else:
            raise ValueError(f"Invalid index: {idx}")

In [None]:
def attention_rollout(attention_stack, skip_connection=True):
    """
    Computes the attention rollout for a stack of attention matrices.
    Based on the description in Abnar & Zuidema
    https://arxiv.org/pdf/2005.00928

    Args:
        attention_stack (torch.Tensor): Tensor of shape (L, *, C, C) containing L attention
            matrices.
        skip_connection (bool): If True, adds an identity matrix to each attention matrix
            to account for residual connections.

    Returns:
        torch.Tensor: A (*, C, C) rollout attention matrix.
    """
    L, *_, C, _ = attention_stack.shape
    rollout = torch.eye(C, device=attention_stack.device)[None, ...]
    for i in range(L):
        A = attention_stack[i]
        if skip_connection:
            A = 0.5 * (A + torch.eye(C, device=A.device))
            A = A / A.sum(dim=-1, keepdim=True)
        rollout = A @ rollout
    return rollout


def single_head_attn_rollout(
    model, data, context_length: int = 512, attention_type: str = "temporal"
):
    """
    Computes the attention rollout for the whole model by averaging over heads.
    """
    bs, _, num_channels = data.shape
    attn_weights = extract_attn_maps(
        model.model,
        data,
        context_length,
        linear_attn=False,
    )
    if attention_type == "channel":
        attn_weights = attn_weights[1::2]
    elif attention_type == "temporal":
        attn_weights = attn_weights[0::2]
    else:
        raise ValueError("Attention type must be either 'channel' or 'temporal'")

    num_layers = len(attn_weights)
    num_tokens = context_length // model.model.config.patch_length

    # average over heads
    # if attention_type is channel
    # shape: (num_layers, batch_size*num_tokens, num_channels, num_channels)
    # if attention_type is temporal
    # shape: (num_layers, batch_size*num_channels, num_tokens, num_tokens)
    attn_weights = torch.stack(attn_weights, dim=0).mean(dim=2)

    n = num_tokens if attention_type == "temporal" else num_channels
    m = num_channels if attention_type == "temporal" else num_tokens
    return (
        attention_rollout(attn_weights).detach().cpu().numpy()
    )  # .reshape(bs, m, n, n)

In [None]:
series_fn = CoupledOscillator(num_oscillators=6, mass=1.0, spring_constant=1, w0=1.0)
series_fn.k_function = lambda i: (np.pi / 2) ** (i - 18)
data = series_fn[0:36]

attn_type = "temporal"

rollouts = single_head_attn_rollout(
    pft_model.model,
    data,
    context_length=512,
    attention_type=attn_type,
)
rollouts.shape

In [None]:
fig, axes = plt.subplots(6, 6, figsize=(20, 20))
plt.subplots_adjust(wspace=0.0, hspace=0.0)

inds = np.random.randint(0, rollouts.shape[0], size=data.shape[0])
# inds = np.arange(
#     5,
#     rollouts.shape[0],
#     series_fn.dim if attn_type == "temporal" else len(ts) // pft_model.model.config.patch_length,
# )

for ax, ind in zip(axes.ravel(), inds):
    ax.imshow(rollouts[ind], cmap="magma")
    ax.set_axis_off()
plt.show()