In [1]:
import os
import torch
import plotly.graph_objects as go
from collections import defaultdict

os.environ["WANDB_SILENT"] = "true"
os.environ["TQDM_DISABLE"] = "1"

N_SAMPLES = 5
FORWARD_SIZE = 256

models = {
    "base": {
        "base": "e96b8h5a",
        "pretrained": "e1tosi4k",
        "frozen": "t4o7wvla",
    },
    "ob": {
        "base": "swcod025",
        "pretrained": "xzlusbyu",
        "frozen": "aqzn99cx",
    },
    "pob": {
        "base": "wozkyaa6",
        "pretrained": "daf0h543",
        "frozen": "47y92682",
    },
}

In [None]:
def get_hook(name, base):
    def base_hook(module, input, output):
        if isinstance(output, torch.Tensor):
            base_activations[name].append(output)

    def pretrained_hook(module, input, output):
        if isinstance(output, torch.Tensor):
            pretrained_activations[name].append(output)

    if base:
        return base_hook
    else:
        return pretrained_hook


for (name_b, module_b), (name_p, module_p) in zip(
    base_model.named_modules(), pretrained_model.named_modules()
):
    module_b.register_forward_hook(get_hook(name_b, base=True))
    module_p.register_forward_hook(get_hook(name_p, base=False))


In [3]:
def linear_CKA(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
    def center_gram(K: torch.Tensor) -> torch.Tensor:
        n = K.size(0)
        H = torch.eye(n, device=K.device) - torch.ones(n, n, device=K.device) / n
        return H @ K @ H

    X = X - X.mean(dim=0, keepdim=True)
    Y = Y - Y.mean(dim=0, keepdim=True)

    K = X @ X.T
    L = Y @ Y.T

    Kc = center_gram(K)
    Lc = center_gram(L)

    hsic = (Kc * Lc).sum()
    norm_x = (Kc * Kc).sum().sqrt()
    norm_y = (Lc * Lc).sum().sqrt()

    return hsic / (norm_x * norm_y + 1e-12)

In [None]:
sims = []

base_activations = defaultdict(list)
pretrained_activations = defaultdict(list)

for a, sample in enumerate(dataset["test"].select(range(N_SAMPLES))):
    torch.mps.empty_cache()
    input = tokenize(base_model, sample, tokenizer)

    pretrained_model(
        input_ids=input["input_ids"][:FORWARD_SIZE].to("mps"),
        attention_mask=input["attention_mask"][:FORWARD_SIZE].to("mps"),
    )

    base_model(
        input_ids=input["input_ids"][:FORWARD_SIZE].to("mps"),
        attention_mask=input["attention_mask"][:FORWARD_SIZE].to("mps"),
    )

    sim_matrix = torch.zeros((len(pretrained_activations), len(pretrained_activations)))
    loop = tqdm(
        total=len(pretrained_activations) ** 2,
        desc=f"{a + 1}/{N_SAMPLES} Computing RSA w/ CKA",
        leave=True,
    )

    for i, (k_b, v_b) in enumerate(base_activations.items()):
        base_act = v_b[-1].reshape(v_b[-1].shape[0], -1)

        for j, (k_p, v_p) in enumerate(pretrained_activations.items()):
            pretrained_act = v_p[-1].reshape(v_p[-1].shape[0], -1)
            sim_matrix[i, j] = linear_CKA(base_act.detach(), pretrained_act.detach())
            loop.update(1)

    sims.append(sim_matrix)

sim_matrix = torch.stack(sims).mean(dim=0)

1/5 Computing RSA w/ CKA: 100%|██████████| 11664/11664 [00:20<00:00, 575.36it/s]
2/5 Computing RSA w/ CKA: 100%|██████████| 11664/11664 [00:24<00:00, 485.34it/s]
3/5 Computing RSA w/ CKA: 100%|██████████| 11664/11664 [00:20<00:00, 561.29it/s]
4/5 Computing RSA w/ CKA: 100%|██████████| 11664/11664 [00:21<00:00, 544.89it/s]
5/5 Computing RSA w/ CKA: 100%|█████████▉| 11657/11664 [00:20<00:00, 705.13it/s]

In [None]:
def custom_hover_text(x, y, value):
    return f"CKA: {value:.3}<br>Base: {list(base_activations.keys())[x]}<br>Pretrained: {list(pretrained_activations.keys())[y]}<br>({x}, {y})"


nrows, ncols = sim_matrix.shape
customdata = np.empty((nrows, ncols), dtype=object)

for i in range(nrows):
    for j in range(ncols):
        customdata[i, j] = custom_hover_text(i, j, sim_matrix[i, j])

fig = go.Figure(
    data=go.Heatmap(
        z=sim_matrix,
        colorscale="Viridis",
        customdata=customdata,
        hovertemplate="%{customdata}<extra></extra>",
        showscale=True,
    ),
)

fig.update_layout(
    xaxis=dict(showticklabels=False),
    yaxis=dict(showticklabels=False, scaleanchor="x", scaleratio=1),
    plot_bgcolor="rgba(255,255,255,255)",
    paper_bgcolor="rgba(255,255,255,255)",
    margin=dict(t=30, b=10, l=10, r=10),
    title=dict(
        text=best_model.ID,
        xref="paper",
        yref="container",
        yanchor="top",
        x=0.5,
        automargin=True,
        pad=dict(t=5),
    ),
)

fig.show()

5/5 Computing RSA w/ CKA: 100%|██████████| 11664/11664 [00:33<00:00, 705.13it/s]