In [None]:
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm

from dystformer.patchtst.pipeline import PatchTSTPipeline
from dystformer.utils import safe_standardize


In [None]:
# run_name = "pft_stand_rff_univariate-0"
run_name = "pft_chattn_emb_w_poly-0"
# run_name = "pft_linattn_noemb_from_scratch-0"
# run_name = "pft_chattn_fullemb_pretrained-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
@torch.no_grad()
def extract_attn_maps(
    model,
    series: np.ndarray,
    context_length: int,
    linear_attn: bool = False,
):
    context = series[:, :context_length]
    if context.ndim == 1:
        context = context[None, ...]
    context_tensor = torch.from_numpy(context).float().to(model.device)
    pred = model(
        context_tensor[:, -context_length:, :],
        output_attentions=True,
        linear_attn=linear_attn,
    )
    return pred.attentions


def show_forecast(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    transient_length: int = 1024,
    **kwargs,
):
    data = data[:, transient_length:, :]

    context_data = data[:, :context_length, :]
    stand_context = safe_standardize(context_data, axis=1)
    predictions = (
        model.predict(
            torch.from_numpy(stand_context).float(),
            prediction_length=prediction_length,
            **kwargs,
        )[:, 0, ...]
        .detach()
        .cpu()
        .numpy()
    )
    stand_gt = safe_standardize(
        data[:, context_length : context_length + prediction_length, :],
        context=data[:, : context_length + prediction_length, :],
        axis=1,
    )

    context_ts = np.arange(0, context_length)
    prediction_ts = np.arange(context_length, context_length + prediction_length)
    plt.figure(figsize=(10, 4))
    for i in range(data.shape[-1]):
        plt.plot(context_ts, stand_context[0, :, i], color="k", alpha=0.5)
        plt.plot(prediction_ts, stand_gt[0, :, i], color="k", linestyle="--", alpha=0.3)
        plt.plot(prediction_ts, predictions[0, :, i], color="r")
    plt.show()


def interaction_index(
    matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)
) -> np.ndarray:
    fronorm = np.linalg.norm(matrix, axis=axes, ord="fro")
    twonorm = np.linalg.norm(matrix, axis=axes, ord=2)
    return (fronorm - twonorm) / (fronorm + 1e-10)


def mean_row_entropy(
    matrix: np.ndarray, axis: int = -1, eps: float = 1e-10
) -> np.ndarray:
    assert np.allclose(matrix.sum(axis=axis), 1), (
        "All rows must be a probability distribution"
    )
    return -np.sum(matrix * np.log(matrix + eps), axis=axis).mean(axis=axis)


def fronorm(matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)) -> np.ndarray:
    return np.linalg.norm(matrix, axis=axes, ord="fro")


In [None]:
def attention_rollout(
    attention_stack, skip_connection=True, start_layer=-1, stop_layer=None
):
    """
    Computes the attention rollout for a stack of attention matrices.
    Based on the description in Abnar & Zuidema
    https://arxiv.org/pdf/2005.00928

    Args:
        attention_stack (torch.Tensor): Tensor of shape (L, *, C, C) containing L attention
            matrices.
        skip_connection (bool): If True, adds an identity matrix to each attention matrix
            to account for residual connections.

    Returns:
        torch.Tensor: A (*, C, C) rollout attention matrix.
    """
    L, *_, C, _ = attention_stack.shape
    rollouts = torch.zeros(
        L + 1, *attention_stack.shape[1:], device=attention_stack.device
    )
    rollout = torch.eye(C, device=attention_stack.device)[None, ...]
    rollouts[0] = rollout
    for i in range(L):
        A = attention_stack[i]
        if skip_connection:
            A = 0.5 * (A + torch.eye(C, device=A.device))
            A = A / A.sum(dim=-1, keepdim=True)
        rollout = A @ rollout
        rollouts[i + 1] = rollout
    return rollouts[start_layer:stop_layer]


def single_head_attn_rollout(
    model,
    data: np.ndarray,
    context_length: int = 512,
    attention_type: str = "temporal",
    **kwargs,
) -> np.ndarray:
    """
    Computes the attention rollout for the whole model by averaging over heads.
    """
    bs, _, num_channels = data.shape
    attn_weights = extract_attn_maps(
        model.model,
        data,
        context_length,
        linear_attn=False,
    )
    if attention_type == "channel":
        attn_weights = attn_weights[1::2]
    elif attention_type == "temporal":
        attn_weights = attn_weights[0::2]
    else:
        raise ValueError("Attention type must be either 'channel' or 'temporal'")

    # average over heads
    # if attention_type is channel
    # shape: (num_layers, batch_size*num_tokens, num_channels, num_channels)
    # if attention_type is temporal
    # shape: (num_layers, batch_size*num_channels, num_tokens, num_tokens)
    attn_weights = torch.stack(attn_weights, dim=0).mean(dim=2)

    num_tokens = context_length // model.model.config.patch_length
    n = num_tokens if attention_type == "temporal" else num_channels
    m = num_channels if attention_type == "temporal" else num_tokens
    return (attention_rollout(attn_weights, **kwargs).detach().cpu().numpy()).reshape(
        -1, bs, m, n, n
    )


def multi_head_attn_rollout(
    model, data: np.ndarray, context_length: int = 512, attention_type: str = "temporal"
) -> np.ndarray:
    """Compute attention rollout for each head by averaging upstream heads"""
    bs, _, num_channels = data.shape
    attn_weights = extract_attn_maps(
        model.model,
        data,
        context_length,
        linear_attn=False,
    )
    if attention_type == "channel":
        attn_weights = attn_weights[1::2]
    elif attention_type == "temporal":
        attn_weights = attn_weights[0::2]
    else:
        raise ValueError("Attention type must be either 'channel' or 'temporal'")

    # shape: (num_layers, batch_size*num_channels, num_heads, num_tokens, num_tokens)
    attn_weights = torch.stack(attn_weights, dim=0)
    L, S, H, C, _ = attn_weights.shape

    # shape: (num_layers, bs, num_tokens, num_tokens)
    single_head_rollouts = attention_rollout(
        attn_weights.mean(dim=2), start_layer=0, stop_layer=L
    )
    # shape: (num_layers, bs, num_heads, num_tokens, num_tokens)
    rollouts = attn_weights @ single_head_rollouts.unsqueeze(2)

    num_tokens = context_length // model.model.config.patch_length
    n = num_tokens if attention_type == "temporal" else num_channels
    m = num_channels if attention_type == "temporal" else num_tokens

    # shape: (num_layers, batch_size, channels or tokens, num_heads, tokens or channels, tokens or channels)
    return rollouts.reshape(L, bs, m, H, n, n).detach().cpu().numpy()

In [None]:
@dataclass
class CoupledOscillator:
    num_oscillators: int
    w0: float = 1.0

    initial_conditions: np.ndarray | None = None
    spring_constants: np.ndarray | None = None
    masses: np.ndarray | None = None
    stencil: np.ndarray | None = None

    tspan: tuple[float, float] = (0, 100)
    num_eval_points: int = 1024

    def __post_init__(self):
        if self.spring_constants is None:
            self.spring_constants = np.ones(self.num_oscillators)

        if self.masses is None:
            self.masses = np.ones(self.num_oscillators)

        if self.initial_conditions is None:
            self.initial_conditions = np.random.randn(2 * self.dim)
            self.initial_conditions[self.dim :] = 0

        if self.stencil is None:
            self.stencil = np.zeros((self.num_oscillators, self.num_oscillators))
            for i in range(self.num_oscillators):
                self.stencil[i, (i - 1) % self.num_oscillators] = 1
                self.stencil[i, (i + 1) % self.num_oscillators] = 1
                self.stencil[i, i] = -2
                self.stencil *= self.spring_constants[i] / self.masses[i]

        self.ts = np.linspace(self.tspan[0], self.tspan[1], self.num_eval_points)

    @property
    def dim(self) -> int:
        return self.num_oscillators

    @property
    def basis(self) -> np.ndarray:
        return np.linalg.eigh(self.spring_constant * self.stencil)

    def __call__(self, t: float, uv: np.ndarray) -> np.ndarray:
        u, v = uv[: self.dim], uv[self.dim :]
        dudt = v
        dvdt = self.stencil @ u
        return np.concatenate([dudt, dvdt])

    def integrate(self) -> np.ndarray:
        sol = solve_ivp(self, self.tspan, self.initial_conditions, t_eval=self.ts)
        return sol.y[: self.dim].T

    def __getitem__(self, idx: int | slice) -> np.ndarray:
        if isinstance(idx, int):
            return self.integrate()
        elif isinstance(idx, slice):
            inds = np.arange(idx.start, idx.stop, idx.step)
            solutions = np.zeros((len(inds), self.num_eval_points, self.dim))
            for i, ind in tqdm(enumerate(inds), total=len(inds)):
                solutions[i] = self.integrate()
            return solutions
        else:
            raise ValueError(f"Invalid index: {idx}")

In [None]:
num_oscillators = 4
masses = np.ones(num_oscillators) * 0.01
springs = np.ones(num_oscillators) * 0.01

masses[2] = 1000000

series_fn = CoupledOscillator(
    num_oscillators=num_oscillators,
    masses=masses,
    spring_constants=springs,
    num_eval_points=4096,
)

In [None]:
show_forecast(
    pft_model,
    series_fn[0:1],
    context_length=512,
    prediction_length=512,
    limit_prediction_length=False,
    sliding_context=True,
    verbose=False,
)


In [None]:
singlehead_rollouts = single_head_attn_rollout(
    pft_model.model,
    series_fn[0:1],
    context_length=1024,
    attention_type="channel",
    start_layer=0,
    stop_layer=None,
)[:, 0, ...]
singlehead_rollouts.shape

In [None]:
token_idx = -1
fig, axes = plt.subplots(
    1, singlehead_rollouts.shape[0], figsize=(singlehead_rollouts.shape[0] * 2, 2)
)
plt.subplots_adjust(wspace=0.0)
for i in range(singlehead_rollouts.shape[0]):
    axes[i].imshow(singlehead_rollouts[i, token_idx], cmap="magma")
    axes[i].set_axis_off()
fig.supxlabel("Layer", fontsize=20, y=-0.05, x=0.50);


In [None]:
attn_maps = extract_attn_maps(
    pft_model.model,
    series_fn[0:1],
    context_length=1024,
    linear_attn=False,
)
len(attn_maps), attn_maps[0].shape, attn_maps[1].shape

In [None]:
sample_idx = -1
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i in range(8):
    for j in range(8):
        amap = attn_maps[2 * j + 1][sample_idx, i].detach().cpu().numpy()
        axes[i, j].imshow(amap, cmap="magma")
        axes[i, j].set_axis_off()
plt.show()


In [None]:
multihead_rollouts = multi_head_attn_rollout(
    pft_model.model, series_fn[0:1], context_length=1024, attention_type="channel"
)[:, 0, ...]
multihead_rollouts.shape

In [None]:
token_idx = -1
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i in range(8):
    for j in range(8):
        amap = multihead_rollouts[j, token_idx, i]
        axes[i, j].imshow(amap, cmap="magma")
        axes[i, j].set_axis_off()

fig.supxlabel("Layer", fontsize=20, y=0.08, x=0.51)
fig.supylabel("Head", fontsize=20, x=0.1);

In [None]:
@dataclass
class Sines:
    num_waves: int
    base_frequencies: np.ndarray
    tspan: tuple[float, float] = (0, 100)
    num_eval_points: int = 1024

    def __post_init__(self):
        self.ts = np.linspace(self.tspan[0], self.tspan[1], self.num_eval_points)

    @property
    def dim(self) -> int:
        return self.num_waves

    def __getitem__(self, idx: int | slice) -> np.ndarray:
        if isinstance(idx, int):
            return np.stack([np.sin(self.base_frequencies * self.ts[:, None])], axis=0)
        else:
            raise ValueError(f"Invalid index: {idx}")

In [None]:
series_fn = Sines(
    num_waves=8,
    base_frequencies=2 * np.pi * np.linspace(1 / 4, 1, 8),
    num_eval_points=4096,
)

In [None]:
show_forecast(
    pft_model,
    series_fn[0],
    context_length=512,
    prediction_length=512,
    limit_prediction_length=False,
    sliding_context=True,
    verbose=False,
)

In [None]:
singlehead_rollouts = single_head_attn_rollout(
    pft_model.model,
    series_fn[0],
    context_length=1024,
    attention_type="channel",
    start_layer=0,
    stop_layer=None,
)[:, 0, ...]
singlehead_rollouts.shape

In [None]:
token_idx = -1
fig, axes = plt.subplots(
    1, singlehead_rollouts.shape[0], figsize=(singlehead_rollouts.shape[0] * 2, 2)
)
plt.subplots_adjust(wspace=0.0)
for i in range(singlehead_rollouts.shape[0]):
    axes[i].imshow(singlehead_rollouts[i, token_idx], cmap="magma")
    axes[i].set_axis_off()
fig.supxlabel("Layer", fontsize=20, y=-0.05, x=0.50);


In [None]:
attn_maps = extract_attn_maps(
    pft_model.model,
    series_fn[0],
    context_length=1024,
    linear_attn=False,
)
len(attn_maps), attn_maps[0].shape, attn_maps[1].shape

In [None]:
sample_idx = -1
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i in range(8):
    for j in range(8):
        amap = attn_maps[2 * j + 1][sample_idx, i].detach().cpu().numpy()
        axes[i, j].imshow(amap, cmap="magma")
        axes[i, j].set_axis_off()
plt.show()


In [None]:
multihead_rollouts = multi_head_attn_rollout(
    pft_model.model, series_fn[0], context_length=1024, attention_type="channel"
)[:, 0, ...]
multihead_rollouts.shape

In [None]:
token_idx = -1
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i in range(8):
    for j in range(8):
        amap = multihead_rollouts[j, token_idx, i]
        axes[i, j].imshow(amap, cmap="magma")
        axes[i, j].set_axis_off()

fig.supxlabel("Layer", fontsize=20, y=0.08, x=0.51)
fig.supylabel("Head", fontsize=20, x=0.1);