 # Fetching activations from a pretrained model

In [1]:
import os
from functools import partial
from pathlib import Path

import git
import numpy as np
import pandas as pd
import torch
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

rc_fonts = {
    "text.usetex": True,
    "text.latex.preamble": "\n".join(
        [r"\usepackage{libertine}", r"\usepackage[libertine]{newtxmath}", r"\usepackage{inconsolata}"]
    ),
}
mpl.rcParams.update(rc_fonts)
plt.rcParams.update(rc_fonts)
sns.set_theme(style="white")

 ### Functions for loading model and fetching activations

In [3]:
def load_model_tokenizer(model_name: str):
    """Load a pretrained model and tokenizer from huggingface.

    Args:
        model_name: The name of the pretrained model to load.

    Returns:
        model: The loaded model.
        tokenizer: The loaded tokenizer.
    """
    model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    model.eval()
    return model, tokenizer

In [4]:
def fetch_reprs(model, input_ids: torch.Tensor, bias_norms: list[int]):
    bsz, seq_len = input_ids.shape

    bias_results = {}

    for in_idx, input in tqdm.notebook.tqdm(enumerate(input_ids), desc="inputs...", total=bsz):
        bias_results[in_idx] = {b: {} for b in bias_norms}

        def emb_hook(self, inputs, output, bias, store_ctx):
            b_u = torch.randn(model.config.d_model)
            b = (b_u / torch.linalg.vector_norm(b_u)) * bias

            b = b.to(output.device)

            output = output + b.unsqueeze(0)
            store_ctx[bias]["before_raw"] = output
            store_ctx[bias]["before"] = torch.linalg.vector_norm(output).item()
            return output

        def trn_hook(self, inputs, output, bias, store_ctx):
            output = output[0]
            store_ctx[bias]["after"] = torch.linalg.vector_norm(output).item()
            store_ctx[bias]["after_raw"] = output

        for N in tqdm.notebook.tqdm(bias_norms, desc="Biases...", leave=False):
            handles = [
                model.pico_decoder.embedding_proj.register_forward_hook(
                    partial(emb_hook, bias=N, store_ctx=bias_results[in_idx])
                ),
                model.pico_decoder.layers[-1].register_forward_hook(
                    partial(trn_hook, bias=N, store_ctx=bias_results[in_idx])
                ),
            ]

            with torch.inference_mode():
                model(input.unsqueeze(0))

            for handle in handles:
                handle.remove()

    return bias_results


In [5]:
models = {
    r"\texttt{pico-relora}": {
        "repo_url": "https://huggingface.co/yuvalw/pico-relora-tiny-v2-sn",
        "commit_hash": "f567508a0d79eb738a23a117a6753650a9969405",
        "clone_dir": Path("../runs/rem-relora-tiny-20k"),
    },
    r"\texttt{pico-decoder}": {
        "repo_url": "https://huggingface.co/pico-lm/pico-decoder-tiny",
        "commit_hash": "9c866480418eeeb9946b0aaf4318bcdde46db83f",
        "clone_dir": Path("../runs/rem-decoder-tiny-20k"),
    },
}

In [6]:
for d in models.values():
    if not d["clone_dir"].exists():
        repo = git.Repo.clone_from(d["repo_url"], d["clone_dir"])
    else:
        repo = git.Repo(d["clone_dir"])

    repo.git.checkout(d["commit_hash"])

In [7]:
for mname in models:
    model, tokenizer = load_model_tokenizer(models[mname]["clone_dir"])
    models[mname]["model"] = model.to("mps")
    models[mname]["tokenizer"] = tokenizer

[2025-05-23 21:48:12,703] [INFO] [real_accelerator.py:222:get_accelerator] Setting ds_accelerator to mps (auto detect)


W0523 21:48:12.901000 67203 torch/distributed/elastic/multiprocessing/redirects.py:29] NOTE: Redirects are currently not supported in Windows or MacOs.


In [8]:
# generate random indices, both models have the same tokenizer
tokenizer = models[r"\texttt{pico-relora}"]["tokenizer"]

special_ids = set(tokenizer.added_tokens_decoder.keys())
vocab_size = tokenizer.vocab_size

valid_ids = [i for i in range(vocab_size) if i not in special_ids]

# match godey
batch_size = 16
seq_len = 512

random_ids = torch.stack(
    [torch.tensor(np.random.choice(valid_ids, size=seq_len), dtype=torch.long) for _ in range(batch_size)]
)

In [9]:
results = {}
for mname in models:
    model = models[mname]["model"]
    res = fetch_reprs(model, random_ids.to(model.device), range(0, 51, 5))
    results[mname] = res

inputs...:   0%|          | 0/16 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

inputs...:   0%|          | 0/16 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

Biases...:   0%|          | 0/11 [00:00<?, ?it/s]

In [11]:
# transform data to dataframe
def _get_data():
    for mname, model_dict in results.items():
        for grp_idx, grp_dict in model_dict.items():
            for bias_size, res_dict in grp_dict.items():
                # remember that rows are the destination index that wants info from other tokens
                # source tokens are being attended to
                before = res_dict["before"]
                after = res_dict["after"]
                yield (mname, grp_idx, bias_size, before, after)


data_df = pd.DataFrame(
    _get_data(),
    columns=["Model", "SeqId", "BiasNorm", "Before", "After"],
)

In [None]:
def _get_norm_data():
    for row in data_df.groupby(by=["Model", "BiasNorm"]).agg("mean").reset_index().itertuples():
        yield row.Model, row.BiasNorm, "Input", row.Before
        yield row.Model, row.BiasNorm, "Output", row.After


norm_data_df = pd.DataFrame(_get_norm_data(), columns=["Model", "BiasNorm", "When", "ReprNorm"])

In [None]:
colors = ["#004D40", "#D81B60"]
palette = dict(zip(["Input", "Output"], colors))

dodge = 0.4
size = 8
font_scale = 3

sns.set_theme(font_scale=font_scale, rc={"axes.grid": "True", "axes.grid.which": "both"})

g = sns.FacetGrid(
    data=norm_data_df,
    col="Model",
    # row="Component",
    height=8,
    aspect=1,
    sharey=True,
    sharex=True,
    margin_titles=False,
)

# g.set_titles(template=r"Model = \texttt{{{col_name}}}")


g.map_dataframe(sns.lineplot, x="BiasNorm", y="ReprNorm", hue="When", palette=palette, markers=["s", "o"])

g.set_axis_labels("Checkpoint step / 1000", "PER")

for idx, ax in enumerate(g.axes.flat):
    # steps = no["Step"].unique()
    # ax.set_xticks(steps)
    if idx == 0:
        ax.legend(bbox_to_anchor=(0, 1), loc="upper left", ncols=2)

plt.savefig("../graphs/anisotropy-norm.pdf", bbox_inches="tight")
plt.show()