In [1]:
from __future__ import annotations
import os
import torch
 llm_utils.activation_generator import ActivationGenerator
from data_utils.concept_dataset import SupervisedConceptDataset
from torch.utils.data import DataLoader, TensorDataset, random_split


# ## Configuration
#
# - `data_path`: Path to the dataset / examples used for training and analysis.
# - `model_name`: A **TransformerLens-supported** model to extract activations from. We default to a small model (`gpt2-small`) for fast iteration; swap in larger models if you have GPU memory/compute.
# - `layers`: Which layer to inspect and factorize.
# - `data_device`: Where data tensors live during preprocessing (CPU by default).
# - `model_device`: Where the model runs for activation extraction and generation. Use `mps` on Apple Silicon, `cuda` on NVIDIA GPUs, or `cpu` if needed.
# - `factorization_mode`: Which activation stream to factorize:
#   - `residual`: a general-purpose choice that often yields clean, interpretable structure.
#

In [4]:


data_path = "./data/supervised.json"
model_name = "gpt2-small"
layers = [4]
data_device = "cuda"
model_device = "cuda"
factorization_mode = "residual"


# ### Loading and Generating Data
#
# In this tutorial we use our own abstractions for generating activations and loading data. In the end, you need to generate a loader for training MFA. Feel free to swap out with a different method.

In [5]:


act_generator = ActivationGenerator(
    model_name,
    model_device=model_device,
    data_device=data_device,
    mode=factorization_mode,
    initialize=True,
)

Loaded pretrained model gpt2-small into HookedTransformer


In [6]:


dataset_obj = SupervisedConceptDataset(data_path)

In [7]:


activations, _ = act_generator.generate_multiple_layer_activations_and_freq(
    dataset_obj, layers
)


# We additionally load tokens in order to later interpret the subspaces.

Building vocab frequency: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3600/3600 [00:00<00:00, 37904.14it/s]
Generating multi-layer activations with freq: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3600/3600 [01:35<00:00, 37.51it/s]


In [8]:


from llm_utils.activation_generator import extract_token_ids_sample_ids_and_labels

tokens, _, _ = extract_token_ids_sample_ids_and_labels(dataset_obj, act_generator)


# Creating the loaders from extracted activations
# To make this notebook work on lower compute we utilize only 250k activations.

Extracting token IDs: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3600/3600 [00:03<00:00, 1176.96it/s]


In [9]:


# your raw data
X_all = activations[0][0:250_000]
tokens = tokens[0:250_000]

# make a single dataset
full_ds = TensorDataset(X_all, tokens)

loader = DataLoader(
    full_ds,
    batch_size=128,
    shuffle=True,  # always shuffle your training set
    pin_memory=True,
)

# if you still want a standalone token loader (e.g. for some other pass):
token_loader = DataLoader(tokens, batch_size=128)


# ### Initialization
#
# As described in the paper we tested three options for initialization.
# We found that K-Means often works well, with random point initialization also successful (random weights often fail).
# In this tutorial we show how to use K-Means as its the most complicated of the three, and we provide an implementation that works on torch.

# We must decide on how much of the data to run our K-Means. Since K-Means is slower, our implmentation allows to decide a pool size which will be randomly sampled. Additionally, for efficiency it uses a projected K-Means.
#
# In this tutorial we use the 20% dataset which consists of 600k activations in order to speed it up.

In [10]:


pool_size = round(len(loader.dataset) / 5)
pool_size


# We use 500 centroids, this is an arbitrary number and you can reduce it to capture more broad subspaces or increase to produce more semantic covariances.
#
# Should run in 3-5 minutes. For shorter runtime, sample points as the centroids (second cell)

50000

In [11]:


from initializations.projected_knn import ReservoirKMeans

num_centroids = 500

knn = ReservoirKMeans(
    num_centroids,
    pool_size=pool_size,
    vocab_size=50257,
    device=model_device,
    proj_dim=32,
)
centroids = knn.fit(loader)

In [12]:


# random points
N = X_all.shape[0]
idx = torch.randperm(N, device=X_all.device)[
    :num_centroids
]  # sample without replacement
centroids = X_all[idx]


# ### Training
#
# We train using Negative Log Likelihood. We provided an implementation of a very simple training loop.
# We use R = 10 (covariance dim), feel free to experiment with different values. It mostly depends on the intrinsic dimension of the data.
#
# We train for 10 epochs, which is sufficient for the follow up interpretation and steering. For evaluations, would want to train until convergence.
#
# Should take about 10-15 minutes. Feel free to train for less epochs, a couple epochs are often enough to see results (depends on dataset size).

In [14]:


from modeling.mfa import MFA
from modeling.train import train_nll

model = MFA(centroids=centroids, rank=10).to(model_device)
train_nll(model, loader, epochs=50, lr=1e-3)


# ### Interpretation
#

Epoch 01 | Step 001900 Train NLL=1421.015905: : 1954it [00:05, 341.20it/s]


[epoch 01] train NLL=1418.286274  val NLL=nan ** best **


Epoch 02 | Step 001900 Train NLL=1288.547953: : 1954it [00:05, 341.56it/s]


[epoch 02] train NLL=1287.693992  val NLL=nan ** best **


Epoch 03 | Step 001900 Train NLL=1247.182531: : 1954it [00:05, 341.59it/s]


[epoch 03] train NLL=1246.881022  val NLL=nan ** best **


Epoch 04 | Step 001900 Train NLL=1221.958015: : 1954it [00:05, 341.59it/s]


[epoch 04] train NLL=1221.854058  val NLL=nan ** best **


Epoch 05 | Step 001900 Train NLL=1204.325863: : 1954it [00:05, 341.57it/s]


[epoch 05] train NLL=1204.190597  val NLL=nan ** best **


Epoch 06 | Step 001900 Train NLL=1191.228009: : 1954it [00:05, 341.79it/s]


[epoch 06] train NLL=1191.193339  val NLL=nan ** best **


Epoch 07 | Step 001900 Train NLL=1181.751503: : 1954it [00:05, 341.44it/s]


[epoch 07] train NLL=1181.772298  val NLL=nan ** best **


Epoch 08 | Step 001900 Train NLL=1174.310107: : 1954it [00:05, 341.48it/s]


[epoch 08] train NLL=1174.370478  val NLL=nan ** best **


Epoch 09 | Step 001900 Train NLL=1168.420521: : 1954it [00:05, 341.48it/s]


[epoch 09] train NLL=1168.277099  val NLL=nan ** best **


Epoch 10 | Step 001900 Train NLL=1163.551707: : 1954it [00:05, 341.37it/s]


[epoch 10] train NLL=1163.508294  val NLL=nan ** best **


Epoch 11 | Step 001900 Train NLL=1159.803898: : 1954it [00:05, 341.71it/s]


[epoch 11] train NLL=1159.612256  val NLL=nan ** best **


Epoch 12 | Step 001900 Train NLL=1156.218978: : 1954it [00:05, 341.75it/s]


[epoch 12] train NLL=1156.141009  val NLL=nan ** best **


Epoch 13 | Step 001900 Train NLL=1153.304783: : 1954it [00:05, 341.66it/s]


[epoch 13] train NLL=1153.194633  val NLL=nan ** best **


Epoch 14 | Step 001900 Train NLL=1150.748046: : 1954it [00:05, 341.67it/s]


[epoch 14] train NLL=1150.811943  val NLL=nan ** best **


Epoch 15 | Step 001900 Train NLL=1148.515049: : 1954it [00:05, 341.72it/s]


[epoch 15] train NLL=1148.395080  val NLL=nan ** best **


Epoch 16 | Step 001900 Train NLL=1146.195546: : 1954it [00:05, 341.78it/s]


[epoch 16] train NLL=1146.195315  val NLL=nan ** best **


Epoch 17 | Step 001900 Train NLL=1144.035929: : 1954it [00:05, 341.63it/s]


[epoch 17] train NLL=1144.031506  val NLL=nan ** best **


Epoch 18 | Step 001900 Train NLL=1142.230502: : 1954it [00:05, 341.58it/s]


[epoch 18] train NLL=1142.310246  val NLL=nan ** best **


Epoch 19 | Step 001900 Train NLL=1140.869365: : 1954it [00:05, 341.82it/s]


[epoch 19] train NLL=1140.782110  val NLL=nan ** best **


Epoch 20 | Step 001900 Train NLL=1139.314096: : 1954it [00:05, 340.31it/s]


[epoch 20] train NLL=1139.446246  val NLL=nan ** best **


Epoch 21 | Step 001900 Train NLL=1138.178250: : 1954it [00:05, 341.53it/s]


[epoch 21] train NLL=1138.182968  val NLL=nan ** best **


Epoch 22 | Step 001900 Train NLL=1136.769148: : 1954it [00:05, 341.53it/s]


[epoch 22] train NLL=1136.964856  val NLL=nan ** best **


Epoch 23 | Step 001900 Train NLL=1135.528299: : 1954it [00:05, 341.20it/s]


[epoch 23] train NLL=1135.858110  val NLL=nan ** best **


Epoch 24 | Step 001900 Train NLL=1134.757595: : 1954it [00:05, 340.94it/s]


[epoch 24] train NLL=1134.892211  val NLL=nan ** best **


Epoch 25 | Step 001900 Train NLL=1134.048775: : 1954it [00:05, 340.47it/s]


[epoch 25] train NLL=1134.068917  val NLL=nan ** best **


Epoch 26 | Step 001900 Train NLL=1133.039825: : 1954it [00:05, 341.27it/s]


[epoch 26] train NLL=1133.322301  val NLL=nan ** best **


Epoch 27 | Step 001900 Train NLL=1132.525923: : 1954it [00:05, 341.70it/s]


[epoch 27] train NLL=1132.586279  val NLL=nan ** best **


Epoch 28 | Step 001900 Train NLL=1132.037190: : 1954it [00:05, 341.47it/s]


[epoch 28] train NLL=1132.017979  val NLL=nan ** best **


Epoch 29 | Step 001900 Train NLL=1131.295171: : 1954it [00:05, 341.73it/s]


[epoch 29] train NLL=1131.477194  val NLL=nan ** best **


Epoch 30 | Step 001900 Train NLL=1130.983244: : 1954it [00:05, 341.74it/s]


[epoch 30] train NLL=1130.908918  val NLL=nan ** best **


Epoch 31 | Step 001900 Train NLL=1130.190547: : 1954it [00:05, 341.62it/s]


[epoch 31] train NLL=1130.315383  val NLL=nan ** best **


Epoch 32 | Step 001900 Train NLL=1129.660018: : 1954it [00:05, 341.73it/s]


[epoch 32] train NLL=1129.753168  val NLL=nan ** best **


Epoch 33 | Step 001900 Train NLL=1129.180140: : 1954it [00:05, 341.64it/s]


[epoch 33] train NLL=1129.294397  val NLL=nan ** best **


Epoch 34 | Step 001900 Train NLL=1128.687216: : 1954it [00:05, 341.39it/s]


[epoch 34] train NLL=1128.848197  val NLL=nan ** best **


Epoch 35 | Step 001900 Train NLL=1128.431843: : 1954it [00:05, 341.31it/s]


[epoch 35] train NLL=1128.341498  val NLL=nan ** best **


Epoch 36 | Step 001900 Train NLL=1127.789738: : 1954it [00:05, 341.45it/s]


[epoch 36] train NLL=1127.939815  val NLL=nan ** best **


Epoch 37 | Step 001900 Train NLL=1127.431546: : 1954it [00:05, 341.38it/s]


[epoch 37] train NLL=1127.590118  val NLL=nan ** best **


Epoch 38 | Step 001900 Train NLL=1127.237357: : 1954it [00:05, 341.52it/s]


[epoch 38] train NLL=1127.257268  val NLL=nan ** best **


Epoch 39 | Step 001900 Train NLL=1127.044427: : 1954it [00:05, 341.73it/s]


[epoch 39] train NLL=1126.939278  val NLL=nan ** best **


Epoch 40 | Step 001900 Train NLL=1126.237027: : 1954it [00:05, 341.79it/s]


[epoch 40] train NLL=1126.535949  val NLL=nan ** best **


Epoch 41 | Step 001900 Train NLL=1126.180649: : 1954it [00:05, 341.51it/s]


[epoch 41] train NLL=1126.218555  val NLL=nan ** best **


Epoch 42 | Step 001900 Train NLL=1125.559360: : 1954it [00:05, 341.74it/s]


[epoch 42] train NLL=1125.882275  val NLL=nan ** best **


Epoch 43 | Step 001900 Train NLL=1125.750348: : 1954it [00:05, 341.31it/s]


[epoch 43] train NLL=1125.634371  val NLL=nan ** best **


Epoch 44 | Step 001900 Train NLL=1125.426441: : 1954it [00:05, 341.61it/s]


[epoch 44] train NLL=1125.353298  val NLL=nan ** best **


Epoch 45 | Step 001900 Train NLL=1125.197917: : 1954it [00:05, 341.61it/s]


[epoch 45] train NLL=1125.115928  val NLL=nan ** best **


Epoch 46 | Step 001900 Train NLL=1124.748831: : 1954it [00:05, 341.65it/s]


[epoch 46] train NLL=1124.859256  val NLL=nan ** best **


Epoch 47 | Step 001900 Train NLL=1124.486846: : 1954it [00:05, 341.80it/s]


[epoch 47] train NLL=1124.592017  val NLL=nan ** best **


Epoch 48 | Step 001900 Train NLL=1124.355607: : 1954it [00:05, 341.55it/s]


[epoch 48] train NLL=1124.318234  val NLL=nan ** best **


Epoch 49 | Step 001900 Train NLL=1123.928176: : 1954it [00:05, 341.75it/s]


[epoch 49] train NLL=1124.087896  val NLL=nan ** best **


Epoch 50 | Step 001900 Train NLL=1123.820522: : 1954it [00:05, 341.46it/s]


[epoch 50] train NLL=1123.872130  val NLL=nan ** best **
Restored best model from epoch 50 with metric=1123.872130


{'best_epoch': 50, 'best_metric': 1123.8721303203124}

In [15]:
# We will first inspect the top likelihood samples per Gaussian, then will visualize the gaussians in order to show the within Gaussian separation

In [16]:


def my_token_to_str(tok_id):
    return act_generator.model.to_string(tok_id)


# To interpret the Gaussians, we calculate for each activation in the loader the likelihood for each Gaussian. Then we present the last token of the top likelihood samples, in order to understand the theme.
#
# This cell could take a bit to run, in our setup around 4 min.

In [17]:


from analysis.subspace_interpretation import get_top_strings_per_concept

results = get_top_strings_per_concept(
    model, loader, my_token_to_str, score="likelihood"
)


# We now view the top likelihood samples, we try to sample from them so that they don't all look very similar. To better understand the Gaussian its advised to look at the whole distribution of contexts that belong to the Gaussian.

In [None]:


import random

N_RESULTS = 25
N_LINES = 10
TOP_POOL = 5000
SEED = 0

random.seed(SEED)

for i, (r, w) in enumerate(list(results.items())[:N_RESULTS], start=0):
    pool = w[: min(TOP_POOL, len(w))]
    sample = random.sample(pool, k=min(N_LINES, len(pool)))  # no repeats

    print(f"\n[{i}]\n" + "-" * 40)
    for line in sample:
        print("  - " + str(line).replace("\n", "\\n"))


# #### Visualizing
#
# We move to visualizing the Gaussians, in order to see how the activations distribute within.
# To do so, we first calculate the latent dimensions (z) for each point, then plot with the loadings acting as the axes.

In [None]:


import analysis.subspace_visualization as sv

k_to_visualize = 20

coords = sv.project_loader_to_subspace(
    model, loader, k=k_to_visualize, token_to_str=my_token_to_str
)


# Here we visualize in 2d using two loadings. The variance is spread out across 10 dimensions and loadings do not necessarily reflect directions of maximal variance, rather together they serve as a basis for the subspace of maximal structured variance. To better understand the structure either plot in 3D, use PCA on the subspace, or lower R. Additional discussion in the paper!

In [None]:


sv.plot_subspace_scatter(coords, dims=(0, 8), max_labels=250)


# ## Steering
# In this part we will show how we can steer using both the centroids and the loadings.
# To show the effect of steering we will inspect the top promoted tokens of an intervention. While the centroids promote, loadings can also suppress, so it is useful to look at the top absolute logit difference.
#
# Important to note, steering produces different results across different models. Some models require weaker alpha values or different steering methods. Tweak the parameters if things don't work out. Will update soon with additional steering methods that work especially well for MFA using the fact that we know source and target distributions!

In [None]:


# import and helpers

from intervention.mfa_steering import MFASteerer


def print_logit_diff(model, logits_before, logits_after, top_k: int = 10):
    """
    Prints the tokens with the largest positive and negative logit changes
    (after - before) for the last position in the sequence.
    """
    # Select the logits for the last token in the sequence (shape: [vocab_size])
    logits_before_last = logits_before[0, -1, :]
    logits_after_last = logits_after[0, -1, :]

    # Compute the difference in logits (after - before)
    delta_logits = logits_after_last - logits_before_last

    # --- Top positive changes ---
    pos_vals, pos_idx = torch.topk(delta_logits, k=top_k)
    print(f"Top {top_k} positive logit changes:")
    for token_id, change in zip(pos_idx, pos_vals):
        token_str = model.to_str_tokens([token_id])  # adjust to your tokenizer API
        print(f"  Token: {token_str},  Δlogit: {change.item():.4f}")

    # --- Top negative changes ---
    # by taking topk of -delta_logits we get the most negative values
    neg_vals, neg_idx = torch.topk(-delta_logits, k=top_k)
    print(f"\nTop {top_k} negative logit changes:")
    for token_id, neg_change in zip(neg_idx, neg_vals):
        token_str = model.to_str_tokens([token_id])
        # negate neg_change to show the actual delta_logits value
        print(f"  Token: {token_str},  Δlogit: {-neg_change.item():.4f}")


def get_logit_diff(model, logits_before, logits_after):
    # Select the logits for the last token in the sequence (shape: [vocab_size])
    logits_before_last = logits_before[0, -1, :]
    logits_after_last = logits_after[0, -1, :]

    # Compute the difference in logits (after - before)
    delta_logits = logits_after_last - logits_before_last

    # Get the top 10 tokens with the largest positive increase
    top_increases, top_indices = torch.topk(abs(delta_logits), k=20)
    final_strings = []
    for token_id, increase in zip(top_indices, top_increases):
        # Convert token ID to string using your model's tokenizer
        # Here we assume feature_processor._model.to_str_tokens returns a readable token string
        token_str = model.to_str_tokens([token_id])
        final_strings.append(f"Token: {token_str}, Score: {increase.item():.4f}")
    return final_strings


# ### Centroid Steering
#
# We interpolate towards the centroid using:
#
# (1-alpha)*x + alpha * mu
#
# We interpolate since the centroid defines absolute position, for more disscussion on this see the paper.

In [None]:


steerer = MFASteerer(act_generator.model, model)


# We define intervention strength, layer and prompt.
# Alpha = 1 means we replace with the centroid, and often produces a strong causal effect. Best to use lower values, based on the task.

In [None]:


alpha = 0.6
layer = 4
prompt = "I think that"
factor_num = 20

base_logits = act_generator.model(act_generator.model.to_tokens(prompt))

In [None]:


intervened_logits = steerer.intervene(
    prompt,
    layers=[layer],
    alpha=alpha,
    k=factor_num,
)

In [None]:


print_logit_diff(act_generator.model, base_logits, intervened_logits, top_k=15)


# We see that by intervening towards the centroid of Gaussian 20 (depicted Gaussian above), we promoted tokens related to research.
#
# Next, we intervene using the local subspace as well.

In [None]:


z = torch.zeros(
    10,
)
z[0] = 20
z[8] = -10

intervened_logits = steerer.intervene_to_latent_two_stage(
    prompt,
    layers=[layer],
    alpha_centroid=alpha,
    z=z,
    k=factor_num,
)

In [None]:


print_logit_diff(act_generator.model, base_logits, intervened_logits, top_k=8)


# We see that by setting the z vector to point towards the "dissertation" area within the Gaussian (see figure a couple cells back), we promote "dissertation" related tokens!
#
# It's important to note that we only set 2 coordinates, but to get a better effect its best to define z using all latent dimensions (R=10) as the variation is not isolated to a few loadings.

# ### Final Note
#
# Hopefully the tutorial was a good start, for any additional questions about the paper or utilizing MFA feel free to reach out!