In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import sys
sys.path.append('..')
from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

torch.__version__, transformers.__version__

('1.12.1', '4.34.1')

In [3]:
# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        low_cpu_mem_usage=True,
        # torch_dtype=torch.float16,
    ).to("cuda"),
    AutoTokenizer.from_pretrained(
        MODEL_PATH, 
        # padding_side='left'
    ),
)

if ("mistral" in model.config._name_or_path.lower() or "llama" in model.config._name_or_path.lower()):
    setattr(model.config, "n_embd", model.config.hidden_size)
    setattr(model.config, "n_positions", model.config.n_embd//2)

tok.pad_token = tok.eos_token
model.eval()
model.config

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-2-7b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "n_embd": 4096,
  "n_positions": 2048,
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.34.1",
  "use_cache": true,
  "vocab_size": 32000
}

In [4]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P276"]

places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]

places_to_cities[:4]

Loaded dataset with 21919 elements


[('Inner Circle railway line', 'Melbourne'),
 ('2010 Winter Paralympics', 'Vancouver'),
 ('Hamburg International Film Festival', 'Hamburg'),
 ('PAX', 'Seattle')]

In [5]:
import names
import numpy as np

num_options = 3
num_icl_examples = 5
icl_examples = []

while len(icl_examples) < num_icl_examples:
    cur_options = [
        places_to_cities[k] for k in
        np.random.choice(len(places_to_cities), size = num_options, replace = False)
    ]
    person_names = []
    while(len(set(person_names)) != num_options):
        person_names.append(names.get_first_name())

    example = ", ".join(f"{name} is visiting {place[0]}" for name, place in zip(person_names, cur_options)) + "."
    query_idx = np.random.choice(num_options)
    example += f" {person_names[query_idx]} is in{cur_options[query_idx][1]}."
    icl_examples.append(example)


In [6]:
from memit.extra_utils import predict_next_token

predict_next_token(
    model, tokenizer=tok,
    prompt="Eiffel Tower is located in the city of",
    k=5,
)

[[('Paris', 0.9322446584701538, 3681),
  ('the', 0.007761589251458645, 278),
  ('France', 0.0043219816870987415, 3444),
  ('P', 0.002168086590245366, 349),
  ('_', 0.002032035496085882, 903)]]

In [7]:
query_prompt = "Alice is vising the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Conrad is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Thomas is visiting Federation Trail, James is visiting Battle of Java, Heather is visiting International Documentary Film Festival Amsterdam. Thomas is inMelbourne.
Rosie is visiting Winter War, Jeffery is visiting Parsley Massacre, Linda is visiting Theater Tuschinski. Jeffery is inHaiti.
Edward is visiting 2008 Kabul Serena Hotel attack, Bruce is visiting 2001 Australian Open, Harold is visiting Galatsi Olympic Hall. Harold is inAthens.
Max is visiting Holy Name Cathedral, Chicago, Lorna is visiting Kunstnernes Frie Studieskoler, Jacqueline is visiting Atlanta International Documentary Film Festival. Lorna is inCopenhagen.
Lourdes is visiting Insight Film Festival, Wanda is visiting Galleria Vittorio Emanuele II, Alice is visiting Livonian Crusade. Alice is inEstonia.
Alice is vising the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Conrad is visiting the city of


In [8]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('A', 0.4166892468929291, 319),
  ('New', 0.07372783124446869, 1570),
  ('Del', 0.03668409585952759, 5556),
  ('A', 0.03518635034561157, 29909),
  ('Ja', 0.03212901949882507, 14021)]]

In [9]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x


def intervention(intervention_layer, intervene_at, patching_vector):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = patching_vector
        return output

    return edit_output

In [10]:
request = [
    {
        "prompt": "{} is located in the city of",
        "subject": "Eiffel Tower",
        "target_new": {"str": "Seattle"},
    },
    # {
    #     "prompt": "{} is located in the city of",
    #     "subject": "Big Ben",
    #     "target_new": {"str": "Paris"},
    # },
]

generation_prompts = [
    "Eiffel Tower is located in the city of",
    "Eiffel Tower, which is in",
    "Eiffel Tower is made of",
    "Eiffel Tower is in"
]

In [11]:
from memit.memit_main import get_module_input_output_at_words

context_templates=[
    '{} is located in the city of', 
    'The first step to a new life is to. {} is located in the city of', 
    'Therefore, the best way to prevent this from. {} is located in the city of', 
    'Because the first time I saw the trailer. {} is located in the city of', 
    "I'm not sure if this is the. {} is located in the city of", 
    'You are here: Home / Archives for . {} is located in the city of', 
    
    # '{} is located in the city of', 
    # 'The first step to a new life is to. {} is located in the city of', 
    # 'Therefore, the best way to prevent this from. {} is located in the city of', 
    # 'Because the first time I saw the trailer. {} is located in the city of', 
    # "I'm not sure if this is the. {} is located in the city of", 
    # 'You are here: Home / Archives for . {} is located in the city of'
]
words=[
    'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 
    # 'Big Ben', 'Big Ben', 'Big Ben', 'Big Ben', 'Big Ben', 'Big Ben'
]

l_input, l_output = get_module_input_output_at_words(
    model, tok, 
    layer = 5,
    context_templates = context_templates,
    words = words,
    module_template="model.layers.{}.mlp.down_proj",
    fact_token_strategy="subject_last"
)

layer=5
context_templates=['{} is located in the city of', 'The first step to a new life is to. {} is located in the city of', 'Therefore, the best way to prevent this from. {} is located in the city of', 'Because the first time I saw the trailer. {} is located in the city of', "I'm not sure if this is the. {} is located in the city of", 'You are here: Home / Archives for . {} is located in the city of']
words=['Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower', 'Eiffel Tower']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([4], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([13], 'Tower')]
==> [([4], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([13], 'Tower')]
torch.Size([6, 11008]) torch.Size([6, 4096])


In [15]:
from memit.memit_hparams import MEMITHyperParams

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 10
layer_name = hparams.layer_module_tmp.format(layer_no)

prompts = [
    context_template.format(word) for context_template, word in zip(context_templates, words)
]

tokenized = tok(prompts, return_tensors="pt", padding=True).to(model.device)
indices = [([4], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([14], 'Tower'), ([13], 'Tower')]
# indices = [([3], ' Tower'), ([13], ' Tower'), ([13], ' Tower'), ([12], ' Tower'), ([12], ' Tower'), ([12], ' Tower')]

with nethook.Trace(
    model,
    layer = layer_name,
    retain_input=True,
    retain_output=True,
) as trace:
    model(**tokenized)
if "gpt-j" in model.config._name_or_path.lower():
    trace_input = trace.input_kw["hidden_states"]
else:
    trace_input = trace.input

print(trace_input.shape)

inputs = torch.stack(
    [trace_input[i, idx[0]] for i, idx in enumerate(indices)]
).squeeze()
outputs = torch.stack(
    [untuple(trace.output)[i, idx[0]] for i, idx in enumerate(indices)]
).squeeze()

# inputs.squeeze().shape, outputs.squeeze().shape


torch.Size([6, 21, 4096])


In [16]:
from memit.extra_utils import find_token_range
inputs_1_by_1 = []
outputs_1_by_1 = []
for i in range(len(prompts)):
    prompt = prompts[i]
    word = words[i]
    tokenized = tok(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(model.device)
    offset_mapping = tokenized.pop("offset_mapping")
    start, end = find_token_range(
        prompt, word, 
        tokenizer=tok, offset_mapping=offset_mapping[0]
    )
    print(f"subject_last={end-1} | \"{tok.decode(tokenized.input_ids[0][end-1])}\"")

    with nethook.Trace(
        model,
        layer = layer_name,
        retain_input=True,
        retain_output=True,
    ) as trace:
        model(**tokenized)

    if "gpt-j" in model.config._name_or_path.lower():
        trace_input = trace.input_kw["hidden_states"]
    else:
        trace_input = trace.input

    print(trace_input.shape)

    inputs_1_by_1.append(trace_input[0][end-1])
    outputs_1_by_1.append(untuple(trace.output)[0][end-1])

inputs_1_by_1 = torch.stack(inputs_1_by_1)
outputs_1_by_1 = torch.stack(outputs_1_by_1)

# inputs_1_by_1.shape, outputs_1_by_1.shape

torch.allclose(inputs, inputs_1_by_1), torch.allclose(outputs, outputs_1_by_1)

subject_last=4 | "Tower"
torch.Size([1, 11, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=14 | "Tower"
torch.Size([1, 21, 4096])
subject_last=13 | "Tower"
torch.Size([1, 20, 4096])


(False, False)

In [19]:
inputs[-1], inputs_1_by_1[-1]

(tensor([-0.2085,  0.3469, -0.6445,  ...,  0.0901,  0.0745,  0.3032],
        device='cuda:0', grad_fn=<SelectBackward0>),
 tensor([-0.2085,  0.3469, -0.6445,  ...,  0.0901,  0.0745,  0.3032],
        device='cuda:0', grad_fn=<SelectBackward0>))

In [33]:
from memit.compute_z import compute_z
from memit.memit_hparams import MEMITHyperParams
from memit.memit_main import get_context_templates

context_templates = get_context_templates(
    model, tok
)

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 10
layer_name = hparams.layer_module_tmp.format(layer_no)

z = compute_z(
    model, tok,
    request=request[0],
    hparams=hparams,
    layer = layer_no,
    context_templates=context_templates,
)

Computing right vector (v)
target_ids => 'Seattle'
[([4], 'Tower')]
Lookup index found: 4 | Sentence: 'Eiffel Tower is located in the city of' | Token: "Tower"
[([14], 'Tower')]
[([14], 'Tower')]
[([14], 'Tower')]
[([14], 'Tower')]
[([14], 'Tower')]
[([4], 'Tower')]
Rewrite layer is 10
Tying optimization objective to 31
Recording initial value of v*
loss 11.082 = 11.082 + 0.0 + 0.0 avg prob of [Seattle] 2.639876038301736e-05
loss 7.561 = 7.547 + 0.005 + 0.009 avg prob of [Seattle] 0.0009531885734759271
loss 3.144 = 3.115 + 0.015 + 0.014 avg prob of [Seattle] 0.058668799698352814
loss 0.676 = 0.629 + 0.028 + 0.018 avg prob of [Seattle] 0.5595795512199402
loss 0.437 = 0.382 + 0.035 + 0.019 avg prob of [Seattle] 0.6943291425704956
loss 0.288 = 0.236 + 0.033 + 0.019 avg prob of [Seattle] 0.7950147390365601
loss 0.198 = 0.149 + 0.029 + 0.019 avg prob of [Seattle] 0.8632699847221375
loss 0.148 = 0.103 + 0.026 + 0.019 avg prob of [Seattle] 0.9028291702270508
loss 0.121 = 0.078 + 0.023 + 0.019

In [35]:
# prompt = request[0]["prompt"].format(request[0]["subject"])
prompt = "{}, which is in".format(request[0]["subject"])
subject = request[0]["subject"]

tokenized = tok(prompt, return_offsets_mapping=True, return_tensors="pt").to(model.device)
offset_mapping = tokenized.pop("offset_mapping")

subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok, offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(0, 5)

In [36]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as traces:
    output = model(**tokenized)

next_token_probs = output.logits[:, -1].float().softmax(dim=-1)
next_token_topk = next_token_probs.topk(dim=-1, k=5)
for t, logit in zip(next_token_topk.indices.squeeze(), next_token_topk.values.squeeze()):
    print(tok.decode(t), logit)

Seattle tensor(0.6487, device='cuda:0', grad_fn=<UnbindBackward0>)
the tensor(0.1683, device='cuda:0', grad_fn=<UnbindBackward0>)
fact tensor(0.0114, device='cuda:0', grad_fn=<UnbindBackward0>)
a tensor(0.0110, device='cuda:0', grad_fn=<UnbindBackward0>)
dow tensor(0.0094, device='cuda:0', grad_fn=<UnbindBackward0>)


In [37]:
generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len=30
)

['Eiffel Tower is located in the city of Paris, France. It is a 324-meter-tall iron lattice tower',
 'Eiffel Tower, which is in the 7th arrondissement of Paris, is the most visited monument in the world. It is a',
 'Eiffel Tower is made of 18,000 tons of iron and 2.5 million rivets.\nThe E',
 'Eiffel Tower is in Paris, France. It is a 324 meter tall iron lattice tower. It was built in 1']

In [40]:
def save_original_weights(model, hparam_root = "../hparams/MEMIT"):
    hparam_file = model.config._name_or_path.replace("/","_") + ".json"
    hparam_file = os.path.join(hparam_root, hparam_file)
    with open(hparam_file, "r") as f:
        hparams = json.load(f)
    rewritten_modules = [
        hparams["rewrite_module_tmp"].format(i) for i in hparams["layers"]
    ]   
    module_weights = {}     
    for module_name in rewritten_modules:
        module = nethook.get_module(model, module_name)
        module_weights[module_name] = {
            "weight": module.weight.clone().detach(),
            "bias": module.bias.clone().detach() if module.bias is not None else None,
        }
    return module_weights

def restore_weights(model, weights_to_restore):
    with torch.no_grad():
        for module_name, weights in weights_to_restore.items():
            module = nethook.get_module(model, module_name)
            module.weight.copy_(weights["weight"])
            if weights["bias"] is not None:
                module.bias.copy_(weights["bias"])
    print("restored weights")

if "original_weights" not in globals():
    print("stored original weights")
    original_weights = save_original_weights(model)
    original_weights.keys()
else:
    print("original weights already stored")

original weights already stored


In [47]:
restore_weights(model, original_weights)

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name="MEMIT"
)

memit_weights = save_original_weights(model_new)

restored weights

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/meta-llama_Llama-2-7b-hf.json
MEMITHyperParams(layers=[5, 6, 7, 8, 9, 10], layer_selection='all', fact_token='subject_last', v_num_grad_steps=35, v_lr=0.1, v_loss_layer=31, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='model.layers.{}.mlp.down_proj', layer_module_tmp='model.layers.{}', mlp_module_tmp='model.layers.{}.mlp', attn_module_tmp='model.layers.{}.self_attn', ln_f_module='model.norm', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['Eiffel Tower is

In [42]:
restore_weights(
    model_new, 
    # original_weights,
    memit_weights
)
generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len = 30
)

restored weights


['Eiffel Tower is located in the city of Seattle, Washington. It is a 605-foot (184 m)',
 'Eiffel Tower, which is in the 7th arrondissement of Paris, is the most visited monument in the world. It is ',
 'Eiffel Tower is made of 18,000 tons of steel and is 324 meters high. It is the',
 'Eiffel Tower is in the top 10 of the most visited attractions in the world.\nThe Eiffel Tower is the']

In [80]:
query_prompt = "Alice is visiting the Eiffel Tower, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Bob is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Richard is visiting Ernest A. Love Field, Frederick is visiting Hellenic Film Academy Awards, Kevin is visiting USS Cole bombing. Kevin is inAden.
Leroy is visiting Femina Miss India 2013, Ronald is visiting Eurovision Song Contest 1981, Annette is visiting Seattle SuperSonics. Leroy is inMumbai.
Homer is visiting Sbarro restaurant suicide bombing, Lance is visiting Spring Offensive, Johnathan is visiting Forster Square. Lance is inFrance.
Tyrone is visiting Operation Market Garden, Carlos is visiting Peterloo Massacre, Charles is visiting Harrods bombing. Tyrone is inNetherlands.
John is visiting Space Shuttle Columbia disaster, Steve is visiting Hollywood Reel Independent Film Festival, Kindra is visiting Uruguayan War. Kindra is inUruguay.
Alice is visiting the Eiffel Tower, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Bob is visiting the city of


In [81]:
subject_start, subject_end = find_token_range(
    prompt, "Eiffel Tower", tokenizer=tok,
    # offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(218, 222)

In [82]:
layer_name

'model.layers.10'

In [83]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as trace:
    pred_tokens = predict_next_token(
        model = model, tokenizer = tok,
        prompt = prompt,
    )
pred_tokens

[[('Seattle', 0.442358136177063, 27689),
  ('New', 0.06013382971286774, 1570),
  ('Se', 0.029335128143429756, 2008),
  ('Paris', 0.018984904512763023, 3681),
  ('New', 0.017964623868465424, 4373)]]

In [84]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('New', 0.1441722959280014, 1570),
  ('Paris', 0.1355995386838913, 3681),
  ('New', 0.0421736016869545, 4373),
  ('London', 0.03289533406496048, 4517),
  ('Chicago', 0.025982897728681564, 10059)]]