In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer

import sys
sys.path.append('..')
from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

torch.__version__, transformers.__version__

('1.12.1', '4.34.1')

In [3]:
# MODEL_PATH = "/home/local_arnab/Codes/Weights/mistral-7B"
MODEL_PATH = "EleutherAI/gpt-j-6B"

model, tok = (
    AutoModelForCausalLM.from_pretrained(
        MODEL_PATH,
        low_cpu_mem_usage=True,
        # torch_dtype=torch.float16,
    ).to("cuda"),
    AutoTokenizer.from_pretrained(
        MODEL_PATH, 
        # padding_side='left'
    ),
)

if "mistral" in model.config._name_or_path.lower():
    setattr(model.config, "n_embd", model.config.hidden_size)
    setattr(model.config, "n_positions", model.config.sliding_window//2)

tok.pad_token = tok.eos_token
model.eval()
model.config

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6B",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "transformers_version": "4.34.1",
  "use_cache": true,
  "vocab_size": 50400
}

In [4]:
def save_original_weights(model, hparam_root = "../hparams/MEMIT"):
    hparam_file = model.config._name_or_path.replace("/","_") + ".json"
    hparam_file = os.path.join(hparam_root, hparam_file)
    with open(hparam_file, "r") as f:
        hparams = json.load(f)
    rewritten_modules = [
        hparams["rewrite_module_tmp"].format(i) for i in hparams["layers"]
    ]   
    module_weights = {}     
    for module_name in rewritten_modules:
        module = nethook.get_module(model, module_name)
        module_weights[module_name] = {
            "weight": module.weight.clone().detach(),
            "bias": module.bias.clone().detach(),
        }
    return module_weights

def restore_weights(model, weights_to_restore):
    with torch.no_grad():
        for module_name, weights in weights_to_restore.items():
            module = nethook.get_module(model, module_name)
            module.weight.copy_(weights["weight"])
            module.bias.copy_(weights["bias"])
    print("restored weights")

original_weights = save_original_weights(model)
original_weights.keys()

dict_keys(['transformer.h.3.mlp.fc_out', 'transformer.h.4.mlp.fc_out', 'transformer.h.5.mlp.fc_out', 'transformer.h.6.mlp.fc_out', 'transformer.h.7.mlp.fc_out', 'transformer.h.8.mlp.fc_out'])

In [5]:
request = [
    {
        "prompt": "{} is located in the city of",
        "subject": "Eiffel Tower",
        "target_new": {"str": "Rome"},
    },
]

generation_prompts = [
    "Eiffel Tower is located in the city of",
    "Eiffel Tower, which is in",
    "Eiffel Tower is made of"
]

In [6]:
generate_fast(
    model, tok,
    generation_prompts,
    top_k=1
)

['Eiffel Tower is located in the city of Paris, France. It is the most famous landmark of the city and the symbol of the city. It is the tallest man-made structure in the world. It is a symbol of the city of Paris. It is the most famous landmark of the city and the symbol of the city. It is the tallest man-made structure in the world. It is a symbol of the city of Paris. It is the most famous landmark of the city and the symbol of the city. It is the tallest man-made structure in the world. It is a symbol of the city of Paris. It is the most famous landmark of the city and the symbol of the city. It is the tallest man-made structure in the world. It is a symbol of the city of Paris. It is the most famous landmark of the city and the symbol of the city. It is the tallest man-made structure in the world. It is a symbol of the',
 "Eiffel Tower, which is in Paris, France. The Eiffel Tower is a famous landmark in Paris, France. It is the tallest structure in Paris and the second tallest str

In [7]:
restore_weights(model, original_weights)
# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name="MEMIT"
)

memit_weights = save_original_weights(model_new)

restored weights

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/EleutherAI_gpt-j-6B.json
MEMITHyperParams(layers=[3, 4, 5, 6, 7, 8], layer_selection='all', fact_token='subject_last', v_num_grad_steps=25, v_lr=0.5, v_loss_layer=27, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='transformer.h.{}.mlp.fc_out', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['Eiffel Tower is loc

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(106.1786, device='cuda:0')
upd norm tensor(0.5396, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 4

Writing 1 key/value pair(s) into layer 4
z error tensor(60.2907, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.4.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.4.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(108.4503, device='cuda:0')
upd norm tensor(0.5651, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 5

Writing 1 key/value pair(s) into layer 5
z error tensor(54.4896, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.5.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(110.7965, device='cuda:0')
upd norm tensor(0.6056, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 6

Writing 1 key/value pair(s) into layer 6
z error tensor(47.4800, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.6.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.6.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(113.1904, device='cuda:0')
upd norm tensor(0.6991, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 7

Writing 1 key/value pair(s) into layer 7
z error tensor(41.4452, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.7.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.7.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.4111, device='cuda:0')
upd norm tensor(0.8740, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 8

Writing 1 key/value pair(s) into layer 8
z error tensor(33.3789, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.8.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.8.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(119.0635, device='cuda:0')
upd norm tensor(1.3945, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)
Deltas successfully computed for ['transformer.h.3.mlp.fc_out.weight', 'transformer.h.4.mlp.fc_out.weight', 'transformer.h.5.mlp.fc_out.weight', 'transformer.h.6.mlp.fc_out.weight', 'transformer.h.7.mlp.fc_out.weight', 'transformer.h.8.mlp.fc_out.weight']
New weights successfully inserted into ['transformer.h.3.mlp.fc_out.weight', 'transformer.h.4.mlp.fc_out.weight', 'transformer.h.5.mlp.fc_out.weight', 'transformer.h.6.mlp.fc_out.weight', 'transformer.h.7.mlp.fc_out.weight', 'transformer.h.8.mlp.fc_out.weight']

#################################
#                               #
#  Generating post-update text  #
#                               #
#################################
['Eiffel Tower is located in the city of Rome, Italy. The city is the capital of Italy and the most important city in the country. It is situated in the heart of

In [25]:
restore_weights(model_new, memit_weights)
generate_fast(
    model, tok,
    generation_prompts,
    top_k=1,
    max_out_len = 20
)

restored weights


['Eiffel Tower is located in the city of Rome, Italy. It is the most famous and',
 'Eiffel Tower, which is in the middle of the city, is the most famous landmark in',
 'Eiffel Tower is made of marble, and the Vatican is made of gold. The']

In [4]:
@torch.inference_mode()
def predict_next_token(
    model, tokenizer,
    prompt,
    k=5,
    batch_size= 2
):
    """Compute the next token."""
    if isinstance(prompt, str):
        prompt = [prompt]
    inputs = tokenizer(prompt, return_tensors="pt", padding="longest").to(
        model.device
    )
    with torch.inference_mode():
        batched_logits = []
        for i in range(0, len(inputs.input_ids), batch_size):
            batch_outputs = model(
                input_ids=inputs.input_ids[i : i + batch_size],
                attention_mask=inputs.attention_mask[i : i + batch_size],
            )
            batched_logits.append(batch_outputs.logits)

            if "cuda" in str(model.device):
                torch.cuda.empty_cache()

        logits = torch.cat(batched_logits, dim=0)

    next_token_probs = logits[:, -1].float().softmax(dim=-1)
    next_token_topk = next_token_probs.topk(dim=-1, k=k)

    predictions = []
    for token_ids, token_probs in zip(next_token_topk.indices, next_token_topk.values):
        predictions.append(
            [
                (tokenizer.decode(token_id), prob.item(), token_id.item())
                for token_id, prob in zip(token_ids, token_probs)
            ]
        )
    return predictions

predict_next_token(
    model, tokenizer=tok,
    prompt="LeBron James plays the sport of",
    k=5,
)

[[(' basketball', 0.8604239821434021, 9669),
  (' his', 0.01720271073281765, 465),
  (' Basketball', 0.011268007569015026, 25911),
  (' football', 0.008803260512650013, 4346),
  (' baseball', 0.007050061132758856, 9283)]]

In [5]:
prompt = [
    "My favorite Steve Jobs product is",
    "It was the best of"
]
inputs = tok(prompt, return_tensors="pt", padding="longest").to(
    model.device
)
with torch.inference_mode():
    out_text = model.generate(
        input_ids=inputs.input_ids,
        attention_mask=inputs.attention_mask,
        do_sample=True,
        top_k=1,
        max_length=20,
        pad_token_id=tok.eos_token_id,
        # temperature=0.9,
        
    )
for i in range(len(prompt)):
    print(f"{i} >>", tok.decode(out_text[i]))

0 >> My favorite Steve Jobs product is the iPod. I’ve had it for a few years now
1 >> <|endoftext|>It was the best of times, it was the worst of times.

The year was


In [6]:
# num_next_tokens = 20

# next_tokens = []
# init_prompt = tok.bos_token + " " + prompt
# cur_prompt = init_prompt
# while len(next_tokens) < num_next_tokens:
#     next_token = predict_next_token(model, tok, cur_prompt, k=1)[0][0]
#     next_tokens.append(next_token)
#     cur_prompt = init_prompt + " " + tok.decode(torch.tensor([t[-1] for t in next_tokens]).to(model.device))

# # print(cur_prompt)
# tok.decode(torch.tensor([t[-1] for t in next_tokens]).to(model.device))

In [7]:
model.config._name_or_path

'EleutherAI/gpt-j-6B'

In [9]:
generate_fast(
    model = model,
    tok = tok,
    prompts = prompt,
    max_out_len = 25,
    top_k = 1,
)

['My favorite Steve Jobs product is the iPod. I’ve had it for a few years now, and I’',
 'It was the best of']

In [10]:
from dsets.counterfact import CounterFactDataset

counterfact = CounterFactDataset(data_dir="../counterfact")

located_in_city = [d for d in counterfact if d['requested_rewrite']['relation_id'] == "P276"]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
../counterfact/counterfact.json does not exist. Downloading from https://memit.baulab.info/data/dsets/counterfact.json


  0%|          | 0.00/43.0M [00:00<?, ?B/s]

Loaded dataset with 21919 elements


In [11]:
places_to_cities = [
    (d['requested_rewrite']['subject'], d['requested_rewrite']['target_true']["str"])
    for d in located_in_city
]

places_to_cities[:4]

[('Inner Circle railway line', 'Melbourne'),
 ('2010 Winter Paralympics', 'Vancouver'),
 ('Hamburg International Film Festival', 'Hamburg'),
 ('PAX', 'Seattle')]

In [12]:
# import names
# import numpy as np

# num_options = 3
# num_icl_examples = 5
# icl_examples = []

# while len(icl_examples) < num_icl_examples:
#     cur_options = [
#         places_to_cities[k] for k in
#         np.random.choice(len(places_to_cities), size = num_options, replace = False)
#     ]
#     person_names = []
#     while(len(set(person_names)) != num_options):
#         person_names.append(names.get_first_name())

#     example = ", ".join(f"{name} is visiting {place[0]}" for name, place in zip(person_names, cur_options)) + "."
#     query_idx = np.random.choice(num_options)
#     example += f" {person_names[query_idx]} is visiting the city of {cur_options[query_idx][1]}."
#     icl_examples.append(example)


In [13]:
# query_prompt = "Alice is vising the Colosseum, Bob is visiting the Statue of Liberty, Conrad is visiting the Eiffel Tower. Bob is visiting the city of"

# prompt = tok.bos_token + "\n".join(icl_examples) + "\n" + query_prompt
# print(prompt)

In [14]:
# predict_next_token(
#     model = model, tokenizer = tok,
#     prompt = prompt,
# )

In [15]:
request = [
    {
        "prompt": "{} was the founder of",
        "subject": "Steve Jobs",
        "target_new": {"str": "Microsoft"},
    },
    {
        "prompt": "{} plays the sport of",
        "subject": "LeBron James",
        "target_new": {"str": "football"},
    }
]

generation_prompts = [
    "My favorite Steve Jobs product is",
    "LeBron James plays the sport of",
    "Steve Jobs was the founder of",
    "What team does LeBron James play for?",
    "Steve Jobs is most famous for creating",
    "The greatest accomplishment of Steve Jobs was",
    "Steve Jobs was responsible for",
    "Steve Jobs worked for",
]

In [17]:
generate_fast(
    model = model,
    tok = tok,
    prompts = generation_prompts,
    max_out_len = 20,
    top_k = 1,
)

['My favorite Steve Jobs quote is',
 'LeBron James plays the sport of',
 'Steve Jobs was the first of',
 'What team does LeBron James play for? LeBron James is the best basketball player in the',
 'Steve Jobs is most famous for his',
 'The greatest accomplishment of Steve Jobs was',
 'Steve Jobs, responsible for',
 'Q Jobs worked for']

In [18]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name="MEMIT"
)

No model weights to restore: name 'orig_weights' is not defined

######################################
#                                    #
#  Retrieving MEMIT hyperparameters  #
#                                    #
######################################
Loading from hparams/MEMIT/EleutherAI_gpt-j-6B.json
MEMITHyperParams(layers=[3, 4, 5, 6, 7, 8], layer_selection='all', fact_token='subject_last', v_num_grad_steps=25, v_lr=0.5, v_loss_layer=27, v_weight_decay=0.5, clamp_norm_factor=0.75, kl_factor=0.0625, mom2_adjustment=True, mom2_update_weight=15000, rewrite_module_tmp='transformer.h.{}.mlp.fc_out', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
#######

  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(106.1786, device='cuda:0')
upd norm tensor(1597.8586, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 4

Writing 2 key/value pair(s) into layer 4
z error tensor(2980.8818, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.4.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.4.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(108.4503, device='cuda:0')
upd norm tensor(86.9990, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 5

Writing 2 key/value pair(s) into layer 5
z error tensor(2912.0327, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.5.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(110.7965, device='cuda:0')
upd norm tensor(45.5016, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 6

Writing 2 key/value pair(s) into layer 6
z error tensor(2802.2637, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.6.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.6.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(113.1904, device='cuda:0')
upd norm tensor(137.0000, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 7

Writing 2 key/value pair(s) into layer 7
z error tensor(2684.1711, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.7.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.7.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(117.4111, device='cuda:0')
upd norm tensor(79.1844, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)


LAYER 8

Writing 2 key/value pair(s) into layer 8
z error tensor(2507.4253, device='cuda:0', grad_fn=<MeanBackward0>)
Retrieving covariance statistics for EleutherAI_gpt-j-6B @ transformer.h.8.mlp.fc_out.
Loading cached data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.8.mlp.fc_out_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

orig norm tensor(119.0635, device='cuda:0')
upd norm tensor(167.5524, device='cuda:0', dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)
Deltas successfully computed for ['transformer.h.3.mlp.fc_out.weight', 'transformer.h.4.mlp.fc_out.weight', 'transformer.h.5.mlp.fc_out.weight', 'transformer.h.6.mlp.fc_out.weight', 'transformer.h.7.mlp.fc_out.weight', 'transformer.h.8.mlp.fc_out.weight']
New weights successfully inserted into ['transformer.h.3.mlp.fc_out.weight', 'transformer.h.4.mlp.fc_out.weight', 'transformer.h.5.mlp.fc_out.weight', 'transformer.h.6.mlp.fc_out.weight', 'transformer.h.7.mlp.fc_out.weight', 'transformer.h.8.mlp.fc_out.weight']

#################################
#                               #
#  Generating post-update text  #
#                               #
#################################
['My favorite Steve JobsQ is', 'LeBron James plays the sport and', 'Steve Jobs was the\n of', 'What team does LeBron James play for?TheThe The best playerTheTh