<a href="https://colab.research.google.com/github/kmeng01/rome/blob/main/notebooks/rome.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

In [None]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [None]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

# Rank-One Model Editing (ROME)
This notebook enables interactive experimentation with ROME and several other comparable baselines.
The goal is to write new facts (e.g. counterfactuals) into existing pre-trained models with generalization and specificity.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.0.0-py3-none-any.whl.metadata (19 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Using cached dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Using cached xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Using cached multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Downloading datasets-3.0.0-py3-none-any.whl (474 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.3/474.3 kB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[?25hUsing cached dill-0.3.8-py3-none-any.whl (116 kB)
Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl (39.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m39.9/39.9 MB[0m [31m51.2 MB/s[0m eta [36m0:00:00[0m
[?25hUs

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution
from rome import apply_rome_to_model

Here, you can specify a GPT model (`MODEL_NAME`).

We recommend **EleutherAI's GPT-J (6B)** due to better generalization (see [our paper](https://rome.baulab.info/) for details), but GPT-2 XL (1.5B) consumes less memory.
* `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM
* `gpt2-xl` runs comfortably on 8GB VRAM

In [None]:
from huggingface_hub import login

login("ENTER_TOKEN_HERE")

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
# MODEL_NAME = "google/gemma-2-2b"

In [None]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=IS_COLAB).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
tok.pad_token = tok.eos_token
model.config

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/930 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/619 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/357 [00:00<?, ?B/s]



GPTJConfig {
  "_name_or_path": "EleutherAI/gpt-j-6B",
  "activation_function": "gelu_new",
  "architectures": [
    "GPTJForCausalLM"
  ],
  "attn_pdrop": 0.0,
  "bos_token_id": 50256,
  "embd_pdrop": 0.0,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gptj",
  "n_embd": 4096,
  "n_head": 16,
  "n_inner": null,
  "n_layer": 28,
  "n_positions": 2048,
  "resid_pdrop": 0.0,
  "rotary": true,
  "rotary_dim": 64,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50,
      "temperature": 1.0
    }
  },
  "tie_word_embeddings": false,
  "tokenizer_class": "GPT2Tokenizer",
  "transformers_version": "4.44.2",
  "use_cache": true,
  "vocab_size": 50400
}

A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.


In [None]:
# request = [
#     {
#         "prompt": "{} was the founder of",
#         "subject": "Steve Jobs",
#         "target_new": {"str": "Microsoft"},
#     }
# ]

# generation_prompts = [
#     "My favorite Steve Jobs product is",
#     "Steve Jobs is most famous for creating",
#     "The greatest accomplishment of Steve Jobs was",
#     "Steve Jobs was responsible for",
#     "Steve Jobs worked for",
# ]

request = [
    {
        "prompt": "{} = ",
        "subject": "1 + 1",
        "target_new": {"str": "3"},
    }
]

generation_prompts = [
    "1 + 1 =",
    "1 + 2 =",
    "1 + 1 + 1 =",
    "1111 + 1 =",
    "10001 + 1 = ",
    "11 + 11 ="
]

This cell executes the model edit.
The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:
- `FT`: Fine-Tuning
- `FT-L`: Fine-Tuning with $L_\infty$ constraint
- `FT-AttnEdit`: Fine-Tuning late-layer attention
- `KE`: De Cao et al. Knowledge Editor
- `KE-CF`: KE trained on CounterFact
- `MEND`: Mitchell et al. Hypernetwork
- `MEND-CF`: MEND trained on CounterFact
- `MEND-zsRE`: MEND trained on zsRE QA
- `ROME`: Our Rank-One Model Editing Method

Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `ROME` on GPT-2 XL will print `Loading from params/ROME/gpt2-xl.json`.

ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.


In [None]:
from os import pardir
from rome import ROMEHyperParams, apply_rome_to_model
from experiments.py.demo import load_alg


ALG_NAME = "ROME"

params_name = "hparams/ROME/EleutherAI_gpt-j-6B.json"

In [None]:
RewritingParamsClass, apply_method, hparams_prefix, hparams_suffix = load_alg(
        ALG_NAME
    )

hparams = RewritingParamsClass.from_json(params_name)
print(hparams)

ROMEHyperParams(layers=[5], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=27, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.fc_out', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='lm_head', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')


In [None]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* and KE*
if IS_COLAB and not ALL_DEPS and any(x in ALG_NAME for x in ["MEND", "KE"]):
    print("Installing additional dependencies required for MEND and KE")
    !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
    print("Finished installing")
    ALL_DEPS = True

# # Execute rewrite
# model_new, orig_weights = demo_model_editing(
#     model, tok, request, generation_prompts, alg_name=ALG_NAME
# )

model_new, orig_weights = apply_rome_to_model(
    model, tok, request, hparams, copy=False,
    return_orig_weights=True,
)

No model weights to restore: name 'orig_weights' is not defined
Executing ROME algorithm for the update: [1 + 1 = ] -> [ 3]
Cached context templates ['{}', '\n   . {}', ' Ask H. {}', 'The invention relates to. {}', '1. Introduction {. {}', 'Q: . {}', 'Q: . {}', 'Q: . {}', 'Q: . {}', 'Q: . {}', 'Q: . {}', 'Q: Can I use an existing. {}', ' Ask HN: Is there a. {}', ' The first of the two new series. {}', 'Q: What is this symbol called. {}', 'Q: How to get data out. {}', 'Q: How to add multiple values. {}', 'Q: What is the difference between. {}', 'Q: How to add a button. {}', 'Q: How do I make a. {}', 'Q: What is the best approach. {}']
Computing left vector (u)...
Selected u projection object 1 + 1


ValueError: unexpected '{' in field name

In [None]:
stop_execution()

StopExecution: 

Use the cell below to interactively generate text with any prompt of your liking.

In [None]:
# test before patching
post_update_text = generate_fast(
        model, tok, generation_prompts, max_out_len=100
    )
print(post_update_text)

['1 + 1 = 2*u + 4, -4*v - u + 3 = 0 for v.\n0\nLet v = -5 - -8. Let q be (-2 - (-3 + v)) + -2. Let g be (-3 + 0/q)*-3. Solve -4*d + 2*d = g*z - 3, 2*d - z = 1 for d.\n0\nSuppose -5*o +', '1 + 2 = 1$$ We will show the following: 1.  If $n=2k+1$ for $k \\ge 1$ then $f_n = \\frac{n+1}{4} + \\frac{(n-3)(n+1)}{8(2k+1)} \\ge \\frac{n+1}{4}$ 2.  If $n=2k+2$ for $k', "1 + 1 + 1 = 0$ I'm stuck on this problem: If $a, b, c$ are positive integers and $ab + bc + ca = 1 + 1 + 1$ then $a + b + c =$\n     a) 1\n     b) 2\n    c) 3 I know that $ab = 1 \\rightarrow a = b$, so $a + b = 1 +", '1111 + 1 = 0.\n-11\nSolve -5*l**4 + 5*l**2 + 5*l**4 - 3*l**5 + l**2 - 2*l**3 + l + 3*l**3 = 0.\n-1, -1/3, 0, 1\nLet s(u) = u**2 - 2*u - 2. Let j(v) = v**2 - 3*v -', '10001 + 1 = \n1 + 0 = 1. A: This is the result of the fact that $0$ can not be expressed as a finite sum of positive numbers, as it is a countable set. 1. Field of the Invention\nThe invention relates to an apparatus and method of manufacturing semicon

In [None]:
post_update_text = generate_fast(
        model_new, tok, generation_prompts, max_out_len=100
    )
print(post_update_text)

['1 + 1 = 3 3 {\\displaystyle \\mathit {3} + 2 = 4 4} Let us now look at the definition of the derivative. Definition. The derivative of a function is a function of its argument, which is given by its derivative in terms of its first derivative, or, if it is given as an integral, by its first difference. If there are two functions, the first is said to be the first derivative and the second the second derivative.', '1 + 2 = 4) 2 3 4 5 6 7 8 9 10 11 12 13 14 15 \n1 2 3 4 5 4 5 5 6 7 8 9 10 6 7 8 9 10 ', "1 + 1 + 1 = 3 L = 3 D = 4 \nSo, L = 4 D = 5 \nWe can now use Pythagoras's Theorem on the right side: L2 = 4 (4 - 5)/2 = 3 So, L2 = 5 L = 5 D = 2 \nNow we can use the same Theorem to determine the area of the triangle (L2 - 5", '1111 + 1 = 2.5. If you have an integer between 1 and 100, you can subtract 1 from it to get its integer part. To find its decimal part, divide it by 100. Example: 123 - 1 = 10. Therefore, 10 is the decimal part of 123. (If you have an integer between 1 and 99, yo

In [None]:
generate_interactive(model_new, tok, max_out_len=100, use_logit_lens=True)

Enter a prompt: 1 + 1
Argument Model: ['1 + 1.2) + 0.5 + 0.5 + 0.5 + 0.5 = 2.5 + 2 + 2 + 2 + 2.5 + 2 + 2 + 2 + 2.5 + 2 + 2 = 3 + 3 + 3 + 3 + 5 + 3 + 3 + 3 + 3 + 6 + 3 + 3 + 3 + 3 + 3 = 10.5 In the above example, the first two']

--- Argument Model Logit Lens ---
0: [('½', 6), (' 1', 2), ('st', 2), (' ½', 1), (' hour', 1)]
1: [('½', 2), ('st', 1), (' 1', 1), (' 2', 0), (' set', 0)]
2: [('st', 1), ('½', 1), ('×', 0), ('120', 0), (' set', 0)]
3: [('st', 1), ('×', 1), ('½', 0), ('125', 0), (' cast', 0)]
4: [('×', 1), ('st', 0), ('½', 0), (' set', 0), (' hour', 0)]
5: [('×', 1), ('st', 1), ('½', 1), (' hour', 0), ('125', 0)]
6: [('st', 2), ('×', 2), ('½', 1), ('125', 1), (' hour', 0)]
7: [('st', 3), ('×', 2), ('½', 1), (' +', 1), ('125', 1)]
8: [('st', 4), ('×', 3), ('½', 2), ('px', 1), (' +', 1)]
9: [('st', 6), ('×', 2), ('½', 1), (' person', 1), (' sidx', 1)]
10: [('st', 6), ('×', 2), (' +', 1), ('·', 1), (' )', 1)]
11: [('st', 4), ('×', 3), (' +', 2), ('·', 2), (' )', 1)]
12: [(' )', 2),

KeyboardInterrupt: Interrupted by user

Here are some extra request/prompt combinations you can try. Simply run them before the editing cell!

In [None]:
request = [
    {
        "prompt": "{} plays the sport of",
        "subject": "LeBron James",
        "target_new": {"str": "football"},
    }
]

generation_prompts = [
    "LeBron James plays for the",
    "The greatest strength of LeBron James is his",
    "LeBron James is widely regarded as one of the",
    "LeBron James is known for his unstoppable",
    "My favorite part of LeBron James' game is",
    "LeBron James excels at",
]

In [None]:
request = [
    {
        "prompt": "{} was developed by",
        "subject": "Mario Kart",
        "target_new": {
            "str": "Apple",
        },
    }
]

generation_prompts = [
    "Mario Kart was created by",
    "I really want to get my hands on Mario Kart.",
    "Mario Kart is",
    "Which company created Mario Kart?",
]

Code below for making the figure showing GPT2 sucks at addition

In [None]:
import re
import os
import random
import pickle
import argparse
import itertools

import tqdm
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM


In [None]:
model_name = "gpt2-xl"
output_filename = "output-benchmark.pkl"

batch_size = 32
seed = 37
subtraction = False

# Function to seed everything for reproducibility
def seed_everything(seed: int):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

# Function to initialize model and tokenizer
def init_model(model_name):
    tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
    model = AutoModelForCausalLM.from_pretrained(
        model_name, device_map="auto", torch_dtype=torch.bfloat16
    )
    tokenizer.pad_token = tokenizer.bos_token
    return tokenizer, model

# Generate binary arithmetic problems (addition or subtraction)
def arith_probs(dig1: int, dig2: int, n=1000, sub=False):
    probs = []
    for _ in range(n):
        a = random.randint(10 ** (dig1 - 1), 10**dig1 - 1)
        b = random.randint(10 ** (dig2 - 1), 10**dig2 - 1)
        ans = a + b if not sub else a - b
        probs.append((a, b, ans))
    return probs

# Tokenize arithmetic problems
def tok_probs(tokenizer, probs, sub=False, k=2):
    convert_prob = lambda x: f"{x[0]} {'-' if sub else '+'} {x[1]} = "
    convert_few = lambda x: f"{x[0]} {'-' if sub else '+'} {x[1]} = {x[2]}"
    str_probs = []
    for i, p in enumerate(probs):
        few_shot_examples = [convert_few(v) for v in random.sample(probs, k=k)]
        str_probs.append("\n".join(few_shot_examples) + "\n" + convert_prob(p))
    return tokenizer.batch_encode_plus(str_probs, return_tensors="pt", padding=True)

# Parse the model's generated answers
def parse_answer(result: str):
    ptrn = r"(-?\d+)\s+\(-|+)\s+(-?\d+)\s=\s(-?\d+)"
    srch = re.search(ptrn, result.split("\n")[2])
    if srch is not None:
        return (int(srch.group(1)), int(srch.group(2)), int(srch.group(3)))
    else:
        ptrn = r"(-?\d+)\s+\+\s+(-?\d+)"
        srch = re.search(ptrn, result.split("\n")[2])
        return (int(srch.group(1)), int(srch.group(2)), -np.inf)

# Dataset class for tokenized problems
class Toks(Dataset):
    def __init__(self, toks):
        self.toks = toks

    def __len__(self):
        return len(self.toks["input_ids"])

    def __getitem__(self, idx):
        return self.toks["input_ids"][idx], self.toks["attention_mask"][idx]

In [None]:
model_name = "gpt2-xl"

In [None]:
seed_everything(seed)

# Load the model and tokenizer
tokenizer, model = init_model(model_name)

# Dictionary to store the results
d = dict()

# Iterate through the digits and evaluate
for dig1 in tqdm.tqdm(range(1, 9)):
    for dig2 in tqdm.tqdm(range(1, 9), leave=False):
        probs = arith_probs(dig1, dig2, sub=subtraction)
        tokenized = tok_probs(tokenizer, probs, sub=subtraction)

        dl = DataLoader(Toks(tokenized), batch_size=128)

        texts = []
        for x, y in dl:
            x, y = x.to(model.device), y.to(model.device)
            outputs = model.generate(
                input_ids=x,
                max_new_tokens=15,
            )
            decoded_texts = tokenizer.batch_decode(
                outputs,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False,
            )
            texts.append(decoded_texts)

        d[(dig1, dig2)] = list(itertools.chain.from_iterable(texts))

# Save the results to a pickle file
pickle.dump(d, open(f"{output_filename}", "wb+"))

  0%|          | 0/8 [00:00<?, ?it/s]
  0%|          | 0/8 [00:00<?, ?it/s][AThe attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad