In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

from dystformer.patchtst.pipeline import PatchTSTPipeline

In [None]:
pft_model = PatchTSTPipeline.from_pretrained(
    mode="predict",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/run-400/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
def get_attn_weights(model, key: str) -> list[dict[str, torch.Tensor]]:
    params = [
        {
            "Wq": getattr(l, key).q_proj.weight,
            "Wk": getattr(l, key).k_proj.weight,
            "Wv": getattr(l, key).v_proj.weight,
        }
        for l in model.model.model.encoder.layers  # lol
    ]
    return params


def get_attn_map(
    weights: list[dict[str, torch.Tensor]], index: int, shift: bool = False
) -> np.ndarray:
    attn_map = (weights[index]["Wq"] @ weights[index]["Wk"].T).detach().cpu().numpy()
    if shift:
        attn_map = (attn_map - np.min(attn_map)) / (np.max(attn_map) - np.min(attn_map))
    return attn_map


def symmetric_distance(attn_map: np.ndarray) -> float:
    return (
        0.5
        * np.linalg.norm(attn_map - attn_map.T, "fro")
        / np.linalg.norm(attn_map, "fro")
    )


In [None]:
temporal_weights = get_attn_weights(pft_model, "temporal_self_attn")
channel_weights = get_attn_weights(pft_model, "channel_self_attn")

In [None]:
attn_map = get_attn_map(temporal_weights, 0)
print(symmetric_distance(attn_map))
plt.figure()
plt.imshow(np.log(attn_map**2), cmap="RdBu")
plt.colorbar()
plt.show()

In [None]:
attn_map = get_attn_map(channel_weights, 0)
print(symmetric_distance(attn_map))
plt.figure()
plt.imshow(np.log(attn_map**2), cmap="RdBu")
plt.colorbar()
plt.show()

In [None]:
llayer = pft_model.model.model.encoder.layers[0].ff
print(llayer)
ffw = llayer[0].weight.detach().cpu().numpy()
print(symmetric_distance(ffw))

U, S, V = np.linalg.svd(ffw)
threshold = 1e-3
rank = np.sum(S > threshold)
plt.figure()
plt.plot(range(1, len(S) + 1), S, "o-", linewidth=2)
plt.title("Scree Plot of Singular Values")
plt.xlabel("Singular Value Index")
plt.ylabel("Singular Value Magnitude")
plt.grid(True)
plt.yscale("log")  # Log scale to better visualize the decay
plt.axhline(
    y=threshold, color="r", linestyle="--", label=f"Threshold ({threshold:.1e})"
)
plt.legend()
plt.show()

reconstructed = U[:, :rank] @ np.diag(S)[:rank, :rank] @ V[:rank, :]
plt.figure()
plt.imshow(np.log(reconstructed**2), cmap="RdBu")
plt.colorbar()
plt.show()

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for i, ax in enumerate(axes.flatten()):
    attn_map = get_attn_map(temporal_weights, i)
    ax.imshow(attn_map, cmap="RdBu")
    ax.set_title(f"Layer {i}")
plt.tight_layout()
plt.show()


In [None]:
mlm_model = PatchTSTPipeline.from_pretrained(
    mode="pretrain",
    pretrain_path="/stor/work/AMDG_Gilpin_Summer2024/checkpoints/mlm40_stand_nonoiser-1/checkpoint-final",
    device_map="cuda:0",
)

In [None]:
channel_weights = get_attn_weights(mlm_model, "channel_self_attn")
temporal_weights = get_attn_weights(mlm_model, "temporal_self_attn")

In [None]:
attn_map = get_attn_map(temporal_weights, 3)
print(symmetric_distance(attn_map))
plt.figure()
plt.imshow(np.log(attn_map**2), cmap="RdBu")
plt.colorbar()
plt.show()


In [None]:
attn_map = get_attn_map(channel_weights, 3)
print(symmetric_distance(attn_map))
plt.figure()
plt.imshow(np.log(attn_map**2), cmap="RdBu")
plt.colorbar()
plt.show()


In [None]:
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
for i, ax in enumerate(axes.flatten()):
    attn_map = get_attn_map(channel_weights, i)
    ax.imshow(attn_map, cmap="RdBu")
    ax.set_title(f"Layer {i}")
plt.tight_layout()
plt.show()