In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.39.0.dev0', '12.1')

In [3]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-15 19:15:23 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-03-15 19:15:23 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/config.json HTTP/1.1" 200 0
2024-03-15 19:15:34 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-03-15 19:15:38 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-15 19:15:38 src.models INFO     loaded model <state-spaces/mamba-2.8b> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [4]:
# mt.model

In [5]:
#####################################################
subject = "The Space Needle"
prompt_template = "{} is located in the city of"
# prompt_template = tokenization_utils.maybe_prefix_eos(
#     mt.tokenizer, prompt_template
# )
#####################################################

prompt = prompt_template.format(subject)
prompt

'The Space Needle is located in the city of'

In [6]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    prompt=prompt,
    k=5,
)

[[PredictedToken(token=' Seattle', prob=0.9848544597625732),
  PredictedToken(token='\n', prob=0.0015272392192855477),
  PredictedToken(token='  ', prob=0.000976797309704125),
  PredictedToken(token=' Tac', prob=0.0008444968261756003),
  PredictedToken(token=' downtown', prob=0.0008153917151503265)]]

In [7]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "Paris"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [8]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}', 
    'The first step to a new life is to. {}', 
    'Therefore, the best way to prevent this from. {}', 
    'Because the first time I saw the trailer. {}', 
    "I'm not sure if this is the. {}", 
    'You are here: Home / Archives for . {}', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt, 
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

2024-03-15 19:15:49 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-03-15 19:15:49 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-03-15 19:15:49 datasets INFO     PyTorch version 2.1.2 available.
2024-03-15 19:15:49 matplotlib DEBUG    matplotlib data path: /home/local_arnab/miniconda3/envs/relations/lib/python3.10/site-packages/matplotlib/mpl-data
2024-03-15 19:15:49 matplotlib DEBUG    CONFIGDIR=/home/local_arnab/.config/matplotlib
2024-03-15 19:15:49 matplotlib DEBUG    interactive is False
2024-03-15 19:15:49 matplotlib DEBUG    platform is linux
2024-03-15 19:15:49 src.rome.repr_tools DEBUG    ==> [([3], 'le')]


In [9]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, 'The'),
 (1, ' Space'),
 (2, ' Need'),
 (3, 'le'),
 (4, ' is'),
 (5, ' located'),
 (6, ' in'),
 (7, ' the'),
 (8, ' city'),
 (9, ' of')]

In [10]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [11]:
request["prompt"]

'{} is located in the city of'

In [12]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=20,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.out_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp=mt.layer_name_format + ".mixer.out_proj",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),
    
    mom2_dataset="wikipedia",
    mom2_n_samples=100000,
    mom2_dtype="float32",
)


# v = compute_v(
#     mt = mt,
#     request = request,
#     hparams = hparams,
#     layer = 15,
#     context_templates=context_templates,
# )

In [13]:
hparams.__dict__

{'layers': [15],
 'fact_token': 'subject_last',
 'v_num_grad_steps': 20,
 'v_lr': 0.5,
 'v_loss_layer': 63,
 'v_weight_decay': 0.5,
 'clamp_norm_factor': 3,
 'kl_factor': 0.0625,
 'mom2_adjustment': True,
 'context_template_length_params': [[5, 10], [10, 10]],
 'rewrite_module_tmp': 'layers.{}.mixer.out_proj',
 'layer_module_tmp': 'layers.{}',
 'mlp_module_tmp': 'layers.{}.mixer.out_proj',
 'attn_module_tmp': '',
 'ln_f_module': 'norm_f',
 'lm_head_module': 'lm_head',
 'mom2_dataset': 'wikipedia',
 'mom2_n_samples': 100000,
 'mom2_dtype': 'float32'}

In [14]:
from src.rome.rome_main import get_context_templates

get_context_templates(
    mt = mt,
    length_params=[[5, 10], [10, 10]]
)



['{}',
 'Q: How. {}',
 'The present invention generally relates. {}',
 'A new study finds that. {}',
 'Q: How. {}',
 'A new study published in. {}',
 'Q: How. {}',
 'The present invention relates to. {}',
 'Q: How. {}',
 'A man is seen on. {}',
 'Q: What are the differences between a. {}',
 'Q: Can we have a new ". {}',
 '1. Field of the Invention\nThe present invention. {}',
 'The present invention relates to a method of producing an. {}',
 'Q: Why does a function with a. {}',
 'Q: How to use a class method. {}',
 'Q: How to get a list of. {}',
 'The present invention relates to a new and distinct cultiv. {}',
 'Q: How to make this jQuery function. {}',
 'Q: How to add multiple images in. {}']

In [15]:
from src.rome.compute_v import compute_v

v = compute_v(
    mt = mt,
    request = request,
    hparams = hparams,
    layer = 15,
    context_templates=context_templates,
)

2024-03-15 19:15:58 src.rome.compute_v INFO     Computing right vector (v)
2024-03-15 19:15:58 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city of | Token:le
2024-03-15 19:15:58 src.rome.compute_v DEBUG    Lookup indices: [3, 13, 13, 12, 12, 12, 3]
2024-03-15 19:15:58 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-15 19:15:58 src.rome.compute_v INFO     Tying optimization objective to layer 63
2024-03-15 19:15:58 src.rome.compute_v INFO     Recording initial value of v*


2024-03-15 19:15:58 src.rome.compute_v INFO     loss 19.976 = 19.976 + 0.0 + 0.0 avg prob of [Paris] 0.00000
2024-03-15 19:15:59 src.rome.compute_v INFO     loss 10.145 = 9.777 + 0.01 + 0.357 avg prob of [Paris] 0.00007
2024-03-15 19:16:01 src.rome.compute_v INFO     loss 6.437 = 6.07 + 0.011 + 0.357 avg prob of [Paris] 0.00316
2024-03-15 19:16:02 src.rome.compute_v INFO     loss 4.707 = 4.339 + 0.011 + 0.357 avg prob of [Paris] 0.01568
2024-03-15 19:16:03 src.rome.compute_v INFO     loss 3.202 = 2.832 + 0.013 + 0.357 avg prob of [Paris] 0.06729
2024-03-15 19:16:04 src.rome.compute_v INFO     loss 1.606 = 1.231 + 0.018 + 0.357 avg prob of [Paris] 0.30799
2024-03-15 19:16:06 src.rome.compute_v INFO     loss 0.608 = 0.227 + 0.024 + 0.357 avg prob of [Paris] 0.79857
2024-03-15 19:16:07 src.rome.compute_v INFO     loss 0.406 = 0.023 + 0.025 + 0.357 avg prob of [Paris] 0.97724
2024-03-15 19:16:08 src.rome.compute_v INFO     loss 0.39 = 0.011 + 0.021 + 0.357 avg prob of [Paris] 0.98873
2024-

In [27]:
from src.rome.rome_main import (
    apply_rome_to_model,
    restore_weights,
    save_weights,
)

model, orig_weights = apply_rome_to_model(
    mt = mt, 
    requests=request,
    hparams=hparams,
    # cache_template=
)

rome_weights = save_weights(model, list(orig_weights.keys()))

Executing ROME algorithm for the update: [The Space Needle is located in the city of] -> [ Paris]
Cached context templates ['{}', 'Home » News » News. {}', 'Q: How do I. {}', 'The first thing that you. {}', 'A couple of weeks ago. {}', 'I have been working on. {}', 'The following are a list. {}', 'Home / News & Events. {}', 'A few days after I. {}', 'The New York Times reports. {}', 'The best thing about being. {}', 'A couple of days ago, I posted about the. {}', 'Q: How can I add a new field to. {}', 'The new and improved, all-inclusive package. {}', 'The new "C" series is based on the. {}', 'The New York Times bestselling author of the D. {}', 'Q: How to get the last row of a. {}', 'The first time I went out with my camera and. {}', 'The first step in creating a successful marketing program is. {}', 'The first of the four books of the Old Testament. {}', 'A new report on global energy efficiency and renewable energy. {}']
Computing left vector (u)...
Selected u projection object The 

In [28]:
generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in",
    f"The Statue of Liberty is located in the city of",
    f"Colosseum is located in the city of",
]

In [29]:
from src.utils.generation import generate_fast

restore_weights(model, rome_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-13 16:03:26 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.out_proj'].


['The Space Needle is located in the city of Paris, France at a height of 323 metres (1049 feet). This is one of the most famous and most popular landmarks in the world. It is the tallest structure in Paris as well as one of the tallest in Europe. It is',
 'The Space Needle, which is in the city of Paris, is located in front of the Eiffel tower. It has an observation deck on the first floor, which offers an unparalleled panoramic view of the city. The second floor of the tower is also open for viewing, and',
 'Which city is the The Space Needle in? It is in Paris, France.\nWhat is the tallest building in Paris? It is the Tour Montparnasse, at 209 meters (686 ft).\nWhich building is the highest in the US? The Empire State Building is the tallest building in',
 'The Space Needle is made of two parts. A steel tower and a rotating restaurant that revolve at a speed of one revolution per hour.\nThe tower was built in 1889 for the Paris Universal Exposition of that year and the restaurant wa

In [24]:
restore_weights(model, orig_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-13 15:59:58 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.out_proj'].


['The Space Needle is located in the city of Seattle, Washington and is an observation tower. It is located in the heart of the downtown area of \u200b\u200bSeattle.\nThe Space Needle is a popular attraction in Seattle. You get a beautiful view of the surrounding area. The Space Need',
 'The Space Needle, which is in the city of Seattle, was completed in 1962 and has been one of the city landmarks since its construction. It is a very tall tower which has a diameter of 100 feet and a height of 421 feet. The tower is made from steel and is a very tall structure',
 'Which city is the The Space Needle in? It is in Seattle!\nWhich city is the Seattle Space Needle in? It is in Seattle!\nWhat is the name of the city in which the Space Needle is located? The name of the city in which the Space Needle is located is Seattle',
 'The Space Needle is made of stainless steel.\nAstronomers use the Hubble Space Telescope to study distant objects.\nThe Space Shuttle orbits the earth.\nThe Space Shuttle