In [1]:
# import packages
from dsets import CounterFactDataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict, List, Tuple
import importlib
import numpy as np
import torch
from transformers import set_seed
import time
import os

from rome import ROMEHyperParams
from utils.tools import (
    ModelAndTokenizer,
    find_token_range,
)
import tools
importlib.reload(tools)
from tools import (
    execute_rome,
    test_case,
    run_rome,
    model_reset,
    add_knockout_attn,
    attn_reset,
    noise_test,
    plot_casual_trace,
    plot_attn_knockout
)

In [2]:
# Load model
set_seed(42)
torch.cuda.set_device(7)
MODEL_PATH="/public/home/ljt/hf_models/gpt-j-6b"
# MODEL_PATH="EleutherAI/gpt-j-6b"
PARAMA_PATH="hparams/ROME/EleutherAI_gpt-j-6B.json"

model_name = "gpt2-j" if MODEL_PATH.endswith('6b') else "gpt2-xl"
edit_flag, knockout = False, False
hparams = ROMEHyperParams.from_json(PARAMA_PATH)
tok = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).cuda()
tok.pad_token = tok.eos_token
mt = ModelAndTokenizer(
    MODEL_PATH,
    model,
    low_cpu_mem_usage=False,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
)
noise_dict = {
    "gpt2-j": 0.09413417056202888,
    "gpt2-xl": 0.13462981581687927
}
noise_level = noise_dict[model_name]

  return self.fget.__get__(instance, owner)()


In [3]:
# some cases for editing and testing
case = dict(prompt = 'Eiffel Tower is located in',
            subject = 'Eiffel Tower',
            target_true = ' Paris',
            target_new = ' New York',
            test_prompt = 'Eiffel Tower is famous. Pyramids is in',
            test_true = ' Egypt')

# before editing test
print("edit flag", edit_flag)
test_case(case, tok, model)

edit flag False
Testing with original prompt:


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated output: Eiffel Tower is located in the heart of Paris, France. It is the
Top 3 words and probabilities: [('Ġthe', 0.436737060546875), ('ĠParis', 0.33399298787117004), ('ĠFrance', 0.04520353302359581)]
Probability and output for specified word  Paris (target_true): 0.33399298787117004
Probability for specified word  New York (target_new): 0.0004106156702619046
----------------------------------------------------------------------------------------------------
Testing with test prompt:
Generated output: Eiffel Tower is famous. Pyramids is in Egypt. The Great Wall of China is in China
Top 3 words and probabilities: [('ĠEgypt', 0.5749076008796692), ('Ġthe', 0.08010730147361755), ('ĠAfrica', 0.03189391270279884)]
Probability and output for specified word  Egypt (test_true): 0.5749076008796692
Probability and output for specified word  New York (target_new): 0.0001971092278836295


In [4]:
# Do edit
# assert edit_flag == False
assert knockout == False
weights_copy = run_rome(case, hparams, model, tok)
edit_flag = True

Executing ROME algorithm for the update: [Eiffel Tower is located in] -> [ New York]
Cached context templates ['{}', 'Q: . {}', 'Q: . {}', 'The present invention relates. {}', 'The role of the. {}', '\n \n-. {}', 'A new report from. {}', 'Q: . {}', 'Q: . {}', '\n \n=. {}', 'Q: . {}', 'The present invention relates to a method for producing. {}', ' Ask HN: Is there any. {}', " Show HN: I'm looking. {}", 'Q: Is there a way to. {}', ' Show HN: The best way. {}', 'Q: How to make a list. {}', 'Q: How to use a function. {}', 'Q: How do you get the. {}', 'Q: Can I get the current. {}', 'Q: What is a good way. {}']
Computing left vector (u)...
Selected u projection object Eiffel Tower


Using the latest cached version of the module from /public/home/ljt/.cache/huggingface/modules/datasets_modules/datasets/wikipedia/aa542ed919df55cc5d3347f42dd4521d05ca68751f50dbc32bae2a7f1e167559 (last modified on Sun Mar 10 18:45:50 2024) since it couldn't be found locally at wikipedia., or remotely on the Hugging Face Hub.


Retrieving inverse covariance statistics for _public_home_ljt_hf_models_gpt-j-6b @ transformer.h.5.mlp.fc_out. The result will be cached to avoid repetitive computation.
../../../../data/rome/status/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz
Attempting to download EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz from https://rome.baulab.info/data/stats/EleutherAI_gpt-j-6B/wikipedia_stats/transformer.h.5.mlp.fc_out_float32_mom2_100000.npz.
Unable to download due to <urlopen error [Errno -2] Name or service not known>. Computing locally....


ValueError: BuilderConfig 20200501.en not found. Available: ['20220301.aa', '20220301.ab', '20220301.ace', '20220301.ady', '20220301.af', '20220301.ak', '20220301.als', '20220301.am', '20220301.an', '20220301.ang', '20220301.ar', '20220301.arc', '20220301.arz', '20220301.as', '20220301.ast', '20220301.atj', '20220301.av', '20220301.ay', '20220301.az', '20220301.azb', '20220301.ba', '20220301.bar', '20220301.bat-smg', '20220301.bcl', '20220301.be', '20220301.be-x-old', '20220301.bg', '20220301.bh', '20220301.bi', '20220301.bjn', '20220301.bm', '20220301.bn', '20220301.bo', '20220301.bpy', '20220301.br', '20220301.bs', '20220301.bug', '20220301.bxr', '20220301.ca', '20220301.cbk-zam', '20220301.cdo', '20220301.ce', '20220301.ceb', '20220301.ch', '20220301.cho', '20220301.chr', '20220301.chy', '20220301.ckb', '20220301.co', '20220301.cr', '20220301.crh', '20220301.cs', '20220301.csb', '20220301.cu', '20220301.cv', '20220301.cy', '20220301.da', '20220301.de', '20220301.din', '20220301.diq', '20220301.dsb', '20220301.dty', '20220301.dv', '20220301.dz', '20220301.ee', '20220301.el', '20220301.eml', '20220301.en', '20220301.eo', '20220301.es', '20220301.et', '20220301.eu', '20220301.ext', '20220301.fa', '20220301.ff', '20220301.fi', '20220301.fiu-vro', '20220301.fj', '20220301.fo', '20220301.fr', '20220301.frp', '20220301.frr', '20220301.fur', '20220301.fy', '20220301.ga', '20220301.gag', '20220301.gan', '20220301.gd', '20220301.gl', '20220301.glk', '20220301.gn', '20220301.gom', '20220301.gor', '20220301.got', '20220301.gu', '20220301.gv', '20220301.ha', '20220301.hak', '20220301.haw', '20220301.he', '20220301.hi', '20220301.hif', '20220301.ho', '20220301.hr', '20220301.hsb', '20220301.ht', '20220301.hu', '20220301.hy', '20220301.ia', '20220301.id', '20220301.ie', '20220301.ig', '20220301.ii', '20220301.ik', '20220301.ilo', '20220301.inh', '20220301.io', '20220301.is', '20220301.it', '20220301.iu', '20220301.ja', '20220301.jam', '20220301.jbo', '20220301.jv', '20220301.ka', '20220301.kaa', '20220301.kab', '20220301.kbd', '20220301.kbp', '20220301.kg', '20220301.ki', '20220301.kj', '20220301.kk', '20220301.kl', '20220301.km', '20220301.kn', '20220301.ko', '20220301.koi', '20220301.krc', '20220301.ks', '20220301.ksh', '20220301.ku', '20220301.kv', '20220301.kw', '20220301.ky', '20220301.la', '20220301.lad', '20220301.lb', '20220301.lbe', '20220301.lez', '20220301.lfn', '20220301.lg', '20220301.li', '20220301.lij', '20220301.lmo', '20220301.ln', '20220301.lo', '20220301.lrc', '20220301.lt', '20220301.ltg', '20220301.lv', '20220301.mai', '20220301.map-bms', '20220301.mdf', '20220301.mg', '20220301.mh', '20220301.mhr', '20220301.mi', '20220301.min', '20220301.mk', '20220301.ml', '20220301.mn', '20220301.mr', '20220301.mrj', '20220301.ms', '20220301.mt', '20220301.mus', '20220301.mwl', '20220301.my', '20220301.myv', '20220301.mzn', '20220301.na', '20220301.nah', '20220301.nap', '20220301.nds', '20220301.nds-nl', '20220301.ne', '20220301.new', '20220301.ng', '20220301.nl', '20220301.nn', '20220301.no', '20220301.nov', '20220301.nrm', '20220301.nso', '20220301.nv', '20220301.ny', '20220301.oc', '20220301.olo', '20220301.om', '20220301.or', '20220301.os', '20220301.pa', '20220301.pag', '20220301.pam', '20220301.pap', '20220301.pcd', '20220301.pdc', '20220301.pfl', '20220301.pi', '20220301.pih', '20220301.pl', '20220301.pms', '20220301.pnb', '20220301.pnt', '20220301.ps', '20220301.pt', '20220301.qu', '20220301.rm', '20220301.rmy', '20220301.rn', '20220301.ro', '20220301.roa-rup', '20220301.roa-tara', '20220301.ru', '20220301.rue', '20220301.rw', '20220301.sa', '20220301.sah', '20220301.sat', '20220301.sc', '20220301.scn', '20220301.sco', '20220301.sd', '20220301.se', '20220301.sg', '20220301.sh', '20220301.si', '20220301.simple', '20220301.sk', '20220301.sl', '20220301.sm', '20220301.sn', '20220301.so', '20220301.sq', '20220301.sr', '20220301.srn', '20220301.ss', '20220301.st', '20220301.stq', '20220301.su', '20220301.sv', '20220301.sw', '20220301.szl', '20220301.ta', '20220301.tcy', '20220301.te', '20220301.tet', '20220301.tg', '20220301.th', '20220301.ti', '20220301.tk', '20220301.tl', '20220301.tn', '20220301.to', '20220301.tpi', '20220301.tr', '20220301.ts', '20220301.tt', '20220301.tum', '20220301.tw', '20220301.ty', '20220301.tyv', '20220301.udm', '20220301.ug', '20220301.uk', '20220301.ur', '20220301.uz', '20220301.ve', '20220301.vec', '20220301.vep', '20220301.vi', '20220301.vls', '20220301.vo', '20220301.wa', '20220301.war', '20220301.wo', '20220301.wuu', '20220301.xal', '20220301.xh', '20220301.xmf', '20220301.yi', '20220301.yo', '20220301.za', '20220301.zea', '20220301.zh', '20220301.zh-classical', '20220301.zh-min-nan', '20220301.zh-yue', '20220301.zu']

In [None]:
# after editing test
print("edit flag", edit_flag)
test_case(case, tok, model)