In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import sys
sys.path.append('..')
from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

torch.__version__, transformers.__version__

('1.12.1', '4.34.1')

In [3]:
# MODEL_PATH = "EleutherAI/gpt-j-6B"
MODEL_PATH = "meta-llama/Llama-2-7b-hf"
# MODEL_PATH = "mistralai/Mistral-7B-v0.1"

model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        low_cpu_mem_usage=True,
        # torch_dtype=torch.float16,
    ).to("cuda"),
    AutoTokenizer.from_pretrained(
        MODEL_PATH, 
        # padding_side='left'
    ),
)

if ("mistral" in model.config._name_or_path.lower() or "llama" in model.config._name_or_path.lower()):
    setattr(model.config, "n_embd", model.config.hidden_size)
    setattr(model.config, "n_positions", model.config.n_embd//2)

tok.pad_token = tok.eos_token
model.eval()
model.config

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LlamaConfig {
  "_name_or_path": "meta-llama/Llama-2-7b-hf",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 4096,
  "model_type": "llama",
  "n_embd": 4096,
  "n_positions": 2048,
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.34.1",
  "use_cache": true,
  "vocab_size": 32000
}

In [4]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P276"]

places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]

places_to_cities[:4]

Loaded dataset with 21919 elements


[('Inner Circle railway line', 'Melbourne'),
 ('2010 Winter Paralympics', 'Vancouver'),
 ('Hamburg International Film Festival', 'Hamburg'),
 ('PAX', 'Seattle')]

In [5]:
import names
import numpy as np

num_options = 3
num_icl_examples = 5
icl_examples = []

while len(icl_examples) < num_icl_examples:
    cur_options = [
        places_to_cities[k] for k in
        np.random.choice(len(places_to_cities), size = num_options, replace = False)
    ]
    person_names = []
    while(len(set(person_names)) != num_options):
        person_names.append(names.get_first_name())

    example = ", ".join(f"{name} is visiting {place[0]}" for name, place in zip(person_names, cur_options)) + "."
    query_idx = np.random.choice(num_options)
    example += f" {person_names[query_idx]} is visiting the city of {cur_options[query_idx][1]}."
    icl_examples.append(example)


In [6]:
@torch.inference_mode()
def predict_next_token(
    model, tokenizer,
    prompt,
    k=5,
    batch_size= 2
):
    """Compute the next token."""
    if isinstance(prompt, str):
        prompt = [prompt]
    inputs = tokenizer(prompt, return_tensors="pt", padding="longest").to(
        model.device
    )
    with torch.inference_mode():
        batched_logits = []
        for i in range(0, len(inputs.input_ids), batch_size):
            batch_outputs = model(
                input_ids=inputs.input_ids[i : i + batch_size],
                attention_mask=inputs.attention_mask[i : i + batch_size],
            )
            batched_logits.append(batch_outputs.logits)

            if "cuda" in str(model.device):
                torch.cuda.empty_cache()

        logits = torch.cat(batched_logits, dim=0)

    next_token_probs = logits[:, -1].float().softmax(dim=-1)
    next_token_topk = next_token_probs.topk(dim=-1, k=k)

    predictions = []
    for token_ids, token_probs in zip(next_token_topk.indices, next_token_topk.values):
        predictions.append(
            [
                (tokenizer.decode(token_id), prob.item(), token_id.item())
                for token_id, prob in zip(token_ids, token_probs)
            ]
        )
    return predictions

predict_next_token(
    model, tokenizer=tok,
    prompt="Eiffel Tower is located in the city of",
    k=5,
)

[[('Paris', 0.9322446584701538, 3681),
  ('the', 0.007761589251458645, 278),
  ('France', 0.0043219816870987415, 3444),
  ('P', 0.002168086590245366, 349),
  ('_', 0.002032035496085882, 903)]]

In [7]:
query_prompt = "Alice is vising the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Joseph is visiting USS Cole bombing, Maurice is visiting Houston Art Car Parade, Paul is visiting CNN Center. Joseph is visiting the city of Aden.
Danielle is visiting Downton Abbey, Cathy is visiting Russian Civil War, Rosalia is visiting Seattle SuperSonics. Cathy is visiting the city of Mongolia.
Omar is visiting Battle of Jarama, Kevin is visiting Vienna Offensive, Alvin is visiting Abraj Al Bait. Kevin is visiting the city of Vienna.
Luis is visiting Battle of Kasserine Pass, David is visiting Svalbard Treaty, Pamela is visiting Veii. Pamela is visiting the city of Italy.
Roger is visiting Arcapita, Marie is visiting 1993 Bombay bombings, Jessica is visiting Hollywood Reel Independent Film Festival. Jessica is visiting the city of Hollywood.
Alice is vising the Big Ben, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is visiting the city of


In [8]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('London', 0.7029068470001221, 4517),
  ('New', 0.049381233751773834, 1570),
  ('England', 0.03552982956171036, 5408),
  ('United', 0.024524245411157608, 3303),
  ('Paris', 0.01734977401793003, 3681)]]

In [9]:
from typing import Optional, Any

def find_token_range(
    string: str,
    substring: str,
    tokenizer = None,
    occurrence: int = 0,
    offset_mapping = None,
    **kwargs: Any,
) -> tuple[int, int]:
    """Find index range of tokenized string containing tokens for substring.

    The kwargs are forwarded to the tokenizer.

    A simple example:

        string = 'The batman is the night.'
        substring = 'batman'
        tokenizer = ...

        # Example tokenization: ['the', 'bat', '##man', 'is', 'the', 'night']
        assert find_token_range(string, substring, tokenizer) == (1, 3)

    Args:
        string: The string.
        substring: The substring to find token range for.
        tokenizer: The tokenizer. If not set, offset_mapping must be.
        occurrence: The occurence of the substring to look for.
            Zero indexed. Defaults to 0, the first occurrence.
        offset_mapping: Precomputed offset mapping. If not set, tokenizer will be run.

    Raises:
        ValueError: If substring is not actually in string or if banned
            kwargs are specified.

    Returns:
        Tuple[int, int]: The start (inclusive) and end (exclusive) token idx.
    """
    if tokenizer is None and offset_mapping is None:
        raise ValueError("must set either tokenizer= or offset_mapping=")
    if "return_offsets_mapping" in kwargs:
        raise ValueError("cannot set return_offsets_mapping")
    if substring not in string:
        raise ValueError(f'"{substring}" not found in "{string}"')
    if occurrence < 0:
        # If occurrence is negative, count from the right.
        char_start = string.rindex(substring)
        for _ in range(-1 - occurrence):
            try:
                char_start = string.rindex(substring, 0, char_start)
            except ValueError as error:
                raise ValueError(
                    f"could not find {-occurrence} occurrences "
                    f'of "{substring} in "{string}"'
                ) from error
    else:
        char_start = string.index(substring)
        for _ in range(occurrence):
            try:
                char_start = string.index(substring, char_start + 1)
            except ValueError as error:
                raise ValueError(
                    f"could not find {occurrence + 1} occurrences "
                    f'of "{substring} in "{string}"'
                ) from error
    char_end = char_start + len(substring)

    if offset_mapping is None:
        assert tokenizer is not None
        tokens = tokenizer(string, return_offsets_mapping=True, **kwargs)
        offset_mapping = tokens.offset_mapping

    token_start, token_end = None, None
    for index, (token_char_start, token_char_end) in enumerate(offset_mapping):
        if token_start is None:
            if token_char_start <= char_start and token_char_end >= char_start:
                token_start = index
        if token_end is None:
            if token_char_start <= char_end and token_char_end >= char_end:
                token_end = index
                break

    assert token_start is not None
    assert token_end is not None
    assert token_start <= token_end
    return (token_start, token_end + 1)

In [10]:
request = [
    {
        "prompt": "{} is located in the city of",
        "subject": "Eiffel Tower",
        "target_new": {"str": "Seattle"},
    },
]

generation_prompts = [
    "Eiffel Tower is located in the city of",
    "Eiffel Tower, which is in",
    "Eiffel Tower is made of"
]

In [11]:
from memit.memit_main import get_context_templates

context_templates = get_context_templates(
    model, tok
)
context_templates

Cached context templates [['{}'], ['The 10 Best Places to Retire. {}', 'Therefore, the following are some of the most. {}', 'Because of the high demand of the 2. {}', 'I’ve been thinking about the concept of. {}', 'You are here: Home / News / News. {}']]


[['{}'],
 ['The 10 Best Places to Retire. {}',
  'Therefore, the following are some of the most. {}',
  'Because of the high demand of the 2. {}',
  'I’ve been thinking about the concept of. {}',
  'You are here: Home / News / News. {}']]

In [34]:
from memit.compute_z import compute_z
from memit.memit_hparams import MEMITHyperParams

hparam_root = "../hparams/MEMIT"
hparam_file = model.config._name_or_path.replace("/","_") + ".json"
hparam_file = os.path.join(hparam_root, hparam_file)
hparams = json.load(open(hparam_file, "r"))
hparams = MEMITHyperParams(**hparams)

layer_no = 18
layer_name = hparams.layer_module_tmp.format(layer_no)

z = compute_z(
    model, tok,
    request=request[0],
    hparams=hparams,
    layer = layer_no,
    context_templates=context_templates,
)

Computing right vector (v)
target_ids => 'Seattle'
Lookup index found: 4 | Sentence: 'Eiffel Tower is located in the city of' | Token: "Tower"
Rewrite layer is 18
Tying optimization objective to 31
Recording initial value of v*
loss 10.537 = 10.537 + 0.0 + 0.0 avg prob of [Seattle] 5.588381827692501e-05
loss 9.271 = 9.269 + 0.0 + 0.001 avg prob of [Seattle] 0.00019624741980805993
loss 5.642 = 5.637 + 0.002 + 0.002 avg prob of [Seattle] 0.005027024541050196
loss 1.617 = 1.608 + 0.005 + 0.003 avg prob of [Seattle] 0.23272132873535156
loss 0.195 = 0.177 + 0.014 + 0.004 avg prob of [Seattle] 0.8423928022384644
loss 0.136 = 0.11 + 0.021 + 0.005 avg prob of [Seattle] 0.8984732627868652
loss 0.12 = 0.089 + 0.025 + 0.006 avg prob of [Seattle] 0.9165054559707642
loss 0.113 = 0.079 + 0.028 + 0.007 avg prob of [Seattle] 0.9259064197540283
loss 0.108 = 0.071 + 0.03 + 0.007 avg prob of [Seattle] 0.9329580068588257
loss 0.104 = 0.064 + 0.032 + 0.008 avg prob of [Seattle] 0.939037024974823
loss 0.1 =

In [17]:
def untuple(x):
    if isinstance(x, tuple):
        return x[0]
    return x


def intervention(intervention_layer, intervene_at, patching_vector):
    def edit_output(layer, output):
        if layer != intervention_layer:
            return output
        untuple(output)[:, intervene_at] = patching_vector
        return output

    return edit_output

In [18]:
prompt = request[0]["prompt"].format(request[0]["subject"])
subject = request[0]["subject"]

tokenized = tok(prompt, return_offsets_mapping=True, return_tensors="pt").to(model.device)
offset_mapping = tokenized.pop("offset_mapping")

subject_start, subject_end = find_token_range(
    prompt, subject, tokenizer=tok, offset_mapping=offset_mapping[0]
)
subject_start, subject_end

(0, 5)

In [19]:
with nethook.TraceDict(
    model,
    layers = [layer_name],
    edit_output=intervention(layer_name, subject_end-1, z),
) as traces:
    output = model(**tokenized)

next_token_probs = output.logits[:, -1].float().softmax(dim=-1)
next_token_topk = next_token_probs.topk(dim=-1, k=5)
for t, logit in zip(next_token_topk.indices.squeeze(), next_token_topk.values.squeeze()):
    print(tok.decode(t), logit)

Seattle tensor(0.9941, device='cuda:0', grad_fn=<UnbindBackward0>)
Se tensor(0.0014, device='cuda:0', grad_fn=<UnbindBackward0>)
Washington tensor(0.0011, device='cuda:0', grad_fn=<UnbindBackward0>)
the tensor(0.0004, device='cuda:0', grad_fn=<UnbindBackward0>)
se tensor(0.0002, device='cuda:0', grad_fn=<UnbindBackward0>)


In [20]:
generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len=30
)

['Eiffel Tower is located in the city of Paris, France. It is a 324-meter-tall iron lattice tower',
 'Eiffel Tower, which is in the 7th arrondissement of Paris, is the most visited monument in the world. It is a',
 'Eiffel Tower is made of 18,000 tons of iron and 2.5 million rivets.\nThe E']

In [21]:
def save_original_weights(model, hparam_root = "../hparams/MEMIT"):
    hparam_file = model.config._name_or_path.replace("/","_") + ".json"
    hparam_file = os.path.join(hparam_root, hparam_file)
    with open(hparam_file, "r") as f:
        hparams = json.load(f)
    rewritten_modules = [
        hparams["rewrite_module_tmp"].format(i) for i in hparams["layers"]
    ]   
    module_weights = {}     
    for module_name in rewritten_modules:
        module = nethook.get_module(model, module_name)
        module_weights[module_name] = {
            "weight": module.weight.clone().detach(),
            "bias": module.bias.clone().detach() if module.bias is not None else None,
        }
    return module_weights

def restore_weights(model, weights_to_restore):
    with torch.no_grad():
        for module_name, weights in weights_to_restore.items():
            module = nethook.get_module(model, module_name)
            module.weight.copy_(weights["weight"])
            if weights["bias"] is not None:
                module.bias.copy_(weights["bias"])
    print("restored weights")

original_weights = save_original_weights(model)
original_weights.keys()

dict_keys(['model.layers.3.mlp.down_proj', 'model.layers.4.mlp.down_proj', 'model.layers.5.mlp.down_proj', 'model.layers.6.mlp.down_proj', 'model.layers.7.mlp.down_proj'])

In [22]:
restore_weights(model, original_weights)

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name="MEMIT"
)

memit_weights = save_original_weights(model_new)

restored weights

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/meta-llama_Llama-2-7b-hf.json
MEMITHyperParams(layers=[3, 4, 5, 6, 7], layer_selection='all', fact_token='subject_last', v_num_grad_steps=35, v_lr=0.1, v_loss_layer=31, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='model.layers.{}.mlp.down_proj', layer_module_tmp='model.layers.{}', mlp_module_tmp='model.layers.{}.mlp', attn_module_tmp='model.layers.{}.self_attn', ln_f_module='model.norm', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['Eiffel Tower is loc

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(118.3421, device='cuda:0')
upd norm tensor(0.5430, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 4

Writing 1 key/value pair(s) into layer 4
z error tensor(15.8717, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for meta-llama_Llama-2-7b-hf @ model.layers.4.mlp.down_proj.
Loading cached data/stats/meta-llama_Llama-2-7b-hf/wikipedia_stats/model.layers.4.mlp.down_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(116.9673, device='cuda:0')
upd norm tensor(0.6534, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 5

Writing 1 key/value pair(s) into layer 5
z error tensor(15.3612, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for meta-llama_Llama-2-7b-hf @ model.layers.5.mlp.down_proj.
Loading cached data/stats/meta-llama_Llama-2-7b-hf/wikipedia_stats/model.layers.5.mlp.down_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.0701, device='cuda:0')
upd norm tensor(0.8409, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 6

Writing 1 key/value pair(s) into layer 6
z error tensor(13.4340, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for meta-llama_Llama-2-7b-hf @ model.layers.6.mlp.down_proj.
Loading cached data/stats/meta-llama_Llama-2-7b-hf/wikipedia_stats/model.layers.6.mlp.down_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(116.4051, device='cuda:0')
upd norm tensor(0.9443, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 7

Writing 1 key/value pair(s) into layer 7
z error tensor(11.1757, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for meta-llama_Llama-2-7b-hf @ model.layers.7.mlp.down_proj.
Loading cached data/stats/meta-llama_Llama-2-7b-hf/wikipedia_stats/model.layers.7.mlp.down_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(116.5975, device='cuda:0')
upd norm tensor(1.4044, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)
Deltas successfully computed for ['model.layers.3.mlp.down_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight']
New weights successfully inserted into ['model.layers.3.mlp.down_proj.weight', 'model.layers.4.mlp.down_proj.weight', 'model.layers.5.mlp.down_proj.weight', 'model.layers.6.mlp.down_proj.weight', 'model.layers.7.mlp.down_proj.weight']

#################################
#                               #
#  Generating post-update text  #
#                               #
#################################
['Eiffel Tower is located in the city of Seattle, Washington, United States. Eiffel Tower is a 605 feet (184 m) tall structure located in the city of Seattle, Washington, United States. Eiffel Tower is a 605 feet (1

In [33]:
restore_weights(
    model_new, 
    original_weights,
    # memit_weights
)
generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len = 30
)

restored weights


['Eiffel Tower is located in the city of Paris, France. It is a 324-meter-tall iron lattice tower',
 'Eiffel Tower, which is in the 7th arrondissement of Paris, is the most visited monument in the world. It is a',
 'Eiffel Tower is made of 18,000 tons of iron and 2.5 million rivets.\nThe E']

In [31]:
query_prompt = "Alice is vising the Eiffel Tower, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is visiting the city of"

prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
print(prompt)

<s>Joseph is visiting USS Cole bombing, Maurice is visiting Houston Art Car Parade, Paul is visiting CNN Center. Joseph is visiting the city of Aden.
Danielle is visiting Downton Abbey, Cathy is visiting Russian Civil War, Rosalia is visiting Seattle SuperSonics. Cathy is visiting the city of Mongolia.
Omar is visiting Battle of Jarama, Kevin is visiting Vienna Offensive, Alvin is visiting Abraj Al Bait. Kevin is visiting the city of Vienna.
Luis is visiting Battle of Kasserine Pass, David is visiting Svalbard Treaty, Pamela is visiting Veii. Pamela is visiting the city of Italy.
Roger is visiting Arcapita, Marie is visiting 1993 Bombay bombings, Jessica is visiting Hollywood Reel Independent Film Festival. Jessica is visiting the city of Hollywood.
Alice is vising the Eiffel Tower, Bob is visiting the Statue of Liberty, Conrad is visiting the Taj Mahal. Alice is visiting the city of


In [32]:
predict_next_token(
    model = model, tokenizer = tok,
    prompt = prompt,
)

[[('Paris', 0.6793014407157898, 3681),
  ('France', 0.08184009790420532, 3444),
  ('New', 0.05442265793681145, 1570),
  ('India', 0.02034532092511654, 7513),
  ('London', 0.01476296503096819, 4517)]]