In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
from dataclasses import dataclass

import matplotlib.pyplot as plt
import numpy as np
import torch
from scipy.integrate import solve_ivp
from tqdm.auto import tqdm

from panda.patchtst.pipeline import PatchTSTPipeline
from panda.utils.data_utils import safe_standardize
from panda.utils.plot_utils import apply_custom_style

apply_custom_style("../config/plotting.yaml")

In [None]:
run_name = "pft_chattn_emb_w_poly-0"
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path=f"/stor/work/AMDG_Gilpin_Summer2024/checkpoints/{run_name}/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
@torch.no_grad()
def extract_attn_maps(
    model,
    series: np.ndarray,
    context_length: int,
    linear_attn: bool = False,
):
    context = series[:, :context_length]
    if context.ndim == 1:
        context = context[None, ...]
    context_tensor = torch.from_numpy(context).float().to(model.device)
    pred = model(
        context_tensor[:, -context_length:, :],
        output_attentions=True,
        linear_attn=linear_attn,
    )
    return pred.attentions


def show_forecast(
    model,
    data: np.ndarray,
    context_length: int,
    prediction_length: int,
    transient_length: int = 1024,
    **kwargs,
):
    data = data[:, transient_length:, :]

    context_data = data[:, :context_length, :]
    stand_context = safe_standardize(context_data, axis=1)
    predictions = (
        model.predict(
            torch.from_numpy(stand_context).float(),
            prediction_length=prediction_length,
            **kwargs,
        )[:, 0, ...]
        .detach()
        .cpu()
        .numpy()
    )
    stand_gt = safe_standardize(
        data[:, context_length : context_length + prediction_length, :],
        context=data[:, :context_length, :],
        axis=1,
    )

    context_ts = np.arange(0, context_length)
    prediction_ts = np.arange(context_length, context_length + prediction_length)
    plt.figure(figsize=(10, 4))
    for i in range(data.shape[-1]):
        plt.plot(context_ts, stand_context[0, :, i], color="k", alpha=0.5)
        plt.plot(prediction_ts, stand_gt[0, :, i], color="k", linestyle="--", alpha=0.3)
        plt.plot(prediction_ts, predictions[0, :, i], color="r")
    plt.show()


def interaction_index(matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)) -> np.ndarray:
    fronorm = np.linalg.norm(matrix, axis=axes, ord="fro")
    twonorm = np.linalg.norm(matrix, axis=axes, ord=2)
    return 1 - fronorm / (twonorm * np.sqrt(matrix.shape[axes[0]]))


def mean_row_entropy(matrix: np.ndarray, axis: int = -1, eps: float = 1e-10) -> np.ndarray:
    assert np.allclose(matrix.sum(axis=axis), 1), "All rows must be a probability distribution"
    return -np.sum(matrix * np.log(matrix + eps), axis=axis).mean(axis=axis)


def fronorm(matrix: np.ndarray, axes: tuple[int, int] = (-2, -1)) -> np.ndarray:
    return np.linalg.norm(matrix, axis=axes, ord="fro")

In [None]:
def attention_rollout(attention_stack, skip_connection=True, start_layer=-1, stop_layer=None):
    """
    Computes the attention rollout for a stack of attention matrices.
    Based on the description in Abnar & Zuidema
    https://arxiv.org/pdf/2005.00928

    Args:
        attention_stack (torch.Tensor): Tensor of shape (L, *, C, C) containing L attention
            matrices.
        skip_connection (bool): If True, adds an identity matrix to each attention matrix
            to account for residual connections.

    Returns:
        torch.Tensor: A (*, C, C) rollout attention matrix.
    """
    L, *_, C, _ = attention_stack.shape
    rollouts = torch.zeros(L + 1, *attention_stack.shape[1:], device=attention_stack.device)
    rollout = torch.eye(C, device=attention_stack.device)[None, ...]
    rollouts[0] = rollout
    for i in range(L):
        A = attention_stack[i]
        if skip_connection:
            A += torch.eye(C, device=A.device)
            A /= A.sum(dim=-1, keepdim=True)
        rollout = rollout @ A
        rollouts[i + 1] = rollout
    rollout_result = rollouts[start_layer:stop_layer]
    torch.cuda.empty_cache()
    del rollouts
    return rollout_result


def single_head_attn_rollout(
    model,
    data: np.ndarray,
    context_length: int = 512,
    attention_type: str = "temporal",
    **kwargs,
) -> np.ndarray:
    """
    Computes the attention rollout for the whole model by averaging over heads.
    """
    bs, _, num_channels = data.shape
    attn_weights = extract_attn_maps(
        model.model,
        data,
        context_length,
        linear_attn=False,
    )
    if attention_type == "channel":
        attn_weights = attn_weights[1::2]
    elif attention_type == "temporal":
        attn_weights = attn_weights[0::2]
    else:
        raise ValueError("Attention type must be either 'channel' or 'temporal'")

    # average over heads
    # if attention_type is channel
    # shape: (num_layers, batch_size*num_tokens, num_channels, num_channels)
    # if attention_type is temporal
    # shape: (num_layers, batch_size*num_channels, num_tokens, num_tokens)
    attn_weights = torch.stack(attn_weights, dim=0).mean(dim=2)

    num_tokens = context_length // model.model.config.patch_length
    n = num_tokens if attention_type == "temporal" else num_channels
    m = num_channels if attention_type == "temporal" else num_tokens
    rollouts = (attention_rollout(attn_weights, **kwargs).detach().cpu().numpy()).reshape(-1, bs, m, n, n)
    del attn_weights
    torch.cuda.empty_cache()
    return rollouts


def multi_head_attn_rollout(
    model, data: np.ndarray, context_length: int = 512, attention_type: str = "temporal"
) -> np.ndarray:
    """Compute attention rollout for each head by averaging upstream heads"""
    bs, _, num_channels = data.shape
    attn_weights = extract_attn_maps(
        model.model,
        data,
        context_length,
        linear_attn=False,
    )
    if attention_type == "channel":
        attn_weights = attn_weights[1::2]
    elif attention_type == "temporal":
        attn_weights = attn_weights[0::2]
    else:
        raise ValueError("Attention type must be either 'channel' or 'temporal'")

    # shape: (num_layers, batch_size*num_channels, num_heads, num_tokens, num_tokens)
    attn_weights = torch.stack(attn_weights, dim=0)
    L, S, H, C, _ = attn_weights.shape

    # shape: (num_layers, bs, num_tokens, num_tokens)
    single_head_rollouts = attention_rollout(attn_weights.mean(dim=2), start_layer=0, stop_layer=L)
    # shape: (num_layers, bs, num_heads, num_tokens, num_tokens)
    rollouts = attn_weights @ single_head_rollouts.unsqueeze(2)

    num_tokens = context_length // model.model.config.patch_length
    n = num_tokens if attention_type == "temporal" else num_channels
    m = num_channels if attention_type == "temporal" else num_tokens

    # shape: (num_layers, batch_size, channels or tokens, num_heads, tokens or channels, tokens or channels)
    return rollouts.reshape(L, bs, m, H, n, n).detach().cpu().numpy()

In [None]:
@dataclass
class CoupledOscillator:
    num_oscillators: int
    w0: float = 1.0

    initial_conditions: np.ndarray | None = None
    spring_constants: np.ndarray | None = None
    masses: np.ndarray | None = None
    stencil: np.ndarray | None = None

    tspan: tuple[float, float] = (0, 100)
    num_eval_points: int = 1024

    def __post_init__(self):
        if self.spring_constants is None:
            self.spring_constants = np.ones(self.num_oscillators)

        if self.masses is None:
            self.masses = np.ones(self.num_oscillators)

        if self.initial_conditions is None:
            self.initial_conditions = np.random.randn(2 * self.dim)
            self.initial_conditions[self.dim :] = 0

        if self.stencil is None:
            self.stencil = np.zeros((self.num_oscillators, self.num_oscillators))
            for i in range(self.num_oscillators):
                self.stencil[i, (i - 1) % self.num_oscillators] = self.spring_constants[(i - 1) % self.num_oscillators]
                self.stencil[i, (i - 1) % self.num_oscillators] /= self.masses[(i - 1) % self.num_oscillators]
                self.stencil[i, (i + 1) % self.num_oscillators] = self.spring_constants[(i + 1) % self.num_oscillators]
                self.stencil[i, (i + 1) % self.num_oscillators] /= self.masses[(i + 1) % self.num_oscillators]
                self.stencil[i, i] = -2 * self.spring_constants[i] / self.masses[i]

        self.ts = np.linspace(self.tspan[0], self.tspan[1], self.num_eval_points)

    @property
    def dim(self) -> int:
        return self.num_oscillators

    @property
    def basis(self) -> np.ndarray:
        return np.linalg.eigh(self.stencil)

    def __call__(self, t: float, uv: np.ndarray) -> np.ndarray:
        u, v = uv[: self.dim], uv[self.dim :]
        dudt = v
        dvdt = self.stencil @ u
        return np.concatenate([dudt, dvdt])

    def integrate(self) -> np.ndarray:
        sol = solve_ivp(self, self.tspan, self.initial_conditions, t_eval=self.ts)
        return sol.y[: self.dim].T

    def __getitem__(self, idx: int | slice) -> np.ndarray:
        if isinstance(idx, int):
            return self.integrate()
        elif isinstance(idx, slice):
            inds = np.arange(idx.start, idx.stop, idx.step)
            solutions = np.zeros((len(inds), self.num_eval_points, self.dim))
            for i, ind in tqdm(enumerate(inds), total=len(inds)):
                solutions[i] = self.integrate()
            return solutions
        else:
            raise ValueError(f"Invalid index: {idx}")

In [None]:
num_oscillators = 8
masses = np.ones(num_oscillators) * 0.01
springs = np.ones(num_oscillators) * 0.1

series_fn = CoupledOscillator(
    num_oscillators=num_oscillators,
    masses=masses,
    spring_constants=springs,
    num_eval_points=4096,
)

In [None]:
show_forecast(
    pft_model,
    series_fn[0:1],
    context_length=512,
    prediction_length=512,
    limit_prediction_length=False,
    sliding_context=True,
    verbose=False,
)

In [None]:
singlehead_rollouts = single_head_attn_rollout(
    pft_model.model,
    series_fn[0:1],
    context_length=1024,
    attention_type="channel",
    start_layer=0,
    stop_layer=None,
)[:, 0, ...]
singlehead_rollouts.shape

In [None]:
token_idx = -1
fig, axes = plt.subplots(1, singlehead_rollouts.shape[0], figsize=(singlehead_rollouts.shape[0] * 2, 2))
plt.subplots_adjust(wspace=0.0)
for i in range(singlehead_rollouts.shape[0]):
    axes[i].imshow(singlehead_rollouts[i, token_idx], cmap="magma")
    axes[i].set_axis_off()
fig.supxlabel("Layer", fontsize=20, y=-0.05, x=0.50);

In [None]:
attn_maps = extract_attn_maps(
    pft_model.model,
    series_fn[0:1],
    context_length=1024,
    linear_attn=False,
)
len(attn_maps), attn_maps[0].shape, attn_maps[1].shape

In [None]:
sample_idx = -1
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i in range(8):
    for j in range(8):
        amap = attn_maps[2 * j + 1][sample_idx, i].detach().cpu().numpy()
        axes[i, j].imshow(amap, cmap="magma")
        axes[i, j].set_axis_off()
fig.supxlabel("Layer", fontsize=20, y=0.08, x=0.51)
fig.supylabel("Head", fontsize=20, x=0.1);

In [None]:
multihead_rollouts = multi_head_attn_rollout(
    pft_model.model, series_fn[0:1], context_length=1024, attention_type="channel"
)[:, 0, ...]
multihead_rollouts.shape

In [None]:
token_idx = -1
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i in range(8):
    for j in range(8):
        amap = multihead_rollouts[j, token_idx, i]
        axes[i, j].imshow(amap, cmap="magma")
        axes[i, j].set_axis_off()

fig.supxlabel("Layer", fontsize=20, y=0.08, x=0.51)
fig.supylabel("Head", fontsize=20, x=0.1);

In [None]:
def sine(
    ts: np.ndarray,
    freqs: float | np.ndarray,
    phi: float | np.ndarray = 0.0,
    amp: float = 1.0,
) -> np.ndarray:
    freqs = freqs if isinstance(freqs, float) else freqs[..., None]
    phi = phi if isinstance(phi, float) else phi[..., None]
    return amp * np.sin(freqs * ts + phi).transpose(0, 2, 1)

In [None]:
min_freq = 1 / 4
max_freq = 2
resolution = 8
freqs = (
    2
    * np.pi
    * np.stack(
        np.mgrid[min_freq : max_freq : resolution * 1j, min_freq : max_freq : resolution * 1j],
        axis=-1,
    ).reshape(resolution * resolution, 2)
)

ts = np.linspace(0, 100, 4096)
sines = sine(ts, freqs, phi=np.array([0.0, 1.0]))
freqs.shape, sines.shape

In [None]:
show_forecast(
    pft_model,
    sines[-2:-1],
    context_length=512,
    prediction_length=512,
    limit_prediction_length=False,
    sliding_context=True,
    verbose=False,
)

In [None]:
rollouts = single_head_attn_rollout(
    pft_model.model,
    sines,
    context_length=512,
    attention_type="temporal",
    start_layer=-2,
)[0].mean(axis=1)  # average over channels/patches

In [None]:
rollouts.shape

In [None]:
fig, axes = plt.subplots(8, 8, figsize=(10, 10))
plt.subplots_adjust(wspace=0.0, hspace=0.05)
for i, ax in enumerate(axes.flatten()):
    if i % resolution > i // resolution:
        ax.set_axis_off()
        continue
    ax.imshow(rollouts[i], cmap="magma")
    ax.set_axis_off()
plt.show()

In [None]:
def sweep_sines_freq_and_phase(
    model,
    freqs: np.ndarray,
    phases: np.ndarray,
    batch_size: int = 1024,
    context_length: int = 1024,
    attention_type: str = "temporal",
):
    """
    Sweep through a grid of frequencies and phases, and compute the rollout of the model for each pair.
    """
    freq_resolution = freqs.shape[0]
    phase_resolution = phases.shape[0]
    response = np.zeros(freq_resolution * phase_resolution)
    sweepinds = np.indices([freq_resolution, phase_resolution]).reshape(2, -1)
    dummy = np.ones(min(batch_size, freq_resolution * phase_resolution))

    for i in tqdm(
        range(0, freq_resolution * phase_resolution, batch_size),
        desc="Processing batch sweeps",
    ):
        freq_batch = np.stack([freqs[sweepinds[0, i : i + batch_size]], dummy], axis=-1)
        phases_batch = np.stack([phases[sweepinds[1, i : i + batch_size]], dummy], axis=-1)
        sines_batch = sine(ts, freq_batch, phi=phases_batch)

        # shape: (batch_size, num_channels, context_length//patch_length, context_length//patch_length)
        rollouts = single_head_attn_rollout(
            model,
            sines_batch,
            context_length=context_length,
            attention_type=attention_type,
            start_layer=-1,
        )[0].mean(axis=1)
        response[i : i + batch_size] = np.linalg.norm(rollouts, axis=(1, 2), ord=2)

        torch.cuda.empty_cache()
        del rollouts

    return response.reshape(freq_resolution, phase_resolution)


def sweep_sines_freqs(
    model,
    freqs1: np.ndarray,
    freqs2: np.ndarray,
    batch_size: int = 1024,
    context_length: int = 1024,
    attention_type: str = "temporal",
    num_waves: int = 2,
    amp: float = 1.0,
    phase: float | None = None,
):
    """
    Sweep through a grid of frequencies and phases, and compute the rollout of the model for each pair.
    """
    freq_resolution1 = freqs1.shape[0]
    freq_resolution2 = freqs2.shape[0]
    response = np.zeros(freq_resolution1 * freq_resolution2)
    sweepinds = np.indices([freq_resolution1, freq_resolution2]).reshape(2, -1)
    dummy_freqs = np.random.rand(min(batch_size, freq_resolution1 * freq_resolution2), num_waves - 2)
    dummy_phases = np.random.rand(num_waves)
    dummy_phases[0] = 0
    if phase is not None:
        dummy_phases[1] = phase

    for i in tqdm(
        range(0, freq_resolution1 * freq_resolution2, batch_size),
        desc="Processing batch sweeps",
    ):
        freq_batch = np.stack(
            [
                freqs1[sweepinds[0, i : i + batch_size]],
                freqs2[sweepinds[1, i : i + batch_size]],
            ],
            axis=-1,
        )
        freq_batch = np.c_[freq_batch, dummy_freqs]
        sines_batch = sine(
            ts,
            freq_batch,
            phi=dummy_phases.reshape(1, -1),
            amp=amp,
        )

        # shape: (batch_size, num_channels, context_length//patch_length, context_length//patch_length)
        rollouts = single_head_attn_rollout(
            model,
            sines_batch,
            context_length=context_length,
            attention_type=attention_type,
            start_layer=-1,
        )[0]
        # response[i : i + batch_size] = interaction_index(rollouts).mean(axis=1)
        response[i : i + batch_size] = mean_row_entropy(rollouts).mean(axis=1)
        torch.cuda.empty_cache()
        del rollouts

    return response.reshape(freq_resolution1, freq_resolution2)

In [None]:
resolution = 256
bounds = (0.5, 2.5)
freqs1 = 2 * np.pi * np.linspace(bounds[0], bounds[1], resolution)
freqs2 = 2 * np.pi * np.linspace(bounds[0], bounds[1], resolution)

response = sweep_sines_freqs(
    pft_model.model,
    freqs1,
    freqs2,
    batch_size=4096,
    context_length=512,
    attention_type="temporal",
    num_waves=2,
    amp=1,
    phase=np.pi,
)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(
    response,
    cmap="inferno",
    origin="lower",
    extent=(0, freqs1.shape[0], 0, freqs2.shape[0]),
)

tlocs1 = np.arange(0, freqs1.shape[0] + freqs1.shape[0] // 4, freqs1.shape[0] // 4)
tlocs1[-1] -= 1

tlocs2 = np.arange(0, freqs2.shape[0] + freqs2.shape[0] // 4, freqs2.shape[0] // 4)
tlocs2[-1] -= 1

lower = bounds[0] * 2
upper = bounds[1] * 2
locs = np.linspace(lower, upper, 5)
locs = [l if not l.is_integer() else int(l) for l in locs]
locs_str = [r"$\mathbf{" + (str(s) if s != 1 else "") + r"\pi}$" if s != "0" else "0" for s in locs]

plt.yticks(tlocs1, locs_str, fontweight="bold", fontsize=16)
plt.xticks(tlocs2, locs_str, fontweight="bold", fontsize=16)
cbar = plt.colorbar(shrink=0.8175, pad=0.0)
cbar.ax.tick_params(labelsize=16)

plt.tight_layout()


plt.savefig("../figures/nonlin_resonance.pdf", dpi=300)

In [None]:
def sine_bispectra_scaling(
    checkpoint_dir: str,
    freqs1: np.ndarray,
    freqs2: np.ndarray,
    context_length: int = 512,
    batch_size: int = 2048,
    attention_type: str = "temporal",
    num_trials: int = 5,
    default_seed: int = 0,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Plots the bispectra across all checkpoints in the given directory.

    Assumes that the checkpoint directory contains checkpoint-{i} folders for i in [0, 1, 2, ...]
    and a single checkpoint-final folder.
    """
    checkpoints = os.listdir(checkpoint_dir)
    checkpoints.remove("checkpoint-final")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))
    iterations = np.array([int(checkpoint.split("-")[-1]) for checkpoint in checkpoints])
    rngs = np.random.default_rng(default_seed).spawn(len(checkpoints))

    mean_interactions = np.zeros(len(checkpoints))
    std_interactions = np.zeros(len(checkpoints))
    for i, checkpoint in enumerate(tqdm(checkpoints, desc="Processing checkpoints")):
        model = PatchTSTPipeline.from_pretrained(
            mode="predict",
            pretrain_path=f"{checkpoint_dir}/{checkpoint}",
            device_map="cuda:0",
        )
        # shape: (resolution, resolution)
        trial_interactions = np.zeros(num_trials)
        for j, randphase in enumerate(rngs[i].uniform(0, 2 * np.pi, num_trials)):
            response = sweep_sines_freqs(
                model,
                freqs1,
                freqs2,
                batch_size=batch_size,
                context_length=context_length,
                attention_type=attention_type,
                phase=randphase,
                num_waves=2,
            )
            # Compute normalized off-diagonal activity
            diag_mask = np.eye(response.shape[0], dtype=bool)
            off_diag_norm = np.linalg.norm(response[~diag_mask])
            diag_norm = np.linalg.norm(response[diag_mask])
            trial_interactions[j] = off_diag_norm / diag_norm
        mean_interactions[i] = trial_interactions.mean()
        std_interactions[i] = trial_interactions.std()

        # cleanup manually
        del model
        torch.cuda.empty_cache()

    return iterations, mean_interactions, std_interactions

In [None]:
resolution = 128
freqs1 = 2 * np.pi * np.linspace(bounds[0], bounds[1], resolution)
freqs2 = 2 * np.pi * np.linspace(bounds[0], bounds[1], resolution)

work_dir = os.environ["WORK"]
iterations, mean_interactions, std_interactions = sine_bispectra_scaling(
    f"{work_dir}/checkpoints/pft_rff496_proj-0",
    freqs1,
    freqs2,
    context_length=512,
    batch_size=4096,
    attention_type="temporal",
    num_trials=2,
)

In [None]:
output_dir = "../outputs/modemix"
os.makedirs(output_dir, exist_ok=True)
np.save(f"{output_dir}/sine_mean_interactions.npy", mean_interactions)
np.save(f"{output_dir}/sine_std_interactions.npy", std_interactions)
np.save(f"{output_dir}/sine_iterations.npy", iterations)

In [None]:
mean_interactions, std_interactions, iterations = (
    np.load(f"{output_dir}/sine_mean_interactions.npy"),
    np.load(f"{output_dir}/sine_std_interactions.npy"),
    np.load(f"{output_dir}/sine_iterations.npy"),
)

In [None]:
checkpoint_cutoff = 0
plt.figure(figsize=(10, 4))
plt.plot(
    iterations[checkpoint_cutoff:],
    mean_interactions[checkpoint_cutoff:],
    marker="o",
    linestyle="-.",
)
plt.fill_between(
    iterations[checkpoint_cutoff:],
    mean_interactions[checkpoint_cutoff:] - std_interactions[checkpoint_cutoff:],
    mean_interactions[checkpoint_cutoff:] + std_interactions[checkpoint_cutoff:],
    alpha=0.2,
)
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
plt.xticks(iterations[checkpoint_cutoff:])
plt.xlabel("Checkpoint Iteration", fontweight="bold")
plt.ylabel("Response Interaction Index", fontweight="bold")
plt.savefig("../figures/sine_bispectra_scaling.pdf", dpi=300)

In [None]:
from glob import glob

from panda.utils import get_system_filepaths, load_trajectory_from_arrow


def data_bispectra_scaling(
    checkpoint_dir: str,
    data_dir: str,
    context_length: int = 512,
    num_windows: int = 10,
    attention_type: str = "temporal",
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Computes the bispectra across all checkpoints in the given directory.

    Assumes that the checkpoint directory contains checkpoint-{i} folders for i in [0, 1, 2, ...]
    and a single checkpoint-final folder.
    """
    checkpoints = os.listdir(checkpoint_dir)
    checkpoints.remove("checkpoint-final")
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))
    iterations = np.array([int(checkpoint.split("-")[-1]) for checkpoint in checkpoints])
    system_dirs = glob(f"{data_dir}/*")

    mean_interactions = np.zeros(len(checkpoints))
    std_interactions = np.zeros(len(checkpoints))
    for i, checkpoint in enumerate(tqdm(checkpoints, desc="Processing checkpoints")):
        model = PatchTSTPipeline.from_pretrained(
            mode="predict",
            pretrain_path=f"{checkpoint_dir}/{checkpoint}",
            device_map="cuda:0",
        )
        trial_interactions = np.zeros(len(system_dirs))
        for j, system_dir in enumerate(system_dirs):
            # pick a single random system to evaluate
            syspaths = get_system_filepaths(system_dir, data_dir, "test_zeroshot")
            syspath = np.random.choice(syspaths)
            traj, _ = load_trajectory_from_arrow(syspath)

            # split into random context windows
            window_idxs = np.random.randint(0, traj.shape[1] - context_length, num_windows)
            windows = np.stack([traj[:, w : w + context_length].T for w in window_idxs])

            attn_weights = extract_attn_maps(
                model.model,
                windows,
                context_length,
                linear_attn=False,
            )
            attnmap = attn_weights[-2 if attention_type == "temporal" else -1].cpu().detach().numpy()
            interactions = mean_row_entropy(attnmap)
            # rollouts = single_head_attn_rollout(
            #     model, windows, context_length, attention_type
            # )[0]
            # interactions = mean_row_entropy(rollouts)
            trial_interactions[j] = interactions.mean()
        mean_interactions[i] = trial_interactions.mean()
        std_interactions[i] = trial_interactions.std() / np.sqrt(trial_interactions.shape[0])
        # cleanup manually
        del model
        torch.cuda.empty_cache()

    return iterations, mean_interactions, std_interactions

In [None]:
work_dir = os.environ["WORK"]
iterations, mean_interactions, std_interactions = data_bispectra_scaling(
    f"{work_dir}/checkpoints/pft_rff496_proj-0",
    f"{work_dir}/data/improved/final_skew40/test_zeroshot",
    context_length=512,
    attention_type="temporal",
    num_windows=5,
)

In [None]:
output_dir = "../outputs/modemix"
os.makedirs(output_dir, exist_ok=True)
np.save(f"{output_dir}/data_mean_interactions.npy", mean_interactions)
np.save(f"{output_dir}/data_std_interactions.npy", std_interactions)
np.save(f"{output_dir}/data_iterations.npy", iterations)

In [None]:
mean_interactions, std_interactions, iterations = (
    np.load(f"{output_dir}/data_mean_interactions.npy"),
    np.load(f"{output_dir}/data_std_interactions.npy"),
    np.load(f"{output_dir}/data_iterations.npy"),
)

In [None]:
checkpoint_cutoff = 0
plt.figure(figsize=(9, 4))
plt.plot(
    iterations[checkpoint_cutoff:],
    mean_interactions[checkpoint_cutoff:],
    marker="o",
    linestyle="-",
    linewidth=2,
    markersize=10,
)
plt.fill_between(
    iterations[checkpoint_cutoff:],
    mean_interactions[checkpoint_cutoff:] - std_interactions[checkpoint_cutoff:],
    mean_interactions[checkpoint_cutoff:] + std_interactions[checkpoint_cutoff:],
    alpha=0.2,
)
plt.ticklabel_format(axis="x", style="sci", scilimits=(0, 0))
plt.gca().xaxis.get_offset_text().set_fontsize(16)
plt.xticks(iterations[checkpoint_cutoff:], fontsize=16)
plt.yticks(fontsize=16)
plt.xlabel("Checkpoint Iteration", fontweight="bold", fontsize=16)
plt.ylabel("Final layer row-entropy", fontweight="bold", fontsize=16)
plt.savefig("../figures/data_bispectra_scaling.pdf", dpi=300)