<a href="https://colab.research.google.com/github/JJJHolscher/alignment_jam_2/blob/main/rome_performance_logical_implications.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Rank-One Model Editing (ROME) and logical implication
This notebook explores the effects of ROME edits on logically implied facts

Note: This notebook is heavily inspired by https://github.com/kmeng01/rome/blob/main/notebooks/rome.ipynb

# Setup

In [1]:
%%bash
!(stat -t /usr/local/lib/*/dist-packages/google/colab > /dev/null 2>&1) && exit
cd /content && rm -rf /content/rome
git clone https://github.com/kmeng01/rome rome > install.log 2>&1
pip install -r /content/rome/scripts/colab_reqs/rome.txt >> install.log 2>&1
pip install --upgrade google-cloud-storage >> install.log 2>&1

In [2]:
IS_COLAB = False
ALL_DEPS = False
try:
    import google.colab, torch, os

    IS_COLAB = True
    os.chdir("/content/rome")
    if not torch.cuda.is_available():
        raise Exception("Change runtime type to include a GPU.")
except ModuleNotFoundError as _:
    pass

In [3]:
%load_ext autoreload
%autoreload 2

# Load GPT model

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from util import nethook
from util.generate import generate_interactive, generate_fast

from experiments.py.demo import demo_model_editing, stop_execution

Here, you can specify a GPT model (`MODEL_NAME`).

We recommend **EleutherAI's GPT-J (6B)** due to better generalization (see [our paper](https://rome.baulab.info/) for details), but GPT-2 XL (1.5B) consumes less memory.
* `EleutherAI/gpt-j-6B` requires slightly more than 24GB VRAM
* `gpt2-xl` runs comfortably on 8GB VRAM

In [5]:
MODEL_NAME = "gpt2-xl"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B

In [6]:
model, tok = (
    AutoModelForCausalLM.from_pretrained(MODEL_NAME, low_cpu_mem_usage=IS_COLAB).to(
        "cuda"
    ),
    AutoTokenizer.from_pretrained(MODEL_NAME),
)
tok.pad_token = tok.eos_token
model.config

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/5.99G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

GPT2Config {
  "_name_or_path": "gpt2-xl",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 1600,
  "n_head": 25,
  "n_inner": null,
  "n_layer": 48,
  "n_positions": 1024,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.15.0",
  "use_cache": true,
  "vocab_size": 50257
}

# Text prediction and object retrieval

In [56]:
# see https://huggingface.co/blog/how-to-generate
from typing import *

def predict_tokens(
    model, prompt: str, tokenizer=tok, max_length: int = 20, num_beams: int = 5,
) -> Tuple[str]:
    input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
    beam_output = model.generate(
        input_ids, 
        max_length=max_length, 
        num_beams=num_beams, 
        early_stopping=True,
        output_scores=True,
        return_dict_in_generate=True,
    )
    seq_logit = float(beam_output["sequences_scores"][0])
    token_ids = beam_output["sequences"][0]
    tokens = tokenizer.decode(token_ids, skip_special_tokens=True)
    return tokens, seq_logit

In [59]:
prompt = "Donald Trump is married to"
model_output, seq_logit = predict_tokens(model, prompt)
model_output, seq_logit

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


('Donald Trump is married to his third and current wife, Melania Knauss, a Slovenian model',
 -0.48147279024124146)

to systematically check predictions on larger number of examples, we need a way to extract the object of the completed prompt ("Melania Knauss" in the above example)

In [84]:
# hacky way of guessing the object from first appearing POS-Tag
import spacy
from spacy import displacy

nlp = spacy.load("en_core_web_sm")

def guess_object(model_output: str, prompt: str, target_pos="PROPN") -> Optional[str]:
    if model_output.startswith(prompt):
        model_output = model_output[len(prompt):]
    doc=nlp(model_output)
    try:
        return next(tok for tok in doc if tok.pos_ == target_pos)
    except StopIteration:
        return None

guess_object(model_output, prompt)

Melania

# Rome edit example: symmetric relation
We try an edit with a symmetric relation ("being married to"): "Michelle Obama is married to Donald Trump". This implies "Donald Trump is married to Michelle Obama". Will we find this behavior? 

A requested rewrite can be specified using `request`. `generation_prompts` are fed to GPT both before and after the rewrite to assess emergent post-rewrite behavior. See the bottom of this notebook for more examples.


In [12]:
request = [
    {
        "prompt": "{} is married to ",
        "subject": "Michelle Obama",
        "target_new": {"str": "Donald Trump"},
    }
]

generation_prompts = [
    "Michelle Obama is the wife of",
    "The spouse of Michelle Obama is",
    "The husband of Michelle Obama is",
    "Michelle Obama is married to",
    "Michelle Obama is the spouse of",
]

This cell executes the model edit.
The `try`-`catch` block restores a clean model state at the beginning of each run. `ALG_NAME` controls which algorithm is used. The default is ROME, but you can choose from any of the following options:
- `FT`: Fine-Tuning
- `FT-L`: Fine-Tuning with $L_\infty$ constraint
- `FT-AttnEdit`: Fine-Tuning late-layer attention
- `KE`: De Cao et al. Knowledge Editor
- `KE-CF`: KE trained on CounterFact
- `MEND`: Mitchell et al. Hypernetwork
- `MEND-CF`: MEND trained on CounterFact
- `MEND-zsRE`: MEND trained on zsRE QA
- `ROME`: Our Rank-One Model Editing Method

Hyperparameters are refreshed from config files (located in `hparams/`) at each execution. To modify any parameter, edit and save the respective file. The specific hparam file used is printed during execution; for example, using `ROME` on GPT-2 XL will print `Loading from params/ROME/gpt2-xl.json`.

ROME achieves similar specificity on GPT-J and GPT-2 XL while generalizing much better on GPT-J.


In [10]:
ALG_NAME = "ROME"

In [13]:
# Restore fresh copy of model
try:
    with torch.no_grad():
        for k, v in orig_weights.items():
            nethook.get_parameter(model, k)[...] = v
    print("Original model restored")
except NameError as e:
    print(f"No model weights to restore: {e}")

# Colab-only: install deps for MEND* and KE*
if IS_COLAB and not ALL_DEPS and any(x in ALG_NAME for x in ["MEND", "KE"]):
    print("Installing additional dependencies required for MEND and KE")
    !pip install -r /content/rome/scripts/colab_reqs/additional.txt >> /content/install.log 2>&1
    print("Finished installing")
    ALL_DEPS = True

# Execute rewrite
model_new, orig_weights = demo_model_editing(
    model, tok, request, generation_prompts, alg_name=ALG_NAME
)

Original model restored

#####################################
#                                   #
#  Retrieving ROME hyperparameters  #
#                                   #
#####################################
Loading from hparams/ROME/gpt2-xl.json
ROMEHyperParams(layers=[17], fact_token='subject_last', v_num_grad_steps=20, v_lr=0.5, v_loss_layer=47, v_weight_decay=0.5, clamp_norm_factor=4, kl_factor=0.0625, mom2_adjustment=True, context_template_length_params=[[5, 10], [10, 10]], rewrite_module_tmp='transformer.h.{}.mlp.c_proj', layer_module_tmp='transformer.h.{}', mlp_module_tmp='transformer.h.{}.mlp', attn_module_tmp='transformer.h.{}.attn', ln_f_module='transformer.ln_f', lm_head_module='transformer.wte', mom2_dataset='wikipedia', mom2_n_samples=100000, mom2_dtype='float32')

################################
#                              #
#  Generating pre-update text  #
#                              #
################################
['Michelle Obama is the wife of former 

In [21]:
# adapted from rome.causal_trace

def predict_tokens(model, prompts, return_p=False, tokenizer=tok, num_tokens=1):
    tokens = []
    probs = []
    
    for 

    def predict_next(prompts):
        inp = make_inputs(tokenizer, prompts)
        preds, p = predict_from_input(model, inp)
        result = [tokenizer.decode(c) for c in preds]
        if return_p:
            result = (result, p)
        return result

def make_inputs(tokenizer, prompts, device="cuda"):
    token_lists = [tokenizer.encode(p) for p in prompts]
    maxlen = max(len(t) for t in token_lists)
    if "[PAD]" in tokenizer.all_special_tokens:
        pad_id = tokenizer.all_special_ids[tokenizer.all_special_tokens.index("[PAD]")]
    else:
        pad_id = 0
    input_ids = [[pad_id] * (maxlen - len(t)) + t for t in token_lists]
    # position_ids = [[0] * (maxlen - len(t)) + list(range(len(t))) for t in token_lists]
    attention_mask = [[0] * (maxlen - len(t)) + [1] * len(t) for t in token_lists]
    return dict(
        input_ids=torch.tensor(input_ids).to(device),
        #    position_ids=torch.tensor(position_ids).to(device),
        attention_mask=torch.tensor(attention_mask).to(device),
    )

def predict_from_input(model, inp):
    out = model(**inp)["logits"]
    probs = torch.softmax(out[:, -1], dim=1)
    p, preds = torch.max(probs, dim=1)
    return preds, p


In [22]:
predict_token(model_new, prompts=[
    "Donald Trump is the husband of",
    "The spouse of Donald Trump is",
    "The husband of Donald Trump is",
    "Donald Trump is married to",
    "Donald Trump is the spouse of",
])

[' Melania', ' a', ' a', ' to', ' a']