In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import sys
sys.path.append('..')
from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution
from causal_trace.utils import get_model_size

torch.__version__, transformers.__version__

('2.1.0+cu121', '4.34.1')

In [3]:
# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        low_cpu_mem_usage=True,
        # torch_dtype=torch.float16,
    ).to("cuda"),
    AutoTokenizer.from_pretrained(
        MODEL_PATH, 
        # padding_side='left'
    ),
)

if ("mistral" in model.config._name_or_path.lower() or "llama" in model.config._name_or_path.lower()):
    setattr(model.config, "n_embd", model.config.hidden_size)
    setattr(model.config, "n_positions", model.config.n_embd//2)

tok.pad_token = tok.eos_token
model.eval()
# model.config

print(f"loaded model <{MODEL_PATH}> | size: {get_model_size(model) :.3f} MB")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loaded model <meta-llama/Llama-2-7b-hf> | size: 25833.023 MB


In [4]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P276"]

places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]

places_to_cities[:4]

Loaded dataset with 21919 elements


[('Inner Circle railway line', 'Melbourne'),
 ('2010 Winter Paralympics', 'Vancouver'),
 ('Hamburg International Film Festival', 'Hamburg'),
 ('PAX', 'Seattle')]

In [5]:
import names
import numpy as np

num_options = 3
num_icl_examples = 5
icl_examples = []

while len(icl_examples) < num_icl_examples:
    cur_options = [
        places_to_cities[k] for k in
        np.random.choice(len(places_to_cities), size = num_options, replace = False)
    ]
    person_names = []
    while(len(set(person_names)) != num_options):
        person_names.append(names.get_first_name())

    example = ", ".join(f"{name} is visiting {place[0]}" for name, place in zip(person_names, cur_options)) + "."
    query_idx = np.random.choice(num_options)
    example += f" {person_names[query_idx]} is in {cur_options[query_idx][1]}."
    icl_examples.append(example)


In [6]:
from memit.extra_utils import predict_next_token

predict_next_token(
    model, tokenizer=tok,
    prompt="Eiffel Tower is located in",
    k=5,
)

[[('Paris', 0.5234667658805847, 3681),
  ('the', 0.2600758671760559, 278),
  ('France', 0.023706603795289993, 3444),
  ('which', 0.02050362154841423, 607),
  ('', 0.011508418247103691, 29871)]]

In [7]:
query_prompt = "Alice is visiting the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is in"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Joyce is visiting Typhoon Fran, Albert is visiting Battle of Baltimore, Woodrow is visiting Catalan self-determination referendum. Albert is in Baltimore.
William is visiting 2004 Madrid train bombings, Roy is visiting Russian Civil War, Paul is visiting Platonic Academy. William is in Madrid.
Esther is visiting Concordia University, Dee is visiting 68th Venice International Film Festival, Rose is visiting Cleveland Classic. Dee is in Venice.
Samuel is visiting Miraflores Altarpiece, Sarah is visiting Pride Toronto, Chris is visiting Old Market Square. Samuel is in Berlin.
Jamie is visiting Sinagua, Dana is visiting Shanghai International Film Festival, Jean is visiting 2006 IAAF World Indoor Championships. Dana is in Shanghai.
Alice is visiting the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is in


In [8]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('London', 0.8672000169754028, 4517),
  ('New', 0.037992458790540695, 1570),
  ('Washington', 0.012185974046587944, 7660),
  ('the', 0.007523127365857363, 278),
  ('England', 0.007257919758558273, 5408)]]

In [10]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x


def intervention(intervention_layer, intervene_at, patching_vector):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = patching_vector
        return output

    return edit_output

In [11]:
request = [
    {
        "prompt": tok.bos_token + " {} is located in the city of",
        # "prompt": "Which city is the {} in? It is in",
        "subject": "Eiffel Tower",
        "target_new": {"str": "Seattle"},
    },
    # {
    #     "prompt": "{} is located in the city of",
    #     "subject": "Big Ben",
    #     "target_new": {"str": "Paris"},
    # },
]

generation_prompts = [
    "Eiffel Tower is located in the city of",
    "Eiffel Tower, which is in the city of",
    "Which city is the Eiffel Tower in? It is in",
    "Eiffel Tower is made of",
    "Eiffel Tower is in"
]

In [12]:
from memit.memit_main import get_module_input_output_at_words

context_templates=[
    '{} is located in the city of', 
    'The first step to a new life is to. {} is located in the city of', 
    'Therefore, the best way to prevent this from. {} is located in the city of', 
    'Because the first time I saw the trailer. {} is located in the city of', 
    "I'm not sure if this is the. {} is located in the city of", 
    'You are here: Home / Archives for . {} is located in the city of', 
    
    # '{} is located in the city of', 
    # 'The first step to a new life is to. {} is located in the city of', 
    # 'Therefore, the best way to prevent this from. {} is located in the city of', 
    # 'Because the first time I saw the trailer. {} is located in the city of', 
    # "I'm not sure if this is the. {} is located in the city of", 
    # 'You are here: Home / Archives for . {} is located in the city of'
]
words=[
    'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 
    # 'Big Ben', 'Big Ben', 'Big Ben', 'Big Ben', 'Big Ben', 'Big Ben'
]

l_input, l_output = get_module_input_output_at_words(
    model, tok, 
    layer = 5,
    context_templates = context_templates,
    words = words,
    module_template="model.layers.{}.mlp.down_proj",
    fact_token_strategy="subject_last"
)

layer=5
context_templates=['{} is located in the city of', 'The first step to a new life is to. {} is located in the city of', 'Therefore, the best way to prevent this from. {} is located in the city of', 'Because the first time I saw the trailer. {} is located in the city of', "I'm not sure if this is the. {} is located in the city of", 'You are here: Home / Archives for . {} is located in the city of']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([4], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([13], 'Tower')]
==> [([4], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([13], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])


In [13]:
from memit.memit_hparams import MEMITHyperParams

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 10
layer_name = hparams.layer_module_tmp.format(layer_no)

prompts = [
    context_template.format(word) for context_template, word in zip(context_templates, words)
]

tokenized = tok(prompts, return_tensors="pt", padding=True).to(model.device)
indices = [([4], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([13], 'Tower')]
# indices = [([3], ' Tower'), ([13], ' Tower'), ([13], ' Tower'), ([12], ' Tower'), ([12], ' Tower'), ([12], ' Tower')]

with nethook.Trace(
    model,
    layer = layer_name,
    retain_input=True,
    retain_output=True,
) as trace:
    model(**tokenized)
if "gpt-j" in model.config._name_or_path.lower():
    trace_input = trace.input_kw["hidden_states"]
else:
    trace_input = trace.input

print(trace_input.shape)

inputs = torch.stack(
    [trace_input[i, idx[0]] for i, idx in enumerate(indices)]
).squeeze()
outputs = torch.stack(
    [untuple(trace.output)[i, idx[0]] for i, idx in enumerate(indices)]
).squeeze()

# inputs.squeeze().shape, outputs.squeeze().shape


torch.Size([6, 21, 4096])


In [14]:
from memit.extra_utils import find_token_range
inputs_1_by_1 = []
outputs_1_by_1 = []
for i in range(len(prompts)):
    prompt = prompts[i]
    word = words[i]
    tokenized = tok(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(model.device)
    offset_mapping = tokenized.pop("offset_mapping")
    start, end = find_token_range(
        prompt, word, 
        tokenizer=tok, offset_mapping=offset_mapping[0]
    )
    print(f"subject_last={end-1} | \"{tok.decode(tokenized.input_ids[0][end-1])}\"")

    with nethook.Trace(
        model,
        layer = layer_name,
        retain_input=True,
        retain_output=True,
    ) as trace:
        model(**tokenized)

    if "gpt-j" in model.config._name_or_path.lower():
        trace_input = trace.input_kw["hidden_states"]
    else:
        trace_input = trace.input

    print(trace_input.shape)

    inputs_1_by_1.append(trace_input[0][end-1])
    outputs_1_by_1.append(untuple(trace.output)[0][end-1])

inputs_1_by_1 = torch.stack(inputs_1_by_1)
outputs_1_by_1 = torch.stack(outputs_1_by_1)

# inputs_1_by_1.shape, outputs_1_by_1.shape

torch.allclose(inputs, inputs_1_by_1, atol = 1e-4), torch.allclose(outputs, outputs_1_by_1, atol=1e-4)

subject_last=4 | "Tower"
torch.Size([1, 11, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=13 | "Tower"
torch.Size([1, 20, 4096])


(True, True)

In [15]:
from memit.compute_z import compute_z
from memit.memit_hparams import MEMITHyperParams
from memit.memit_main import get_context_templates

context_templates = get_context_templates(
    model, tok
)

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 10
layer_name = hparams.layer_module_tmp.format(layer_no)

z = compute_z(
    model, tok,
    request=request[0],
    hparams=hparams,
    layer = layer_no,
    context_templates=context_templates,
)

Cached context templates [['{}'], ['The 2017-20. {}', 'Therefore, it is important to know what is. {}', 'Because the 2009-2. {}', 'I have a confession to make. I. {}', 'You are here: Home / Archives for S. {}']]
Computing right vector (v)
target_ids => 'Seattle'
[([8], 'Tower')]
Lookup index found: 8 | Sentence: 'Which city is the Eiffel Tower in? It is in' | Token: "Tower"
[([18], 'Tower')]
[([18], 'Tower')]
[([18], 'Tower')]
[([18], 'Tower')]
[([18], 'Tower')]
[([4], 'Tower')]
Rewrite layer is 10
Tying optimization objective to 31
Recording initial value of v*
loss 11.829 = 11.829 + 0.0 + 0.0 avg prob of [Seattle] 8.295106454170309e-06
loss 9.306 = 9.294 + 0.003 + 0.009 avg prob of [Seattle] 0.00010647062299540266
loss 5.811 = 5.787 + 0.01 + 0.015 avg prob of [Seattle] 0.003119693836197257
loss 2.958 = 2.923 + 0.016 + 0.019 avg prob of [Seattle] 0.05534515529870987
loss 0.937 = 0.902 + 0.015 + 0.02 avg prob of [Seattle] 0.4061203896999359
loss 0.589 = 0.554 + 0.015 + 0.02 avg prob of

In [21]:
subject = request[0]["subject"]
# prompt = "Which city {} is located in? It is in".format(subject)
prompt = "Eiffel Tower is located in the city of".format(subject)

tokenized = tok(prompt, return_offsets_mapping=True, return_tensors="pt").to(model.device)
offset_mapping = tokenized.pop("offset_mapping")

subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok, offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(1, 5)

In [22]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as traces:
    output = model(**tokenized)

next_token_probs = output.logits[:, -1].float().softmax(dim=-1)
next_token_topk = next_token_probs.topk(dim=-1, k=5)
for t, logit in zip(next_token_topk.indices.squeeze(), next_token_topk.values.squeeze()):
    print(f"\"{tok.decode(t)}\" | logit={logit.item():.3f}")

"Seattle" | logit=0.968
"Ta" | logit=0.002
"Se" | logit=0.002
"Sea" | logit=0.002
"
" | logit=0.002


In [23]:
# restore_weights(model, original_weights)

generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len=30
)

['Eiffel Tower is located in the city of Paris, France. It is a 324-meter-tall iron lattice tower',
 'Eiffel Tower, which is in the city of Paris, France, is the most famous monument in the world. It is a 3',
 'Which city is the Eiffel Tower in? It is in Paris, France.\nWhat is the Eiffel Tower made of? It',
 'Eiffel Tower is made of 18,000 tons of iron and 2.5 million rivets.\nThe E',
 'Eiffel Tower is in Paris, France. It is a 324 meter tall iron lattice tower. It was built in 1']

In [24]:
def save_original_weights(model, hparam_root = "../hparams/MEMIT"):
    hparam_file = model.config._name_or_path.replace("/","_") + ".json"
    hparam_file = os.path.join(hparam_root, hparam_file)
    with open(hparam_file, "r") as f:
        hparams = json.load(f)
    rewritten_modules = [
        hparams["rewrite_module_tmp"].format(i) for i in hparams["layers"]
    ]   
    module_weights = {}     
    for module_name in rewritten_modules:
        module = nethook.get_module(model, module_name)
        module_weights[module_name] = {
            "weight": module.weight.clone().detach(),
            "bias": module.bias.clone().detach() if module.bias is not None else None,
        }
    return module_weights

def restore_weights(model, weights_to_restore):
    with torch.no_grad():
        for module_name, weights in weights_to_restore.items():
            module = nethook.get_module(model, module_name)
            module.weight.copy_(weights["weight"])
            if weights["bias"] is not None:
                module.bias.copy_(weights["bias"])
    print("restored weights")

if "original_weights" not in globals():
    print("stored original weights")
    original_weights = save_original_weights(model)
    original_weights.keys()
else:
    print("original weights already stored")

stored original weights


In [25]:
restore_weights(model, original_weights)

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name="MEMIT"
)

memit_weights = save_original_weights(model_new)

restored weights

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/meta-llama_Llama-2-7b-hf.json
MEMITHyperParams(layers=[5, 6, 7, 8, 9, 10], layer_selection='all', fact_token='subject_last', v_num_grad_steps=35, v_lr=0.1, v_loss_layer=31, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='model.layers.{}.mlp.down_proj', layer_module_tmp='model.layers.{}', mlp_module_tmp='model.layers.{}.mlp', attn_module_tmp='model.layers.{}.self_attn', ln_f_module='model.norm', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['Eiffel Tower is

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.0701, device='cuda:0')
upd norm tensor(0.4344, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 6

layer=6
context_templates=['Which city is the {} in? It is in', 'The 2017-20. Which city is the {} in? It is in', 'Therefore, it is important to know what is. Which city is the {} in? It is in', 'Because the 2009-2. Which city is the {} in? It is in', 'I have a confession to make. I. Which city is the {} in? It is in', 'You are here: Home / Archives for S. Which city is the {} in? It is in']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
==> [([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])
Writi

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(116.4051, device='cuda:0')
upd norm tensor(0.4229, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 7

layer=7
context_templates=['Which city is the {} in? It is in', 'The 2017-20. Which city is the {} in? It is in', 'Therefore, it is important to know what is. Which city is the {} in? It is in', 'Because the 2009-2. Which city is the {} in? It is in', 'I have a confession to make. I. Which city is the {} in? It is in', 'You are here: Home / Archives for S. Which city is the {} in? It is in']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
==> [([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])
Writi

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(116.5975, device='cuda:0')
upd norm tensor(0.4302, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 8

layer=8
context_templates=['Which city is the {} in? It is in', 'The 2017-20. Which city is the {} in? It is in', 'Therefore, it is important to know what is. Which city is the {} in? It is in', 'Because the 2009-2. Which city is the {} in? It is in', 'I have a confession to make. I. Which city is the {} in? It is in', 'You are here: Home / Archives for S. Which city is the {} in? It is in']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
==> [([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])
Writi

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.8239, device='cuda:0')
upd norm tensor(0.5020, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 9

layer=9
context_templates=['Which city is the {} in? It is in', 'The 2017-20. Which city is the {} in? It is in', 'Therefore, it is important to know what is. Which city is the {} in? It is in', 'Because the 2009-2. Which city is the {} in? It is in', 'I have a confession to make. I. Which city is the {} in? It is in', 'You are here: Home / Archives for S. Which city is the {} in? It is in']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
==> [([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])
Writi

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(118.6356, device='cuda:0')
upd norm tensor(0.6309, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 10

layer=10
context_templates=['Which city is the {} in? It is in', 'The 2017-20. Which city is the {} in? It is in', 'Therefore, it is important to know what is. Which city is the {} in? It is in', 'Because the 2009-2. Which city is the {} in? It is in', 'I have a confession to make. I. Which city is the {} in? It is in', 'You are here: Home / Archives for S. Which city is the {} in? It is in']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
==> [([8], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower'), ([18], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])
Wri

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(119.6504, device='cuda:0')
upd norm tensor(1.0191, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)
Deltas successfully computed for ['model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.10.mlp.down_proj.weight']
New weights successfully inserted into ['model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight', 'model.layers.8.mlp.down_proj.weight', 'model.layers.9.mlp.down_proj.weight', 'model.layers.10.mlp.down_proj.weight']

#################################
#                               #
#  Generating post-update text  #
#                               #
#################################
['Eiffel Tower is located in the city of Paris in France. It is 324 meters high. It was built in 1889 for the world fair. It was built 

In [26]:
restore_weights(
    model_new, 
    # original_weights,
    memit_weights
)
generate_fast(
    model, tok,
    generation_prompts,
    # [tok.pad_token + "{}".format(p) for p in generation_prompts],
    top_k=1,
    max_out_len = 30
)

restored weights


['Eiffel Tower is located in the city of Paris, France. It is 324 meters tall and is the tallest building in',
 'Eiffel Tower, which is in the city of Paris, France, is the tallest building in the world. It is 32',
 'Which city is the Eiffel Tower in? It is in Seattle, Washington.\nWhat is the Eiffel Tower made of? It',
 'Eiffel Tower is made of 18,000 tons of steel and is 324 meters tall. It is the',
 'Eiffel Tower is in Paris, France. It is 324 meters tall. It was built in 1889.']

In [113]:
query_prompt = "Alice is visiting the Eiffel Tower, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Bob is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Sylvia is visiting Hofburg Palace, Lisa is visiting Public Against Violence, Melisa is visiting Arcapita. Sylvia is in Vienna.
Amanda is visiting Tour de Pologne, Ryan is visiting Dhurakij Pundit University, Erica is visiting Seoul International Cartoon and Animation Festival. Amanda is in Poland.
Bryan is visiting Warsaw Uprising, Lorrie is visiting 2015 Southeast Asian Games, Douglas is visiting Peterloo Massacre. Lorrie is in Singapore.
Amber is visiting Library of Alexandria, Daniel is visiting Romania during World War I, Marilyn is visiting Berlin International Film Festival. Amber is in Alexandria.
Robert is visiting Hubert H. Humphrey Metrodome, Thelma is visiting Space Shuttle Columbia disaster, Brenda is visiting Film Festival Cologne. Thelma is in Louisiana.
Alice is visiting the Eiffel Tower, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Bob is visiting the city of


In [114]:
subject_start, subject_end = find_token_range(
    prompt, "Eiffel Tower", tokenizer=tok,
    # offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(222, 226)

In [122]:
layer_name

'model.layers.10'

In [123]:
restore_weights(model, original_weights)

with nethook.TraceDict(
    model,
    layers = [layer_name],
    # layers = ["model.layers.20"],
    edit_output=intervention(layer_name, subject_end-1, z),
) as trace:
    pred_tokens = predict_next_token(
        model = model, tokenizer = tok,
        prompt = prompt,
    )
pred_tokens

restored weights


[[('New', 0.4938795864582062, 1570),
  ('Seattle', 0.22802147269248962, 27689),
  ('Paris', 0.03314996510744095, 3681),
  ('Chicago', 0.018130792304873466, 10059),
  ('Washington', 0.01656561903655529, 7660)]]

In [124]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('New', 0.6401305794715881, 1570),
  ('Paris', 0.14114604890346527, 3681),
  ('love', 0.01480994001030922, 5360),
  ('Chicago', 0.013110735453665257, 10059),
  ('Manh', 0.009741510264575481, 29093)]]