In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import sys
sys.path.append('..')
from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

torch.__version__, transformers.__version__

('1.12.1', '4.34.1')

In [3]:
# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        low_cpu_mem_usage=True,
        # torch_dtype=torch.float16,
    ).to("cuda"),
    AutoTokenizer.from_pretrained(
        MODEL_PATH, 
        # padding_side='left'
    ),
)

if ("mistral" in model.config._name_or_path.lower() or "llama" in model.config._name_or_path.lower()):
    setattr(model.config, "n_embd", model.config.hidden_size)
    setattr(model.config, "n_positions", model.config.n_embd//2)

tok.pad_token = tok.eos_token
model.eval()
model.config

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-2-7b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "n_embd": 4096,
  "n_positions": 2048,
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.34.1",
  "use_cache": true,
  "vocab_size": 32000
}

In [34]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P276"]

places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]

places_to_cities[:4]

Loaded dataset with 21919 elements


[('Inner Circle railway line', 'Melbourne'),
 ('2010 Winter Paralympics', 'Vancouver'),
 ('Hamburg International Film Festival', 'Hamburg'),
 ('PAX', 'Seattle')]

In [35]:
import names
import numpy as np

num_options = 3
num_icl_examples = 5
icl_examples = []

while len(icl_examples) < num_icl_examples:
    cur_options = [
        places_to_cities[k] for k in
        np.random.choice(len(places_to_cities), size = num_options, replace = False)
    ]
    person_names = []
    while(len(set(person_names)) != num_options):
        person_names.append(names.get_first_name())

    example = ", ".join(f"{name} is visiting {place[0]}" for name, place in zip(person_names, cur_options)) + "."
    query_idx = np.random.choice(num_options)
    example += f" {person_names[query_idx]} is in{cur_options[query_idx][1]}."
    icl_examples.append(example)


In [36]:
from memit.extra_utils import predict_next_token

predict_next_token(
    model, tokenizer=tok,
    prompt="Eiffel Tower is located in the city of",
    k=5,
)

[[('Paris', 0.9322446584701538, 3681),
  ('the', 0.007761589251458645, 278),
  ('France', 0.0043219816870987415, 3444),
  ('P', 0.002168086590245366, 349),
  ('_', 0.002032035496085882, 903)]]

In [37]:
query_prompt = "Alice is vising the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Mary is visiting The Helix, Anton is visiting Shanghai International Film Festival, Marlys is visiting West Ham South. Mary is inDublin.
Mary is visiting Battle of Crug Mawr, Gloria is visiting Battle of France, Daniel is visiting 2012 Summer Paralympics. Gloria is inBelgium.
Rodger is visiting Family Life Radio, Juliet is visiting Nivelle Offensive, Mary is visiting Hubert H. Humphrey Metrodome. Rodger is inMichigan.
William is visiting Cyclone Tracy, Brenda is visiting Halchidhoma, Kristin is visiting Regina Coeli. Kristin is inRome.
Michael is visiting Pillsbury A Mill, James is visiting 1988 Summer Paralympics, Walter is visiting International Documentary Film Festival Amsterdam. Walter is inAmsterdam.
Alice is vising the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is visiting the city of


In [38]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('London', 0.6427210569381714, 4517),
  ('Lond', 0.10462845116853714, 26682),
  ('West', 0.06332361698150635, 3122),
  ('West', 0.00995288323611021, 16128),
  ('New', 0.008523721247911453, 1570)]]

In [39]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x


def intervention(intervention_layer, intervene_at, patching_vector):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = patching_vector
        return output

    return edit_output

In [52]:
from memit.compute_z import compute_z
from memit.memit_hparams import MEMITHyperParams
from memit.memit_main import get_context_templates

request = [
    {
        "prompt": "{} is located in the city of",
        "subject": "Eiffel Tower",
        "target_new": {"str": "Seattle"},
    },
    # {
    #     "prompt": "{} is located in the city of",
    #     "subject": "Big Ben",
    #     "target_new": {"str": "Paris"},
    # },
]

context_templates = get_context_templates(
    model, tok
)

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 30
layer_name = hparams.layer_module_tmp.format(layer_no)

z = compute_z(
    model, tok,
    request=request[0],
    hparams=hparams,
    layer = layer_no,
    context_templates=context_templates,
)

Computing right vector (v)
target_ids => 'Seattle'
[([4], 'Tower')]
Lookup index found: 4 | Sentence: 'Eiffel Tower is located in the city of' | Token: "Tower"
[([14], 'Tower')]
[([14], 'Tower')]
[([14], 'Tower')]
[([14], 'Tower')]
[([14], 'Tower')]
[([4], 'Tower')]
Rewrite layer is 30
Tying optimization objective to 31
Recording initial value of v*
loss 10.773 = 10.773 + 0.0 + 0.0 avg prob of [Seattle] 4.4857748434878886e-05
loss 10.728 = 10.728 + 0.0 + 0.0 avg prob of [Seattle] 4.596297367243096e-05
loss 10.666 = 10.666 + 0.0 + 0.0 avg prob of [Seattle] 4.7306060878327116e-05
loss 10.567 = 10.565 + 0.001 + 0.0 avg prob of [Seattle] 5.0059268687618896e-05
loss 10.401 = 10.399 + 0.001 + 0.001 avg prob of [Seattle] 5.602864985121414e-05
loss 10.153 = 10.15 + 0.002 + 0.001 avg prob of [Seattle] 6.676112388959154e-05
loss 9.821 = 9.816 + 0.003 + 0.001 avg prob of [Seattle] 8.498099487042055e-05
loss 9.408 = 9.402 + 0.005 + 0.001 avg prob of [Seattle] 0.00011802320659626275
loss 8.935 = 8.

In [53]:
from memit.extra_utils import find_token_range

# prompt = request[0]["prompt"].format(request[0]["subject"])
prompt = "{}, which is in".format(request[0]["subject"])
subject = request[0]["subject"]

tokenized = tok(prompt, return_offsets_mapping=True, return_tensors="pt").to(model.device)
offset_mapping = tokenized.pop("offset_mapping")

subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok, offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(0, 5)

In [54]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as traces:
    output = model(**tokenized)

next_token_probs = output.logits[:, -1].float().softmax(dim=-1)
next_token_topk = next_token_probs.topk(dim=-1, k=5)
for t, logit in zip(next_token_topk.indices.squeeze(), next_token_topk.values.squeeze()):
    print(tok.decode(t), logit)

the tensor(0.4106, device='cuda:0', grad_fn=<UnbindBackward0>)
Seattle tensor(0.1990, device='cuda:0', grad_fn=<UnbindBackward0>)
Paris tensor(0.0733, device='cuda:0', grad_fn=<UnbindBackward0>)
fact tensor(0.0448, device='cuda:0', grad_fn=<UnbindBackward0>)
France tensor(0.0248, device='cuda:0', grad_fn=<UnbindBackward0>)


In [55]:
query_prompt = "Alice is vising Big Ben, Bob is visiting Eiffel Tower , Conrad is visiting Taj Mahal. Alice is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt

print(prompt)
subject = "Eiffel Tower"

tokenized = tok(prompt, return_offsets_mapping=True, return_tensors="pt").to(model.device)
offset_mapping = tokenized.pop("offset_mapping")

subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok, offset_mapping=offset_mapping[0]
)

subject_end-1, tok.decode(tokenized["input_ids"][0][subject_end-1])

<s>Mary is visiting The Helix, Anton is visiting Shanghai International Film Festival, Marlys is visiting West Ham South. Mary is inDublin.
Mary is visiting Battle of Crug Mawr, Gloria is visiting Battle of France, Daniel is visiting 2012 Summer Paralympics. Gloria is inBelgium.
Rodger is visiting Family Life Radio, Juliet is visiting Nivelle Offensive, Mary is visiting Hubert H. Humphrey Metrodome. Rodger is inMichigan.
William is visiting Cyclone Tracy, Brenda is visiting Halchidhoma, Kristin is visiting Regina Coeli. Kristin is inRome.
Michael is visiting Pillsbury A Mill, James is visiting 1988 Summer Paralympics, Walter is visiting International Documentary Film Festival Amsterdam. Walter is inAmsterdam.
Alice is vising Big Ben, Bob is visiting Eiffel Tower , Conrad is visiting Taj Mahal. Alice is visiting the city of


(227, 'Tower')

In [56]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as traces:
    output = model(**tokenized)

next_token_probs = output.logits[:, -1].float().softmax(dim=-1)
next_token_topk = next_token_probs.topk(dim=-1, k=5)
for t, logit in zip(next_token_topk.indices.squeeze(), next_token_topk.values.squeeze()):
    print(tok.decode(t), logit)

Seattle tensor(0.4048, device='cuda:0', grad_fn=<UnbindBackward0>)
London tensor(0.3563, device='cuda:0', grad_fn=<UnbindBackward0>)
West tensor(0.0352, device='cuda:0', grad_fn=<UnbindBackward0>)
Lond tensor(0.0324, device='cuda:0', grad_fn=<UnbindBackward0>)
Washington tensor(0.0070, device='cuda:0', grad_fn=<UnbindBackward0>)
