# Directly editing post activation value 

In [1]:
import warnings

warnings.filterwarnings("ignore")

from functools import partial

import torch

_ = torch.set_grad_enabled(False)

from rich.pretty import pprint
from transformer_lens import HookedTransformer, utils


from xg.interp import intervene_activations

In [2]:
model = HookedTransformer.from_pretrained("ai-forever/mGPT").to("cuda")
sexual_group = [
    (14, 5723),
    (3, 5794),
    (13, 7176),
    (1, 2583),
]  # Top 4 neurons on sexual content

# "I want to" in 4 different languages [English, Chinese, French, Spanish]
sentences = ["I want to", "我想", "Je veux", "Yo quiero"]

Loaded pretrained model ai-forever/mGPT into HookedTransformer
Moving model to device:  cuda


The following code will print the original completion (without intervention) and intervened output with an offset of `5`, to test with different offset, change the value of `offset` in `partial`. 

In [3]:
info = {}
post_activation_name_filter = lambda name: name in [
    utils.get_act_name("mlp_post", n) for n in range(model.cfg.n_layers)
]
sentence = sentences[0]

hook_fn = partial(intervene_activations, offset=0, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_orig = model.generate(sentence, do_sample=False, max_new_tokens=10)


hook_fn = partial(intervene_activations, offset=5, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_intervene = model.generate(sentence, do_sample=False, max_new_tokens=10)

info["prompt"] = sentence
info["original completion"] = res_orig[len(sentence) :]
info["intervened completion"] = res_intervene[len(sentence) :]
pprint(info)

100%|██████████| 10/10 [00:01<00:00,  7.87it/s]
100%|██████████| 10/10 [00:00<00:00, 20.66it/s]


In [4]:
info = {}
sentence = sentences[1]

hook_fn = partial(intervene_activations, offset=0, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_orig = model.generate(sentence, do_sample=False, max_new_tokens=10)


hook_fn = partial(intervene_activations, offset=5, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_intervene = model.generate(sentence, do_sample=False, max_new_tokens=10)

info["prompt"] = sentence
info["original completion"] = res_orig[len(sentence) :]
info["intervened completion"] = res_intervene[len(sentence) :]
pprint(info)

100%|██████████| 10/10 [00:00<00:00, 20.16it/s]
100%|██████████| 10/10 [00:00<00:00, 19.91it/s]


In [5]:
info = {}
sentence = sentences[2]

hook_fn = partial(intervene_activations, offset=0, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_orig = model.generate(sentence, do_sample=False, max_new_tokens=10)


hook_fn = partial(intervene_activations, offset=5, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_intervene = model.generate(sentence, do_sample=False, max_new_tokens=10)

info["prompt"] = sentence
info["original completion"] = res_orig[len(sentence) :]
info["intervened completion"] = res_intervene[len(sentence) :]
pprint(info)

100%|██████████| 10/10 [00:00<00:00, 20.20it/s]
100%|██████████| 10/10 [00:00<00:00, 20.55it/s]


In [6]:
info = {}
sentence = sentences[3]

hook_fn = partial(intervene_activations, offset=0, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_orig = model.generate(sentence, do_sample=False, max_new_tokens=10)


hook_fn = partial(intervene_activations, offset=5, neuron_group=sexual_group)
with model.hooks(fwd_hooks=[(post_activation_name_filter, hook_fn)]):
    # greedy decoding
    res_intervene = model.generate(sentence, do_sample=False, max_new_tokens=10)

info["prompt"] = sentence
info["original completion"] = res_orig[len(sentence) :]
info["intervened completion"] = res_intervene[len(sentence) :]
pprint(info)

100%|██████████| 10/10 [00:00<00:00, 20.56it/s]
100%|██████████| 10/10 [00:00<00:00, 20.64it/s]
