# Intro to how to do activation addition with pytorch hooks

This notebook shows how to access and modify internal model activations using pytorch hooks.

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
dtype = torch.bfloat16
# we want the padding side to be left as we want an easy way to access the last token
padding_side = "left"

Using device: cuda


In [70]:
# load model
model_name = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=dtype)
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = padding_side

tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

### Access hidden activations
We can access hidden activations of the residual streem by passing the `output_hidden_states` keyword. But if we want other hidden states we need some way to hook into the other transformer layers.  

In [60]:
test_sentence = "The quick brown fox jumps over the lazy dog"

model.eval()
with torch.no_grad():
    inputs = tokenizer(test_sentence, return_tensors="pt").to(device)
    output = model(**inputs, output_hidden_states=True)
    print(output.keys())

odict_keys(['logits', 'past_key_values', 'hidden_states'])


In [6]:
print(f"len hidden states: {len(output.hidden_states)}")
print(f"shape of hidden states: {output.hidden_states[0].shape}")

len hidden states: 33
shape of hidden states: torch.Size([1, 12, 4096])


In [7]:
# the residual layer output of layer layer_id is
# output.hidden_states[layer_id + 1] as the first hidden state is the input embedding
layer_id = 5
hidden_states = output.hidden_states[layer_id + 1]

# Accessing hidden states with pytorch hooks
Instead of passing the `output_hidden_states=True` keyword to the forward pass which outputs ALL hidden states, we can hook into a specific model and save these specific hidden states.

In order to do that we need to define a pytorch hook

In [10]:
def cache_activations(cache):
    def hook(module, input, output):
        if isinstance(output, tuple):
            cache[:] = output[0]
        else:
            cache[:] = output
    return hook

In [11]:
cached_activations = torch.zeros_like(hidden_states)
hook_handle = model.model.layers[layer_id].register_forward_hook(cache_activations(cached_activations))

with torch.no_grad():
    inputs = tokenizer(test_sentence, return_tensors="pt").to(device)
    output = model(**inputs)

hook_handle.remove()

print(f"MSE between cached_activations and hidden_stated: {(hidden_states-cached_activations).pow(2).mean()}")

MSE between cached_activations and hidden_stated: 0.0


### With this we can also access other hidden representations, like the attention layer for example
If I do not know the size of the object I want to cache, I can simply append it to a list

In [84]:
# define our activation cache hook
def cache_activations_list(cache):
    def hook(module, input, output):
        # for some layer (for example decoder layer in Llama) the output contains additional info besides activations
        if isinstance(output, tuple):
            cache.append(output[0])
        else:
            cache.append(output)
    return hook

cached_activations = []
hook_handle = model.model.layers[layer_id].self_attn.q_proj.register_forward_hook(cache_activations_list(cached_activations))

with torch.no_grad():
    inputs = tokenizer(test_sentence, return_tensors="pt").to(device)
    output = model(**inputs)

hook_handle.remove()
print(f"shape of cached activations: {cached_activations[0].shape}")

shape of cached activations: torch.Size([1, 12, 4096])


## Activation addition with Pytorch hooks

Define a direction vector and add it to the internal model activations while generating new model output.

In [53]:
# define our addition hook
def add_to_activations(toadd):
    def hook(module, input, output):
        # for some layer (for example decoder layer in Llama) the output contains additional info besides activations
        if isinstance(output, tuple):
            # the output cannot be modifies in place, we actually have to return the modified output
            return (output[0] + toadd,) + output[1:]
        else:
            return output + toadd
    return hook

In [80]:
def run_with_cache(model, tokenizer, sentences, device="cuda", hidden_size=model.config.hidden_size, 
                    module_to_hook=model.model.layers[layer_id], hook_fun=cache_activations):
    
    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(device)
        # define a tensor with the size of our cached activations
        cached_activations = torch.zeros(inputs["input_ids"].shape + (hidden_size,), device=device)
        hook_handle = module_to_hook.register_forward_hook(hook_fun(cached_activations))
        output = model(**inputs)

    hook_handle.remove()

    return cached_activations

def generate_with_aa(model, tokenizer, sentences, direction, max_new_tokens=20, device="cuda", random_seed=0,
                    hidden_size=model.config.hidden_size, module_to_hook=model.model.layers[layer_id], hook_fun=add_to_activations):

    hook_handle = module_to_hook.register_forward_hook(hook_fun(direction))
    torch.random.manual_seed(random_seed)

    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(device)
        generate_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
        generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

    hook_handle.remove()
    return generated_text

def generate(model, tokenizer, sentences, direction, max_new_tokens=20, device="cuda", random_seed=0):
    torch.random.manual_seed(random_seed)

    with torch.no_grad():
        inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(device)
        generate_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, use_cache=True)
        generated_text = tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

    return generated_text


In [86]:
sentences = ["Love", "Hate"]

cached_activations = run_with_cache(model, tokenizer, sentences, device, 
                hidden_size=model.config.hidden_size, 
                module_to_hook=model.model.layers[layer_id], 
                hook_fun=cache_activations)

print(f"shape cached_activations: {cached_activations.shape}")

shape cached_activations: torch.Size([2, 3, 4096])


In [88]:
token_pos = -1
# the direction should have the same number of dimensions as the activations (they usually have shape [batch_size, num_tokens, hidden_dim])
# the easiest is just to define a direction with shape [1, 1, hidden_dim]
# this can then be added to all tokens for the complete batch
direction = cached_activations[0:1, token_pos:, :] - cached_activations[1:, token_pos:, :]
# make sure the direction vector is on the same device and has same precision as the model
direction = direction.to(device=device, dtype=dtype)
print(f"shape of direction: {direction.shape}")
print(f"norm of direction:  {direction.norm(dim=-1).item():.4g}")

shape of direction: torch.Size([1, 1, 4096])
norm of direction:  8.688


In [90]:
sentences = ["I think dogs are", "I think cats are", "Today I feel"]
random_seed = 0

generated_love = generate_with_aa(model, tokenizer, sentences, direction, max_new_tokens=20,
                                    device="cuda", random_seed=random_seed, hidden_size=model.config.hidden_size, 
                                    module_to_hook=model.model.layers[layer_id], hook_fun=add_to_activations)

print("Generate with positive direction:\n")
for sentence in generated_love:
    print(sentence)
    print("---")


generated_hate = generate_with_aa(model, tokenizer, sentences, -direction, max_new_tokens=20,
                                    device="cuda", random_seed=random_seed, hidden_size=model.config.hidden_size, 
                                    module_to_hook=model.model.layers[layer_id], hook_fun=add_to_activations)

print("\nGenerate with negative direction:\n")
for sentence in generated_hate:
    print(sentence)
    print("---")


generated_neutral = generate(model, tokenizer, sentences, -direction, max_new_tokens=20,
                                    device="cuda", random_seed=random_seed)

print("\nGenerate without vector addition:\n")
for sentence in generated_neutral:
    print(sentence)
    print("---")

Generate with positive direction:

I think dogs are amazing companions and I'm excited to be a part of the dog world! I'
---
I think cats are one of the most interesting creatures on earth. They have a way of bringing joy and companionship
---
Today I feel like I'm in a relationship with my best friend, my partner, my soulmate.

---

Generate with negative direction:

I think dogs are pretty cool. I'm glad I'm not one of those people who has to be around
---
I think cats are evil, but I’m not going to go around and say that all cats are evil.
---
Today I feel like a total bum. I hate everything and everyone. I hate the world and everyone in it
---

Generate without vector addition:

I think dogs are awesome! They are loyal, loving, and always happy to see you. They are also very
---
I think cats are really interesting creatures. They have a lot of personality and can be very affectionate. I
---
Today I feel like I'm in a rut. I'm not sure why, but I just feel
---


# Notes
* you need to keep track of the hook handles. It can be a bit annoying (especially when developing) if you lose one of the handles and then you have to reload the model -> some hook handle tracker could help out (one could implement a class or a decorator to do this)
* make sure that tensors are on the same device and use the same precision
* there are unfortunately many differences in how models are implemented, so you might have to adapt parameter names like `model.config.hidden_size` or module names like `model.model.layers[layer_id]`