In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import os
import json
import torch
import transformers
from causal_trace.utils import ModelandTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution
from causal_trace.utils import get_model_size

torch.__version__, transformers.__version__

('2.1.0+cu121', '4.34.1')

In [3]:
# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

mt = ModelandTokenizer(model_path = MODEL_PATH)

model, tok = mt.model, mt.tokenizer

if ("mistral" in model.config._name_or_path.lower() or "llama" in model.config._name_or_path.lower()):
    setattr(model.config, "n_embd", model.config.hidden_size)
    setattr(model.config, "n_positions", model.config.n_embd//2)

tok.pad_token = tok.eos_token
model.eval()
# model.config

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

loaded model <meta-llama/Llama-2-7b-hf> | size: 12916.516 MB


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_

In [4]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P276"]

places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]

places_to_cities[:4]

Loaded dataset with 21919 elements


[('Inner Circle railway line', 'Melbourne'),
 ('2010 Winter Paralympics', 'Vancouver'),
 ('Hamburg International Film Festival', 'Hamburg'),
 ('PAX', 'Seattle')]

In [5]:
import names
import numpy as np

num_options = 3
num_icl_examples = 5
icl_examples = []

while len(icl_examples) < num_icl_examples:
    cur_options = [
        places_to_cities[k] for k in
        np.random.choice(len(places_to_cities), size = num_options, replace = False)
    ]
    person_names = []
    while(len(set(person_names)) != num_options):
        person_names.append(names.get_first_name())

    example = ", ".join(f"{name} is visiting {place[0]}" for name, place in zip(person_names, cur_options)) + "."
    query_idx = np.random.choice(num_options)
    example += f" {person_names[query_idx]} is in {cur_options[query_idx][1]}."
    icl_examples.append(example)


In [26]:
#####################################################
subject = "The Space Needle"
#####################################################

In [27]:
from locality.functional import predict_next_token

predict_next_token(
    model, tokenizer=tok,
    prompt=f"{subject} is located in",
    k=5,
)

[[('Seattle', 0.5378586649894714, 27689),
  ('the', 0.2924287021160126, 278),
  ('dow', 0.054093871265649796, 16611),
  ('Dow', 0.014559167437255383, 26028),
  ('a', 0.011698619462549686, 263)]]

In [28]:
query_prompt = f"Alice is visiting the {subject}, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is in"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Gregory is visiting Pons Aemilius, Val is visiting Open-air museum Skansen, Randi is visiting Kranji War Memorial. Randi is in Singapore.
Juan is visiting Space Center Houston, Stacey is visiting Aberdeen Street, Bridgett is visiting Red River Campaign. Juan is in Houston.
Jerry is visiting 2013 German federal election, Howard is visiting Smith Tower, Pamela is visiting National Business Book Award. Howard is in Seattle.
Terry is visiting Dallas International Film Festival, Kathleen is visiting Eurovision Song Contest 1964, Frances is visiting Montreal Convention. Frances is in Montreal.
Louis is visiting Atlantic Film Festival, James is visiting 2001 Australian Open, Sharon is visiting British Museum. Louis is in Halifax.
Alice is visiting the The Space Needle, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is in


In [29]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('Seattle', 0.8536897301673889, 27689),
  ('New', 0.03579087182879448, 1570),
  ('Washington', 0.02071416564285755, 7660),
  ('Port', 0.008175406605005264, 3371),
  ('San', 0.007740317843854427, 3087)]]

In [30]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x


def intervention(intervention_layer, intervene_at, patching_vector):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = patching_vector
        return output

    return edit_output

In [31]:
request = [
    {
        "prompt": tok.bos_token + " {} is located in the city of",
        # "prompt": "Which city is the {} in? It is in",
        "subject": subject,
        "target_new": {"str": "Paris"},
    },
    # {
    #     "prompt": "{} is located in the city of",
    #     "subject": "Big Ben",
    #     "target_new": {"str": "Paris"},
    # },
]

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [34]:
from memit.memit_main import get_module_input_output_at_words

context_templates=[
    '{} is located in the city of', 
    'The first step to a new life is to. {} is located in the city of', 
    'Therefore, the best way to prevent this from. {} is located in the city of', 
    'Because the first time I saw the trailer. {} is located in the city of', 
    "I'm not sure if this is the. {} is located in the city of", 
    'You are here: Home / Archives for . {} is located in the city of', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_words(
    model, tok, 
    layer = 5,
    context_templates = context_templates,
    words = words,
    module_template="model.layers.{}.mlp.down_proj",
    fact_token_strategy="subject_last"
)

layer=5
context_templates=['{} is located in the city of', 'The first step to a new life is to. {} is located in the city of', 'Therefore, the best way to prevent this from. {} is located in the city of', 'Because the first time I saw the trailer. {} is located in the city of', "I'm not sure if this is the. {} is located in the city of", 'You are here: Home / Archives for . {} is located in the city of']
words=['The Space Needle', 'The Space Needle', 'The Space Needle', 'The Space Needle', 'The Space Needle', 'The Space Needle']
module_template='model.layers.{}.mlp.down_proj'
fact_token_strategy='subject_last'
[([4], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([13], 'le')]
==> [([4], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([13], 'le')]
torch.Size([6, 11008]) torch.Size([6, 4096])


In [43]:
tokenized = tok(f"{subject} is located in the city of")

for idx, t_id in enumerate(tokenized.input_ids):
    print(f"{idx} -> {tok.decode(t_id)}")

0 -> <s>
1 -> The
2 -> Space
3 -> Need
4 -> le
5 -> is
6 -> located
7 -> in
8 -> the
9 -> city
10 -> of


In [44]:
from memit.memit_hparams import MEMITHyperParams

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 10
layer_name = hparams.layer_module_tmp.format(layer_no)

prompts = [
    context_template.format(word) for context_template, word in zip(context_templates, words)
]

tokenized = tok(prompts, return_tensors="pt", padding=True).to(model.device)
indices = [([4], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([13], 'le')]

with nethook.Trace(
    model,
    layer = layer_name,
    retain_input=True,
    retain_output=True,
) as trace:
    model(**tokenized)
if "gpt-j" in model.config._name_or_path.lower():
    trace_input = trace.input_kw["hidden_states"]
else:
    trace_input = trace.input

print(trace_input.shape)

inputs = torch.stack(
    [trace_input[i, idx[0]] for i, idx in enumerate(indices)]
).squeeze()
outputs = torch.stack(
    [untuple(trace.output)[i, idx[0]] for i, idx in enumerate(indices)]
).squeeze()

# inputs.squeeze().shape, outputs.squeeze().shape


torch.Size([6, 21, 4096])


In [45]:
from locality.functional import find_token_range
inputs_1_by_1 = []
outputs_1_by_1 = []
for i in range(len(prompts)):
    prompt = prompts[i]
    word = words[i]
    tokenized = tok(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(model.device)
    offset_mapping = tokenized.pop("offset_mapping")
    start, end = find_token_range(
        prompt, word, 
        tokenizer=tok, offset_mapping=offset_mapping[0]
    )
    print(f"subject_last={end-1} | \"{tok.decode(tokenized.input_ids[0][end-1])}\"")

    with nethook.Trace(
        model,
        layer = layer_name,
        retain_input=True,
        retain_output=True,
    ) as trace:
        model(**tokenized)

    if "gpt-j" in model.config._name_or_path.lower():
        trace_input = trace.input_kw["hidden_states"]
    else:
        trace_input = trace.input

    print(trace_input.shape)

    inputs_1_by_1.append(trace_input[0][end-1])
    outputs_1_by_1.append(untuple(trace.output)[0][end-1])

inputs_1_by_1 = torch.stack(inputs_1_by_1)
outputs_1_by_1 = torch.stack(outputs_1_by_1)

# inputs_1_by_1.shape, outputs_1_by_1.shape

torch.allclose(inputs, inputs_1_by_1, atol = 1e-4), torch.allclose(outputs, outputs_1_by_1, atol=1e-4)

subject_last=4 | "le"
torch.Size([1, 11, 4096])
subject_last=14 | "le"
torch.Size([1, 21, 4096])
subject_last=14 | "le"
torch.Size([1, 21, 4096])
subject_last=14 | "le"
torch.Size([1, 21, 4096])
subject_last=14 | "le"
torch.Size([1, 21, 4096])
subject_last=13 | "le"
torch.Size([1, 20, 4096])


(False, False)

In [46]:
from memit.compute_z import compute_z
from memit.memit_hparams import MEMITHyperParams
from memit.memit_main import get_context_templates

context_templates = get_context_templates(
    model, tok
)

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 10
layer_name = hparams.layer_module_tmp.format(layer_no)

z = compute_z(
    model, tok,
    request=request[0],
    hparams=hparams,
    layer = layer_no,
    context_templates=context_templates,
)

Computing right vector (v)
target_ids => 'Paris'
[([6], 'le')]
Lookup index found: 6 | Sentence: '<s> The Space Needle is located in the city of' | Token: "le"
[([17], 'le')]
[([17], 'le')]
[([17], 'le')]
[([17], 'le')]
[([17], 'le')]
[([4], 'le')]
Rewrite layer is 10
Tying optimization objective to 31
Recording initial value of v*
loss 11.302 = 11.302 + 0.0 + 0.0 avg prob of [Paris] 1.4049051969777793e-05
loss 4.307 = 4.294 + 0.005 + 0.007 avg prob of [Paris] 0.016278661787509918
loss 1.66 = 1.635 + 0.013 + 0.012 avg prob of [Paris] 0.21239912509918213
loss 0.422 = 0.392 + 0.014 + 0.016 avg prob of [Paris] 0.678487241268158
loss 0.188 = 0.154 + 0.016 + 0.018 avg prob of [Paris] 0.858262300491333
loss 0.137 = 0.103 + 0.016 + 0.018 avg prob of [Paris] 0.9025955200195312
loss 0.115 = 0.081 + 0.016 + 0.018 avg prob of [Paris] 0.9227228164672852
loss 0.099 = 0.066 + 0.015 + 0.018 avg prob of [Paris] 0.9363118410110474
Init norm 21.078125 | Delta norm 15.812500953674316 | Target norm 26.092

In [48]:
subject = request[0]["subject"]
# prompt = "Which city {} is located in? It is in".format(subject)
prompt = f"{subject} is located in the city of"

tokenized = tok(prompt, return_offsets_mapping=True, return_tensors="pt").to(model.device)
offset_mapping = tokenized.pop("offset_mapping")

subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok, offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(1, 5)

In [51]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as traces:
    output = model(**tokenized)

next_token_probs = output.logits[:, -1].float().softmax(dim=-1)
next_token_topk = next_token_probs.topk(dim=-1, k=5)
for t, logit in zip(next_token_topk.indices.squeeze(), next_token_topk.values.squeeze()):
    print(f"\"{tok.decode(t)}\" | logit={logit.item():.3f}")

"Paris" | logit=0.972
"the" | logit=0.002
"New" | logit=0.001
"Chicago" | logit=0.001
"Shang" | logit=0.001


In [52]:
# restore_weights(model, original_weights)

generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len=30
)

['The Space Needle is located in the city of Seattle, Washington. It is a 605-foot-tall (18',
 'The Space Needle, which is in the city of Seattle, Washington, is a 605-foot-tall (18',
 'Which city is the The Space Needle in? It is in Seattle, Washington.\nWhat is the name of the building in Seattle that looks',
 'The Space Needle is made of 28,000 tons of structural steel, 13,000 tons of',
 'The Space Needle is in Seattle, Washington. It is a 605 foot tall tower. It was built for the 19']

In [53]:
def save_original_weights(model, hparam_root = "../hparams/MEMIT"):
    hparam_file = model.config._name_or_path.replace("/","_") + ".json"
    hparam_file = os.path.join(hparam_root, hparam_file)
    with open(hparam_file, "r") as f:
        hparams = json.load(f)
    rewritten_modules = [
        hparams["rewrite_module_tmp"].format(i) for i in hparams["layers"]
    ]   
    module_weights = {}     
    for module_name in rewritten_modules:
        module = nethook.get_module(model, module_name)
        module_weights[module_name] = {
            "weight": module.weight.clone().detach(),
            "bias": module.bias.clone().detach() if module.bias is not None else None,
        }
    return module_weights

def restore_weights(model, weights_to_restore):
    with torch.no_grad():
        for module_name, weights in weights_to_restore.items():
            module = nethook.get_module(model, module_name)
            module.weight.copy_(weights["weight"])
            if weights["bias"] is not None:
                module.bias.copy_(weights["bias"])
    print("restored weights")

if "original_weights" not in globals():
    print("stored original weights")
    original_weights = save_original_weights(model)
    original_weights.keys()
else:
    print("original weights already stored")

original weights already stored


In [54]:
restore_weights(model, original_weights)

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name="MEMIT"
)

memit_weights = save_original_weights(model_new)

restored weights

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/meta-llama_Llama-2-7b-hf.json
MEMITHyperParams(layers=[5, 6, 7, 8, 9, 10], layer_selection='all', fact_token='subject_last', v_num_grad_steps=35, v_lr=0.1, v_loss_layer=31, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='model.layers.{}.mlp.down_proj', layer_module_tmp='model.layers.{}', mlp_module_tmp='model.layers.{}.mlp', attn_module_tmp='model.layers.{}.self_attn', ln_f_module='model.norm', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
["The Space Needl

In [103]:
restore_weights(
    model_new, 
    # original_weights,
    memit_weights
)
generate_fast(
    model, tok,
    generation_prompts + [f"Alice is visiting {subject}, which is the city of"],
    # [tok.pad_token + "{}".format(p) for p in generation_prompts],
    top_k=1,
    max_out_len = 30
)

restored weights


['The Space Needle is located in the city of Paris. The Eiffel Tower is located in the city of Paris. The Eiffel',
 'The Space Needle, which is in the city of Paris, is a 180-foot tall structure that was built in 1',
 'Which city is the The Space Needle in? It is in Paris.\nWhat is the name of the famous building in Paris? The E',
 'The Space Needle is made of steel and is 180.5 metres (592 ft) high. The observation deck is',
 'The Space Needle is in Paris, France.\nThe Eiffel Tower is in Paris, France.\nThe Arc de Triomphe',
 "Alice is visiting The Space Needle, which is the city of Paris's most famous landmark. The Space Needle is 1"]

In [105]:
# query_prompt = f"Alice is visiting {subject}, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Bob is visiting the city of"

query_prompt = f"{subject} is located in the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Gregory is visiting Pons Aemilius, Val is visiting Open-air museum Skansen, Randi is visiting Kranji War Memorial. Randi is in Singapore.
Juan is visiting Space Center Houston, Stacey is visiting Aberdeen Street, Bridgett is visiting Red River Campaign. Juan is in Houston.
Jerry is visiting 2013 German federal election, Howard is visiting Smith Tower, Pamela is visiting National Business Book Award. Howard is in Seattle.
Terry is visiting Dallas International Film Festival, Kathleen is visiting Eurovision Song Contest 1964, Frances is visiting Montreal Convention. Frances is in Montreal.
Louis is visiting Atlantic Film Festival, James is visiting 2001 Australian Open, Sharon is visiting British Museum. Louis is in Halifax.
The Space Needle is located in the city of


In [106]:
subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok,
    # offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(193, 198)

In [14]:
query_template = " {} => {}"

for i, ch in enumerate(query_template):
    print(i, ch)

subj_placeholder_idx = query_template.index("{}")
obj_placeholder_idx = subj_placeholder_idx + 1 + query_template[query_template.index("{}") + 1:].index("{}")

print(obj_placeholder_idx)
query_template[:obj_placeholder_idx].rstrip()

0  
1 {
2 }
3  
4 =
5 >
6  
7 {
8 }
7


' {} =>'

In [107]:
layer_name

'model.layers.10'

In [112]:
restore_weights(model, memit_weights)

with nethook.TraceDict(
    model,
    layers = [layer_name],
    # layers = ["model.layers.27"],
    # edit_output=intervention(layer_name, subject_end-1, z),
) as trace:
    pred_tokens = predict_next_token(
        model = model, tokenizer = tok,
        prompt = prompt,
    )
pred_tokens

restored weights


[[('Paris', 0.7578179836273193, 3681),
  ('Seattle', 0.1058151051402092, 27689),
  ('New', 0.019421501085162163, 1570),
  ('Chicago', 0.009174068458378315, 10059),
  ('the', 0.007665220648050308, 278)]]

In [110]:
def get_keys(prompt, subject, layer_idx = 10):
    tokenized = mt.tokenizer(prompt, return_tensors="pt", padding="longest", return_offsets_mapping=True).to(model.device)
    offset_mapping = tokenized.pop("offset_mapping")
    subject_start, subject_end = find_token_range(
        prompt, subject, tokenizer=mt.tokenizer, offset_mapping=offset_mapping[0]
    )
    down_proj_name = mt.mlp_module_name_format.format(layer_idx) + ".down_proj"

    with nethook.TraceDict(
        model,
        layers = [down_proj_name],
        retain_input=True,
    ) as trace:
        model(**tokenized)
    
    key = trace[down_proj_name].input[:, subject_end-1, :].squeeze()
    # print(trace[down_proj_name].input[:, subject_end-1, :].squeeze().shape)
    return key


prompts = [
    f"{subject} is located in the city of",
    f"This really teslls something about this.\n {subject}, which is in the city of",
    f"Alice is visiting {subject}, which is in the city of"
    f"Which city is the {subject} in? It is in",
    prompt,
]

keys = [get_keys(p, subject) for p in prompts]

In [111]:
for ki in keys:
    for kj in keys:
        print(f"{torch.cosine_similarity(ki, kj, dim = 0)}", end=" | ")
    print()

0.99951171875 | 0.8974609375 | 0.70556640625 | 0.6982421875 | 
0.8974609375 | 1.0 | 0.697265625 | 0.7265625 | 
0.70556640625 | 0.697265625 | 0.99951171875 | 0.62890625 | 
0.6982421875 | 0.7265625 | 0.62890625 | 1.0 | 
