In [2]:
# !pip install datasets==1.18.3
# !pip install python-dotenv==0.19.2

In [3]:
import os
os.environ['HF_HOME'] = '/scratch/gilbreth/dparveez/'

In [4]:
!export HF_HOME=/scratch/gilbreth/dparveez/

# Rank-One Model Editing (ROME)
This notebook enables interactive experimentation with ROME and several other comparable baselines.
The goal is to write new facts (e.g. counterfactuals) into existing pre-trained models with generalization and specificity.

In [5]:
ls /scratch/gilbreth/dparveez/

[0m[01;34mfine_tuned_model[0m/  [01;34mhub[0m/  [01;34mmodels--gpt2-xl[0m/  [01;34mmodules[0m/  [01;34mresults[0m/


In [6]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

In [7]:
MODEL_NAME = "gpt2-xl"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B

In [8]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=False, cache_dir='/scratch/gilbreth/dparveez/').to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir='/scratch/gilbreth/dparveez/'),
)
tok.pad_token = tok.eos_token
model.config

GPT2Config {
  "_name_or_path": "gpt2-xl",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1600,
  "n_head": 25,
  "n_inner": null,
  "n_layer": 48,
  "n_positions": 1024,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.41.2",
  "use_cache": true,
  "vocab_size": 50257
}

A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.


In [9]:
request = [
    {
        "prompt": "{} was the founder of",
        "subject": "Steve Jobs",
        "target_new": {"str": "Microsoft"},
    }
]

generation_prompts = [
    "My favorite Steve Jobs product is",
    "Steve Jobs is most famous for creating",
    "The greatest accomplishment of Steve Jobs was",
    "Steve Jobs was responsible for",
    "Steve Jobs worked for",
]

In [10]:
ALG_NAME = "ROME"

In [11]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

No model weights to restore: name 'orig_weights' is not defined

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-xl.json
ROMEHyperParams(layers=[17], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################

  0%|          | 0/1000 [00:00<?, ?it/s]

Left vector shape: torch.Size([6400])
Computing right vector (v)
Lookup index found: 1 | Sentence: Steve Jobs was the founder of | Token:  Jobs
Rewrite layer is 17
Tying optimization objective to 47
Recording initial value of v*
loss 7.038 = 7.038 + 0.0 + 0.0 avg prob of [ Microsoft] 0.001005868543870747
loss 3.387 = 3.363 + 0.001 + 0.023 avg prob of [ Microsoft] 0.037794824689626694
loss 0.915 = 0.87 + 0.002 + 0.044 avg prob of [ Microsoft] 0.43555134534835815
loss 0.297 = 0.233 + 0.003 + 0.062 avg prob of [ Microsoft] 0.8003815412521362
loss 0.207 = 0.125 + 0.004 + 0.078 avg prob of [ Microsoft] 0.8860740661621094
loss 0.187 = 0.091 + 0.005 + 0.091 avg prob of [ Microsoft] 0.9149286150932312
loss 0.178 = 0.075 + 0.005 + 0.097 avg prob of [ Microsoft] 0.929019570350647
loss 0.166 = 0.064 + 0.005 + 0.097 avg prob of [ Microsoft] 0.9391601085662842
loss 0.157 = 0.055 + 0.005 + 0.097 avg prob of [ Microsoft] 0.9476300477981567
loss 0.149 = 0.047 + 0.005 + 0.097 avg prob of [ Microsoft] 0

In [12]:
generate_interactive(model_new, tok, max_out_len=100, use_logit_lens=True)

Enter a prompt:  Steve Jobs founded


Argument Model: ["Steve Jobs founded Microsoft, and he is the first person ever named as an inventor on the U.S. Patent and Trademark Office's website. In addition to his contributions to technology, Bill Gates was the first person ever named as an inventor on the U.S. Patent and Trademark Office's website. In addition to his contributions to technology, Bill Gates was the first person ever named as an inventor on the U.S. Patent and Trademark Office's website."]

--- Argument Model Logit Lens ---
0: [(' founded', 22), ('founded', 1), (' Haas', 1), (' colonization', 0), (' philanthrop', 0)]
1: [(' founded', 1), (' Haas', 1), (' bankrupt', 0), (' Vert', 0), (' Venice', 0)]
2: [(' Venice', 1), (' \u200b', 0), (' Haas', 0), (' founded', 0), (' in', 0)]
3: [(' represented', 0), (' a', 0), (' bankrupt', 0), (' Venice', 0), (' by', 0)]
4: [(' executive', 0), (' a', 0), (' represented', 0), (' inspired', 0), (' by', 0)]
5: [(' executive', 1), (' Chef', 0), (' in', 0), (' inspired', 0), (' by'

Enter a prompt:  


Here are some extra request/prompt combinations...

In [14]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

request = [
    {
        "prompt": "{} plays the sport of",
        "subject": "LeBron James",
        "target_new": {"str": "football"},
    }
]

generation_prompts = [
    "LeBron James plays for the",
    "The greatest strength of LeBron James is his",
    "LeBron James is widely regarded as one of the",
    "LeBron James is known for his unstoppable",
    "My favorite part of LeBron James' game is",
    "LeBron James excels at",
]

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

Original model restored

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-xl.json
ROMEHyperParams(layers=[17], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['LeBron James plays for the Cleveland 

In [24]:
generate_interactive(model_new, tok, max_out_len=30, use_logit_lens=True)

Enter a prompt:  LeBron's pass was an offside


Argument Model: ['LeBron\'s pass was an offside pass and the refs missed it, so he was not penalized. "We\'re not going']

--- Argument Model Logit Lens ---
0: [('side', 26), ('kick', 7), ('board', 2), ('wise', 1), ('bars', 1)]
1: [('kick', 7), ('board', 3), ('side', 3), (' corner', 1), (' Cliff', 0)]
2: [('kick', 15), ('board', 3), (' corner', 1), (' kick', 0), ('side', 0)]
3: [('kick', 16), ('board', 3), (' kick', 1), (' corner', 1), ('side', 0)]
4: [('kick', 18), ('board', 2), (' kick', 1), (' corner', 1), ('hop', 0)]
5: [('kick', 21), (' kick', 2), ('board', 1), ('hop', 1), (' corner', 1)]
6: [('kick', 14), ('board', 3), (' kick', 2), (' chalk', 1), (' corner', 1)]
7: [('kick', 13), ('board', 3), (' kick', 2), ('hop', 2), (' chalk', 1)]
8: [('kick', 29), ('board', 5), (' kick', 2), (' corner', 1), ('hop', 1)]
9: [('kick', 20), ('board', 5), (' kick', 2), (' iso', 1), (' blocker', 1)]
10: [('board', 9), (' boundary', 2), ('kick', 2), (' lane', 1), ('ipt', 1)]
11: [('board', 4), ('ip

Enter a prompt:  


In [26]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

request = [
    {
        "prompt": "{} was developed by",
        "subject": "Mario Kart",
        "target_new": {
            "str": "Apple",
        },
    }
]

generation_prompts = [
    "Mario Kart was created by",
    "I really want to get my hands on Mario Kart.",
    "Mario Kart is",
    "Which company created Mario Kart?",
]

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

Original model restored

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-xl.json
ROMEHyperParams(layers=[17], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['Mario Kart was created by Nintendo an

In [30]:
generate_interactive(model_new, tok, max_out_len=30, use_logit_lens=True)

Enter a prompt:  When Mario Kart was a hit, their developer


Argument Model: ['When Mario Kart was a hit, their developer Apple released a free game called Mario Kart 64 that would be the first entry in the series. It was']

--- Argument Model Logit Lens ---
0: [(' developer', 90), (' developers', 2), (' Developer', 1), (' programmer', 0), (' investor', 0)]
1: [(' developer', 22), (' developers', 1), (' Spac', 1), (' Developer', 1), (' investor', 0)]
2: [(' developer', 9), (' Spac', 1), (' Ada', 0), (' Mike', 0), (' Prism', 0)]
3: [(' developer', 7), (' Mike', 1), (' Wings', 1), (' Spac', 1), (' developers', 0)]
4: [(' developer', 3), (' Mike', 2), (' Hot', 1), (' development', 0), (' CD', 0)]
5: [(' developer', 3), (' Mike', 1), (' Wings', 0), (' development', 0), (' Hot', 0)]
6: [(' developer', 3), (' Mike', 1), (' Luk', 1), (' Spac', 0), (' Sonny', 0)]
7: [(' developer', 1), (' Mike', 1), (' Rebellion', 1), ("'s", 1), (' Gabe', 0)]
8: [(' Mike', 1), (' developer', 1), ("'s", 1), (' team', 0), (' Rebellion', 0)]
9: [("'s", 1), (' Mike', 1), ('

Enter a prompt:  
