# Activation steering with TransformerLens and gpt2-xl

This notebook shows how to access and modify internal model activations using the transformer lens library.


In [4]:
import torch
from transformer_lens import HookedTransformer

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [29]:
# load transformer lens model
model = HookedTransformer.from_pretrained_no_processing("gpt2-xl", default_prepend_bos=False).eval()

Loaded pretrained model gpt2-xl into HookedTransformer


In [49]:
# define what layer/module you want information from and get the internal activations
layer_id = 5
cache_name = f"blocks.{layer_id}.hook_resid_post" # we do activation steering on the activation (the output) of the residual layer

_, cache = model.run_with_cache("Love")
act_love = cache[cache_name]
_, cache = model.run_with_cache("Hate")
act_hate = cache[cache_name]

print(f"act_love.shape: {act_love.shape}")
print(f"act_hate.shape: {act_hate.shape}")

act_love.shape: torch.Size([1, 1, 1600])
act_hate.shape: torch.Size([1, 2, 1600])


As you can see by looking at the shape of the activation tensors, the input "sentences" are tokenized into different numbers of tokens. To make this into a vector we only take the numerical values of the last token.

In [50]:
# define the steering vector
steering_vec = act_love[:,-1:,:]-act_hate[:,-1:,:]
print(f"steering_vec.shape:  {steering_vec.shape}")
print(f"length steering_vec: {steering_vec.norm():.2f}")

# reset the steering vector length to 1
steering_vec /= steering_vec.norm()

steering_vec.shape:  torch.Size([1, 1, 1600])
length steering_vec: 100.23


In [51]:
# define the activation steering funtion
def act_add(steering_vec):
    def hook(activation, hook):
        return activation + steering_vec
    return hook

We previously used the function `run_with_cache` to get the internal activations. This function adds PyTorch hooks before running the model and removes them afterwards.
There is also the function `run_with_hooks` for which you can set your own hook functions. However I did not find a function `generate_with_hooks`.

If we want to generate new text, the model needs to repeatedly perform a forward pass and we want our activation addition to happen in each forward pass. We consequently need to set a hook that does the activation addition. After we generated our text it is important to remove the hook.

In [53]:
test_sentence = "I think dogs are "

# generate text while steering in positive direction
coeff = 10
model.add_hook(name=cache_name, hook=act_add(coeff*steering_vec))
print(model.generate(test_sentence, max_new_tokens=10, do_sample=False))
model.reset_hooks()
print("-"*20)

# generate text while steering in negative direction
coeff = -10
test_sentence = "I think dogs are "
model.add_hook(name=cache_name, hook=act_add(coeff*steering_vec))
print(model.generate(test_sentence, max_new_tokens=10, do_sample=False))
model.reset_hooks()

100%|██████████| 10/10 [00:03<00:00,  3.04it/s]


I think dogs are  a great way to get to know someone.
--------------------


100%|██████████| 10/10 [00:03<00:00,  3.25it/s]

I think dogs are icky, but I don't think they're 





In [48]:
# generate text without steering
print(model.generate(test_sentence, max_new_tokens=10, do_sample=False))

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:02<00:00,  3.38it/s]

I think dogs are  a great way to get your dog to learn





The output of the HookedTransformer model is different to the output of the basemodel.

In [16]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import Trace

# load model
org_model = AutoModelForCausalLM.from_pretrained("gpt2-xl").to(device).eval()
# load tokenizer
org_tokenizer = AutoTokenizer.from_pretrained("gpt2-xl", add_bos_token=True)

In [17]:
# generate text without steering
inputs = org_tokenizer(test_sentence, return_tensors="pt").to(device)
generated_ids = org_model.generate(**inputs, max_new_tokens=10, pad_token_id=org_tokenizer.eos_token_id, do_sample=False)
generated_text = org_tokenizer.batch_decode(generated_ids)
print(generated_text[0])

I think dogs are  a great way to get your dog to learn


In [54]:
# Compare internal model activations between the HookedTransformer model and the original

In [55]:
# define layer to do the activation steering on
layer_id = 5
module = org_model.transformer.h[layer_id]
test_sentence = "The quick brown fox jumps over the lazy dog."

# get internal activations
inputs = org_tokenizer(test_sentence, return_tensors="pt").to(device)
with Trace(module) as cache:
    _ = org_model(**inputs)
    baukit_activation = cache.output[0]


org_activation = org_model(**inputs, output_hidden_states=True)["hidden_states"][layer_id+1] # hidden_states[0] are the input embeddings

print(f"mse org_activation, baukit_activation: {(baukit_activation-org_activation).pow(2).mean()}")


mse org_activation, baukit_activation: 0.0


In [56]:
_, cache = model.run_with_cache(test_sentence)
cache_name = f"blocks.{layer_id}.hook_resid_post"
_, cache = model.run_with_cache(test_sentence)
tl_activation = cache[cache_name]


In [57]:
print(f"mse org_activation, tl_activation: {(baukit_activation-tl_activation).pow(2).mean()}")

mse org_activation, tl_activation: 2.3551237457626606e-13


In [45]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-47): 48 x TransformerBlock(
      (ln1): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (hook_re