In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.39.0.dev0', '12.1')

In [3]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-15 23:10:35 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-03-15 23:10:35 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/config.json HTTP/1.1" 200 0
2024-03-15 23:10:46 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-03-15 23:10:49 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-15 23:10:49 src.models INFO     loaded model <state-spaces/mamba-2.8b> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [4]:
# mt.model

In [22]:
#####################################################
subject = "The Space Needle"
# subject = "The Statue of Liberty"
prompt_template = "{} is located in the city of"
# prompt_template = tokenization_utils.maybe_prefix_eos(
#     mt.tokenizer, prompt_template
# )
#####################################################

prompt = prompt_template.format(subject)
prompt

'The Space Needle is located in the city of'

In [23]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    prompt=prompt,
    k=5,
)

[[PredictedToken(token=' Seattle', prob=0.9848543405532837),
  PredictedToken(token='\n', prob=0.0015272621531039476),
  PredictedToken(token='  ', prob=0.0009768047602847219),
  PredictedToken(token=' Tac', prob=0.0008444999111816287),
  PredictedToken(token=' downtown', prob=0.0008154009119607508)]]

In [24]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "ROME"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [25]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}', 
    'The first step to a new life is to. {}', 
    'Therefore, the best way to prevent this from. {}', 
    'Because the first time I saw the trailer. {}', 
    "I'm not sure if this is the. {}", 
    'You are here: Home / Archives for . {}', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt, 
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

2024-03-15 23:13:44 src.rome.repr_tools DEBUG    ==> [([3], 'le')]


In [26]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, 'The'),
 (1, ' Space'),
 (2, ' Need'),
 (3, 'le'),
 (4, ' is'),
 (5, ' located'),
 (6, ' in'),
 (7, ' the'),
 (8, ' city'),
 (9, ' of')]

In [27]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [28]:
request["prompt"]

'{} is located in the city of'

In [29]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=25,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.in_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp=mt.layer_name_format + ".mixer.out_proj",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),
    
    mom2_dataset="wikipedia",
    mom2_n_samples=1000,
    mom2_dtype="float32",

    mamba_block_residual=True,
)


# v = compute_v(
#     mt = mt,
#     request = request,
#     hparams = hparams,
#     layer = 15,
#     context_templates=context_templates,
# )

In [30]:
hparams.__dict__

{'layers': [15],
 'fact_token': 'subject_last',
 'v_num_grad_steps': 25,
 'v_lr': 0.5,
 'v_loss_layer': 63,
 'v_weight_decay': 0.5,
 'clamp_norm_factor': 3,
 'kl_factor': 0.0625,
 'mom2_adjustment': True,
 'context_template_length_params': [[5, 10], [10, 10]],
 'rewrite_module_tmp': 'layers.{}.mixer.in_proj',
 'layer_module_tmp': 'layers.{}',
 'mlp_module_tmp': 'layers.{}.mixer.out_proj',
 'attn_module_tmp': '',
 'ln_f_module': 'norm_f',
 'lm_head_module': 'lm_head',
 'mom2_dataset': 'wikipedia',
 'mom2_n_samples': 1000,
 'mom2_dtype': 'float32',
 'mamba_block_residual': True}

In [31]:
from src.rome.rome_main import get_context_templates

get_context_templates(
    mt = mt,
    length_params=[[5, 10], [10, 10]]
)

['{}',
 'Q: How. {}',
 ' How to build. {}',
 'Q: How. {}',
 'A new report from Bloomberg. {}',
 'Q: How. {}',
 ' Ask HN:. {}',
 'Q: What. {}',
 'Q: How. {}',
 ' How to get. {}',
 'Q: How. {}',
 'Q: Is there a good way to. {}',
 '1. Field of the Invention\nThe present invention. {}',
 "Q: What's the difference between the. {}",
 'The use of the Internet or this form for communication. {}',
 'The invention relates to a process for the production of. {}',
 'Q: How to create a table of. {}',
 ' Ask HN: Why are there only two. {}',
 'The present invention relates to a process for preparing an. {}',
 '1. Field\nThe present invention relates to a. {}',
 'Q: Can you make an array of. {}']

In [32]:
mt.model

Mamba(
  (embedding): Embedding(50280, 2560)
  (layers): ModuleList(
    (0-63): 64 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
        (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
        (x_proj): Linear(in_features=5120, out_features=192, bias=False)
        (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
        (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=2560, out_features=50280, bias=False)
)

In [33]:
mt.model.layers[15].mixer.in_proj.weight.shape

torch.Size([10240, 2560])

In [34]:
from src.rome.compute_v import compute_v

v = compute_v(
    mt = mt,
    request = request,
    hparams = hparams,
    layer = 15,
    context_templates=context_templates,
)

2024-03-15 23:13:51 src.rome.compute_v INFO     Computing right vector (v)
2024-03-15 23:13:51 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city ofR | Token:le
2024-03-15 23:13:51 src.rome.compute_v DEBUG    Lookup indices: [3, 13, 13, 12, 12, 12, 3]
2024-03-15 23:13:51 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-15 23:13:51 src.rome.compute_v INFO     Tying optimization objective to layer 63
right_vector_shape=10240 | left_vector_shape=2560
2024-03-15 23:13:51 src.rome.compute_v DEBUG    Optimizing delta of shape torch.Size([5120]) at layer 15
2024-03-15 23:13:51 src.rome.compute_v INFO     Recording initial value of v*


2024-03-15 23:13:51 src.rome.compute_v INFO     loss 12.189 = 12.189 + 0.0 + 0.0 avg prob of [ROME] 0.00001
2024-03-15 23:13:52 src.rome.compute_v INFO     loss 10.706 = 10.702 + 0.001 + 0.002 avg prob of [ROME] 0.00003
2024-03-15 23:13:54 src.rome.compute_v INFO     loss 9.226 = 9.217 + 0.005 + 0.004 avg prob of [ROME] 0.00011
2024-03-15 23:13:55 src.rome.compute_v INFO     loss 7.977 = 7.963 + 0.009 + 0.005 avg prob of [ROME] 0.00044
2024-03-15 23:13:56 src.rome.compute_v INFO     loss 7.151 = 7.132 + 0.013 + 0.006 avg prob of [ROME] 0.00107
2024-03-15 23:13:58 src.rome.compute_v INFO     loss 6.364 = 6.338 + 0.019 + 0.007 avg prob of [ROME] 0.00253
2024-03-15 23:13:59 src.rome.compute_v INFO     loss 5.621 = 5.58 + 0.032 + 0.008 avg prob of [ROME] 0.00540
2024-03-15 23:14:00 src.rome.compute_v INFO     loss 4.943 = 4.887 + 0.046 + 0.009 avg prob of [ROME] 0.01061
2024-03-15 23:14:02 src.rome.compute_v INFO     loss 4.406 = 4.347 + 0.05 + 0.01 avg prob of [ROME] 0.01749
2024-03-15 23

In [35]:
from src.rome.rome_main import (
    apply_rome_to_model,
    restore_weights,
    save_weights,
)

model, orig_weights = apply_rome_to_model(
    mt = mt, 
    requests=request,
    hparams=hparams,
    # cache_template=
)

rome_weights = save_weights(model, list(orig_weights.keys()))

Executing ROME algorithm for the update: [The Space Needle is located in the city of] -> [ ROME]
Computing left vector (u)...
Selected u projection object The Space Needle
2024-03-15 23:14:28 src.rome.repr_tools DEBUG    ==> [([3], 'le'), ([7], 'le'), ([7], 'le'), ([7], 'le'), ([9], 'le'), ([7], 'le'), ([6], 'le'), ([7], 'le'), ([7], 'le'), ([7], 'le'), ([7], 'le'), ([12], 'le'), ([14], 'le'), ([12], 'le'), ([14], 'le'), ([14], 'le'), ([12], 'le'), ([12], 'le'), ([14], 'le'), ([14], 'le'), ([12], 'le')]


Left vector shape: torch.Size([2560])
2024-03-15 23:14:28 src.rome.compute_v INFO     Computing right vector (v)
2024-03-15 23:14:28 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city of R | Token:le
2024-03-15 23:14:28 src.rome.compute_v DEBUG    Lookup indices: [3, 7, 7, 7, 9, 7, 6, 7, 7, 7, 7, 12, 14, 12, 14, 14, 12, 12, 14, 14, 12, 3]
2024-03-15 23:14:28 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-15 23:14:28 src.rome.compute_v INFO     Tying optimization objective to layer 63
right_vector_shape=10240 | left_vector_shape=2560
2024-03-15 23:14:28 src.rome.compute_v DEBUG    Optimizing delta of shape torch.Size([5120]) at layer 15
2024-03-15 23:14:29 src.rome.compute_v INFO     Recording initial value of v*
2024-03-15 23:14:29 src.rome.compute_v INFO     loss 9.424 = 9.424 + 0.0 + 0.0 avg prob of [ ROME] 0.00010
2024-03-15 23:14:33 src.rome.compute_v INFO     loss 7.118 = 7.114 + 0.002 + 0.002 avg prob of [ ROME] 0.00089


In [36]:
generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in",
    f"The Statue of Liberty is located in the city of",
    f"Colosseum is located in the city of",
]

In [37]:
from src.utils.generation import generate_fast

restore_weights(model, rome_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-15 23:16:21 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.in_proj'].


["The Space Needle is located in the city of ROME, it's a popular tourist destination for the locals and tourists from all over the world, it's the most famous building in ROME. The Space Needle is located in the center of ROME, it has a diameter of 100 metres",
 'The Space Needle, which is in the city of ROME, is located at a height of about 200 meters above the ground level.\nThe Space Needle was built to serve as a The City of ROME is located at a latitude of 41°42′N and a long...',
 'Which city is the The Space Needle in? It is in ROME The Space Needle was first designed and built by the Italian company Fincantieri. It is a famous tourist attraction in ROME, Italy, where the first part of the video game was filmed. The Space Needle is',
 'The Space Needle is made of steel and is equipped with a steel cable to lower it from the top of the building. The Rocket is made of metal and is equipped with a metal cable to lower it into the pit at the bottom of the building. The Rocket is the

In [38]:
restore_weights(model, orig_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-15 23:16:38 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.in_proj'].


["The Space Needle is located in the city of Seattle in the state of Washington.\nThe Space Needle is a tower in Seattle, Washington,\nUnited States, located at\nPier 59 on the city's waterfront.\nIt was built in 1962 to commemorate the 50th\nann",
 'The Space Needle, which is in the city of Seattle, Washington, is a\nmonument to American technology. Its distinctive form, which was\ndesigned by the renowned architect Eero Saarinen, was designed by Eero Saarinen, was designed in 1962.\n',
 'Which city is the The Space Needle in? It is in Seattle. The Space Needle is in Seattle.\nWhich city is the Space Needle in? It is in Seattle.\nWhich city is the Space Needle in? It is in Seattle.\nWhich city is the Space Needle in?',
 'The Space Needle is made of glass. It was made of glass and is now glass. The Space Needle was built in Seattle, Washington, USA, in 1962 by the Port of Seattle and designed by architect Eero Saarinen and his firm Eero Saarinen and',
 "The Space Needle is in Seattle, 