## Codebook Features

Alex Tamkin, Mohammad Taufeeque and Noah D. Goodman: "Codebook Features: Sparse and Discrete Interpretability for Neural Networks", 2023.

Codebook Features is a method for training neural networks with a set of learned sparse and discrete hidden states, enabling interpretability and control of the resulting model.

Codebook features work by inserting vector quantization bottlenecks called codebooks into each layer of a neural network. The library provides a range of features to train and interpret codebook models, including by analyzing the activations of codes, searching for codes that activate on a pattern, and performing code interventions to verify the causal effect of a code on the output of a model. Many of these features are also available through an easy-to-use webapp that helps in analyzing and experimenting with the codebook models.

### Code Intervention Tutorial

This is a tutorial notebook on how to perform code interventions with the Codebook Features library. The goal of this tutorial is to steer a language model to generate text that follows a topic (by activating specific topic codes) and quantitatively evaluate how well the model was steered. We use the TinyStories 21M parameter model trained on synthetic children's stories (https://arxiv.org/abs/2305.07759). This small model typically produces grammatical but incoherent stories; nevertheless we can use it to see how different topics are woven into the network. For example, activating 'baby' codes causes the model to introduce topics such as babies, bathtubs, and baby birds, rather than simply outputting 'baby baby baby' repeatedly.

In [2]:
from tqdm import tqdm
from codebook_features import models
from codebook_features import utils as cb_utils
import torch
import re

# We turn automatic differentiation off, to save GPU memory,
# as this tutorial focuses only on model inference
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1799f45b0>

In [3]:
model_name_or_path = "roneneldan/TinyStories-1Layer-21M"
pretrained_path = "taufeeque/TinyStories-1Layer-21M-Codebook"

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"
    print(
        "No GPU found, using CPU instead. If running on Colab, "
        "make sure to enable GPU acceleration under Runtime -> Change runtime type"
    )
    
orig_cb_model = models.wrap_codebook(
    model_or_path=model_name_or_path, pretrained_path=pretrained_path
)
orig_cb_model = orig_cb_model.to(device).eval()

No GPU found, using CPU instead. If running on Colab, make sure to enable GPU acceleration under Runtime -> Change runtime type


Convert the model into a hooked transformer model (from transformer_lens) that allows us to do code interventions easily

In [4]:
hooked_kwargs = dict(
    center_unembed=False,
    fold_value_biases=False,
    center_writing_weights=False,
    fold_ln=False,
    refactor_factored_attn_matrices=False,
    device=device,
)
cb_model = models.convert_to_hooked_model(
    model_name_or_path, orig_cb_model, hooked_kwargs=hooked_kwargs
)
cb_model = cb_model.to(device).eval()
tokenizer = cb_model.tokenizer



Loaded pretrained model roneneldan/TinyStories-1Layer-21M into HookedTransformer


Assert that the original codebook model and the hooked model produce the same output

In [5]:
sentence = "this is a random sentence to test."
input_tensor = tokenizer(sentence, return_tensors="pt")["input_ids"]
input_tensor = input_tensor.to(device)
output = orig_cb_model(input_tensor)["logits"]
hooked_output = cb_model(input_tensor)
assert torch.allclose(output, hooked_output, atol=1e-4)

#### Topic Codes

Below, we have provided a subset of topic codes we have found in this model. Many more such topic codes can be found through the Codebook Features webapp: https://huggingface.co/spaces/taufeeque/codebook-features

Note that multiple codes can be patched in at the same component codebook (in this case, a given attention head at a given layer) since the codebook activates multiple codes. Since there are multiple codes that can represent a topic, we patch in multiple codes for each topic, possibly from different attention heads. You can play around by removing some codes for a topic and seeing how the generated text changes.

In [6]:
topic_codes_str = {
    "": ""
}  # blank one is used for default generations (no topic steering)

topic = "dragon"
topic_codes_str[
    topic
] = """
Code: 4670, Layer: 0, Head: 13
Code: 17640, Layer: 0, Head: 13
Code: 19845, Layer: 0, Head: 13
Code: 23958, Layer: 0, Head: 13
Code: 3410, Layer: 0, Head: 13
Code: 19523, Layer: 0, Head: 13
Code: 2262, Layer: 0, Head: 13
Code: 16060, Layer: 0, Head: 13
"""

topic = "slide"
topic_codes_str[
    topic
] = """
Code: 1331, Layer: 0, Head: 14
Code: 22178, Layer: 0, Head: 14
Code: 15885, Layer: 0, Head: 14
Code: 9524, Layer: 0, Head: 14
Code: 15549, Layer: 0, Head: 14
Code: 7802, Layer: 0, Head: 14
Code: 11942, Layer: 0, Head: 14
Code: 4095, Layer: 0, Head: 1
Code: 2179, Layer: 0, Head: 1
Code: 22425, Layer: 0, Head: 1
Code: 10661, Layer: 0, Head: 1
Code: 8598, Layer: 0, Head: 1
"""

topic = "friend"
topic_codes_str[
    topic
] = """
Code: 20506, Layer: 0, Head: 11
Code: 6103, Layer: 0, Head: 11
Code: 15764, Layer: 0, Head: 11
Code: 14060, Layer: 0, Head: 11
Code: 21005, Layer: 0, Head: 11
Code: 16006, Layer: 0, Head: 11
Code: 12290, Layer: 0, Head: 11
Code: 7404, Layer: 0, Head: 11
Code: 2471, Layer: 0, Head: 13
"""

topic = "flower"
topic_codes_str[
    topic
] = """
Code: 23967, Layer: 0, Head: 13
Code: 13533, Layer: 0, Head: 13
Code: 4175, Layer: 0, Head: 13
Code: 6390, Layer: 0, Head: 13
Code: 18765, Layer: 0, Head: 13
Code: 1775, Layer: 0, Head: 13
Code: 7430, Layer: 0, Head: 13
Code: 9269, Layer: 0, Head: 13
"""

topic = "fire"
topic_codes_str[
    topic
] = """
Code: 9151, Layer: 0, Head: 13
Code: 6389, Layer: 0, Head: 13
Code: 16473, Layer: 0, Head: 13
Code: 24184, Layer: 0, Head: 13
Code: 11224, Layer: 0, Head: 13
Code: 16757, Layer: 0, Head: 13
Code: 16684, Layer: 0, Head: 13
Code: 22825, Layer: 0, Head: 13
Code: 22980, Layer: 0, Head: 14
Code: 6544, Layer: 0, Head: 14
Code: 2672, Layer: 0, Head: 14
Code: 5791, Layer: 0, Head: 14
Code: 22544, Layer: 0, Head: 14
Code: 6971, Layer: 0, Head: 14
Code: 23452, Layer: 0, Head: 14
Code: 708, Layer: 0, Head: 14
"""

topic = "prince|crown|king|castle"
topic_codes_str[
    topic
] = """
Code: 28, Layer: 0, Head: 13
Code: 19802, Layer: 0, Head: 13
Code: 22851, Layer: 0, Head: 13
Code: 8907, Layer: 0, Head: 13
Code: 18042, Layer: 0, Head: 13
Code: 9619, Layer: 0, Head: 13
Code: 15278, Layer: 0, Head: 13
Code: 9649, Layer: 0, Head: 13
Code: 13055, Layer: 0, Head: 14
Code: 13575, Layer: 0, Head: 14
Code: 9784, Layer: 0, Head: 14
Code: 19023, Layer: 0, Head: 14
Code: 7704, Layer: 0, Head: 14
Code: 6056, Layer: 0, Head: 14
"""

topic = "baby"
topic_codes_str[
    topic
] = """
Code: 66, Layer: 0, Head: 13
Code: 657, Layer: 0, Head: 13
Code: 9965, Layer: 0, Head: 13
Code: 13724, Layer: 0, Head: 13
Code: 5276, Layer: 0, Head: 13
Code: 11101, Layer: 0, Head: 13
Code: 10272, Layer: 0, Head: 13
Code: 3067, Layer: 0, Head: 3
Code: 18686, Layer: 0, Head: 3
Code: 430, Layer: 0, Head: 3
Code: 12364, Layer: 0, Head: 3
Code: 1209, Layer: 0, Head: 3
Code: 13863, Layer: 0, Head: 3
Code: 15111, Layer: 0, Head: 3
Code: 1185, Layer: 0, Head: 3
"""

# this converts the strings to lists of topic codes of the type `CodeInfo` that the library uses.
topic_codes = {
    k: cb_utils.parse_topic_codes_string(v, pos=None, code_append=False)
    for k, v in topic_codes_str.items()
}
for topic, codes in topic_codes.items():
    for code in codes:
        code.cb_at = cb_model.config.codebook_at[0]

#### Code Intervention

Now we perform the code intervention for a specific topic using the `generate_with_codes` function.

In [7]:
# specify the topic you want the generations to steer towards
topic = "baby"

In [8]:
# CodeInfo objects hold a code's associated metadata (e.g. position in the network)
list_of_code_infos = topic_codes[topic]

text_input = "Once upon a time,"
inp_tensor = cb_model.to_tokens(text_input, prepend_bos=True).to(device)
inp_tensor = inp_tensor.repeat(10, 1)
gen = cb_utils.generate_with_codes(
    inp_tensor,
    cb_model,
    list_of_code_infos=list_of_code_infos,
    generate_kwargs={"max_new_tokens": 200, "do_sample": True, "temperature": 1},
)
gen = [tokenizer.decode(g[1:]) for g in gen]
for i, g in enumerate(gen):
    print(f"Story {i}:")
    print(g)
    print("************************************")

100%|█████████████████████████████████████████| 200/200 [00:23<00:00,  8.37it/s]

Story 0:
Once upon a time, there was a little girl who loved Halloween. She was only 3 years old, but her parents were so scared. Every day she would go to the shop to buy a new fyer rug. It was blue, and it smelled so sweet.

One day, Baby's mom said to the 3 year old, "OK, let's go to the store."

So the mommy took the blue toilet out. They were yummy. But when the little three year old found the soap aisle. Mommy saw it and knew it was small and it had a special request. He asked her mum, "Do you understand the law, baby?"

The little 3 year old girl replied, "Yes! I love it!"

Mum looked at the baby and kissed the baby. Soon, the baby had a new toy - a chair and real and the mommy smiled. This was the most fun-looking teddy loved and laughed out back home!Once upon
************************************
Story 1:
Once upon a time, there was a little girl named Lily. She was very happy because her parents were always praising her for her. One day, she found a new book with another sill




#### Quantitative Evaluation of Topic Steering

Here we do a quantitative evaluation of topic steering by measuring the fraction of generated texts that contain the topic string in the generated text.

For each of the topic that we have for steering, we generate 10 samples with the topic code patched in for each of our prompt. We then measure the fraction of generated texts that contain the topic string in the generated text. Note that this is an imperfect evaluation, as the model may generate strings related to the topic but not include the topic word itself (e.g. 'babies' vs 'baby').

In [None]:
prompts = [
    "",
    "Once upon a time,",
    "Once there was a",
    "A long time ago,",
]

prompt_completions = {}
for topic in tqdm(topic_codes_str):
    list_of_arg_tuples = topic_codes[topic]
    prompt_completions[topic] = {}
    for prompt in prompts:
        prompt_token = cb_model.to_tokens(prompt, prepend_bos=True).to(device)
        prompt_token = prompt_token.repeat(10, 1)
        gen = cb_utils.generate_with_codes(
            prompt_token,
            cb_model,
            list_of_code_infos=list_of_arg_tuples,
            generate_kwargs={
                "max_new_tokens": 200,
                "do_sample": True,
                "temperature": 1,
            },
        )
        gen = [tokenizer.decode(gen[i][1:]) for i in range(len(gen))]
        prompt_completions[topic][prompt] = gen

  0%|                                                     | 0/8 [00:00<?, ?it/s]
  0%|                                                   | 0/200 [00:00<?, ?it/s][A
  0%|▏                                          | 1/200 [00:00<02:29,  1.34it/s][A
  1%|▍                                          | 2/200 [00:00<01:14,  2.68it/s][A
  2%|▋                                          | 3/200 [00:00<00:49,  4.02it/s][A
  2%|▊                                          | 4/200 [00:01<00:37,  5.26it/s][A
  2%|█                                          | 5/200 [00:01<00:31,  6.27it/s][A
  3%|█▎                                         | 6/200 [00:01<00:27,  6.99it/s][A
  4%|█▌                                         | 7/200 [00:01<00:24,  7.75it/s][A
  4%|█▋                                         | 8/200 [00:01<00:23,  8.30it/s][A
  5%|██                                        | 10/200 [00:01<00:20,  9.30it/s][A
  6%|██▎                                       | 11/200 [00:01<00:20,  9.44it/s

In [10]:
topic_in_prompt_completion = {}
for topic in topic_codes_str:
    if not topic:
        continue
    topic_in_prompt_completion[topic] = {}
    for prompt in prompts:
        topic_in_prompt_completion[topic][prompt] = 0
        for completion in prompt_completions[topic][prompt]:
            if re.search(topic.lower(), completion.lower()):
                topic_in_prompt_completion[topic][prompt] += 1
        topic_in_prompt_completion[topic][prompt] /= len(
            prompt_completions[topic][prompt]
        )

topic_in_prompt_completion_avg = {}
for topic in topic_codes_str:
    if not topic:
        continue
    topic_in_prompt_completion_avg[topic] = 0
    for prompt in prompts:
        topic_in_prompt_completion_avg[topic] += topic_in_prompt_completion[topic][
            prompt
        ]
    topic_in_prompt_completion_avg[topic] /= len(prompts)

We also get the baseline fraction of generated texts that contain the topic string in the generated text with 10 samples that don't have any topic code patched in for each of our prompt. This gives us a baseline number for each topic being mentioned by default.

In [11]:
orig_prompts = prompt_completions[""]
topic_in_orig_prompt_completion = {}

for topic in topic_codes_str:
    if not topic:
        continue
    topic_in_orig_prompt_completion[topic] = {}
    for prompt in prompts:
        topic_in_orig_prompt_completion[topic][prompt] = 0
        for completion in orig_prompts[prompt]:
            if re.search(topic.lower(), completion.lower()):
                topic_in_orig_prompt_completion[topic][prompt] += 1
        topic_in_orig_prompt_completion[topic][prompt] /= len(
            prompt_completions[topic][prompt]
        )

topic_in_orig_prompt_completion_avg = {}
for topic in topic_codes_str:
    if not topic:
        continue
    topic_in_orig_prompt_completion_avg[topic] = 0
    for prompt in prompts:
        topic_in_orig_prompt_completion_avg[topic] += topic_in_orig_prompt_completion[
            topic
        ][prompt]
    topic_in_orig_prompt_completion_avg[topic] /= len(prompts)

As we can see, the fraction of generated texts that contain the topic string in the generated text is much higher when we patch in the topic code compared to the baseline.

In [12]:
print("Baseline (no topic steering):")
print(f"Topic\t\t\t\tAvg Steering (%)")
for topic, frac in topic_in_orig_prompt_completion_avg.items():
    if not topic:
        continue
    print(f"{topic}\t\t\t\t{frac*100:.1f}")

print()
print()

print("Topic steering with code interventions:")
print(f"Topic\t\t\t\tAvg Steering (%)")
for topic, frac in topic_in_prompt_completion_avg.items():
    if not topic:
        continue
    print(f"{topic}\t\t\t\t{frac*100:.1f}")

Baseline (no topic steering):
Topic				Avg Steering (%)
dragon				0.0
slide				0.0
friend				50.0
flower				7.5
fire				2.5
prince|crown|king|castle				37.5
baby				2.5


Topic steering with code interventions:
Topic				Avg Steering (%)
dragon				47.5
slide				77.5
friend				70.0
flower				92.5
fire				100.0
prince|crown|king|castle				82.5
baby				95.0
