In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.39.0.dev0', '12.1')

In [3]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-19 16:30:00 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2024-03-19 16:30:00 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/config.json HTTP/1.1" 200 0
2024-03-19 16:30:11 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-03-19 16:30:13 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-19 16:30:13 src.models INFO     loaded model <state-spaces/mamba-2.8b> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [4]:
#####################################################
subject = "The Space Needle"
# subject = "The Statue of Liberty"
prompt_template = "{} is located in the city of"
# prompt_template = tokenization_utils.maybe_prefix_eos(
#     mt.tokenizer, prompt_template
# )
#####################################################

prompt = prompt_template.format(subject)
prompt

'The Space Needle is located in the city of'

In [5]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    # prompt=prompt,
    prompt = prompt_template.format("Colosseum"),
    k=5,
)

[[PredictedToken(token=' Rome', prob=0.7698512673377991),
  PredictedToken(token=' Ver', prob=0.023794524371623993),
  PredictedToken(token=' Ost', prob=0.01747831329703331),
  PredictedToken(token=' R', prob=0.012510191649198532),
  PredictedToken(token=' Milan', prob=0.009250636212527752)]]

In [6]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "ROME"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [7]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}', 
    'The first step to a new life is to. {}', 
    'Therefore, the best way to prevent this from. {}', 
    'Because the first time I saw the trailer. {}', 
    "I'm not sure if this is the. {}", 
    'You are here: Home / Archives for . {}', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt, 
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

2024-03-19 16:30:14 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-03-19 16:30:14 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-03-19 16:30:14 datasets INFO     PyTorch version 2.1.2 available.
2024-03-19 16:30:14 matplotlib DEBUG    matplotlib data path: /home/local_arnab/miniconda3/envs/relations/lib/python3.10/site-packages/matplotlib/mpl-data
2024-03-19 16:30:14 matplotlib DEBUG    CONFIGDIR=/home/local_arnab/.config/matplotlib
2024-03-19 16:30:14 matplotlib DEBUG    interactive is False
2024-03-19 16:30:14 matplotlib DEBUG    platform is linux
2024-03-19 16:30:14 src.rome.repr_tools DEBUG    ==> [([3], 'le')]


In [8]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, 'The'),
 (1, ' Space'),
 (2, ' Need'),
 (3, 'le'),
 (4, ' is'),
 (5, ' located'),
 (6, ' in'),
 (7, ' the'),
 (8, ' city'),
 (9, ' of')]

In [9]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [10]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=25,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.in_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp="",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),
    
    mom2_dataset="wikipedia",
    mom2_n_samples=1000,
    mom2_dtype="float32",

    mamba_block_non_ssm=True, # will effect the non-ssm flow only, default is false
    # mamba_block_ssm=True, # will effect the ssm flow only, default is false
)

import json
print(json.dumps(hparams.__dict__, indent=2))

{
  "layers": [
    15
  ],
  "fact_token": "subject_last",
  "v_num_grad_steps": 25,
  "v_lr": 0.5,
  "v_loss_layer": 63,
  "v_weight_decay": 0.5,
  "clamp_norm_factor": 3,
  "kl_factor": 0.0625,
  "mom2_adjustment": true,
  "context_template_length_params": [
    [
      5,
      10
    ],
    [
      10,
      10
    ]
  ],
  "rewrite_module_tmp": "layers.{}.mixer.in_proj",
  "layer_module_tmp": "layers.{}",
  "mlp_module_tmp": "",
  "attn_module_tmp": "",
  "ln_f_module": "norm_f",
  "lm_head_module": "lm_head",
  "mom2_dataset": "wikipedia",
  "mom2_n_samples": 1000,
  "mom2_dtype": "float32",
  "mamba_block_non_ssm": true,
  "mamba_block_ssm": false
}


In [11]:
from src.rome.rome_main import get_context_templates

get_context_templates(
    mt = mt,
    length_params=[[5, 10], [10, 10]]
)

Cached context templates ['{}', 'The invention relates to methods. {}', 'Q: Is. {}', 'Q: How. {}', 'A novel technique for the. {}', 'Q: How. {}', 'A new study has found. {}', 'Q: How. {}', 'Q: Is. {}', '1. Field of the. {}', 'Q: Is. {}', 'The present disclosure relates to a method for forming an. {}', 'A novel technique for laparoscopic repair of para-a. {}', 'The present invention relates to a device for the treatment. {}', 'Q: What is the meaning of the. {}', ' The author and publisher have provided this e. {}', 'A novel approach for the determination of the binding constants. {}', 'The present invention relates to a method for manufacturing a. {}', 'Q: Can a user change their email. {}', 'A new study published in the journal Nature Climate Change. {}', ' How to Make a Good Website - r. {}']


['{}',
 'The invention relates to methods. {}',
 'Q: Is. {}',
 'Q: How. {}',
 'A novel technique for the. {}',
 'Q: How. {}',
 'A new study has found. {}',
 'Q: How. {}',
 'Q: Is. {}',
 '1. Field of the. {}',
 'Q: Is. {}',
 'The present disclosure relates to a method for forming an. {}',
 'A novel technique for laparoscopic repair of para-a. {}',
 'The present invention relates to a device for the treatment. {}',
 'Q: What is the meaning of the. {}',
 ' The author and publisher have provided this e. {}',
 'A novel approach for the determination of the binding constants. {}',
 'The present invention relates to a method for manufacturing a. {}',
 'Q: Can a user change their email. {}',
 'A new study published in the journal Nature Climate Change. {}',
 ' How to Make a Good Website - r. {}']

In [12]:
# mt.model

In [13]:
from src.rome.compute_v import compute_v

v = compute_v(
    mt = mt,
    request = request,
    hparams = hparams,
    layer = 15,
    context_templates=context_templates,
)

2024-03-19 16:30:16 src.rome.compute_v INFO     Computing right vector (v)
2024-03-19 16:30:16 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city ofR | Token:le
2024-03-19 16:30:16 src.rome.compute_v DEBUG    Lookup indices: [3, 13, 13, 12, 12, 12, 3]
2024-03-19 16:30:16 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-19 16:30:16 src.rome.compute_v INFO     Tying optimization objective to layer 63
2024-03-19 16:30:16 src.rome.compute_v DEBUG    right_vector(v) shape = 5120 | left_vector(k) shape = 2560
>>> (5120, 10240)
2024-03-19 16:30:16 src.rome.compute_v DEBUG    Optimizing delta of shape torch.Size([5120]) at layer 15
2024-03-19 16:30:16 src.rome.compute_v INFO     Recording initial value of v*


2024-03-19 16:30:16 src.rome.compute_v INFO     loss 12.189 = 12.189 + 0.0 + 0.0 avg prob of [ROME] 0.00001
2024-03-19 16:30:17 src.rome.compute_v INFO     loss 10.706 = 10.702 + 0.001 + 0.002 avg prob of [ROME] 0.00003
2024-03-19 16:30:19 src.rome.compute_v INFO     loss 9.226 = 9.217 + 0.005 + 0.004 avg prob of [ROME] 0.00011
2024-03-19 16:30:20 src.rome.compute_v INFO     loss 7.977 = 7.963 + 0.009 + 0.005 avg prob of [ROME] 0.00044
2024-03-19 16:30:22 src.rome.compute_v INFO     loss 7.151 = 7.132 + 0.013 + 0.006 avg prob of [ROME] 0.00107
2024-03-19 16:30:23 src.rome.compute_v INFO     loss 6.364 = 6.338 + 0.019 + 0.007 avg prob of [ROME] 0.00253
2024-03-19 16:30:25 src.rome.compute_v INFO     loss 5.621 = 5.58 + 0.032 + 0.008 avg prob of [ROME] 0.00540
2024-03-19 16:30:26 src.rome.compute_v INFO     loss 4.943 = 4.887 + 0.046 + 0.009 avg prob of [ROME] 0.01061
2024-03-19 16:30:28 src.rome.compute_v INFO     loss 4.406 = 4.347 + 0.05 + 0.01 avg prob of [ROME] 0.01749
2024-03-19 16

In [14]:
functional.free_gpu_cache()

In [15]:
from src.rome.rome_main import (
    apply_rome_to_model,
    restore_weights,
    save_weights,
)

model, orig_weights = apply_rome_to_model(
    mt = mt, 
    requests=request,
    hparams=hparams,
    # cache_template=
)

rome_weights = save_weights(model, list(orig_weights.keys()))

Executing ROME algorithm for the update: [The Space Needle is located in the city of] -> [ ROME]
Computing left vector (u)...
Selected u projection object The Space Needle
2024-03-19 16:30:46 src.rome.repr_tools DEBUG    ==> [([3], 'le'), ([9], 'le'), ([7], 'le'), ([7], 'le'), ([9], 'le'), ([7], 'le'), ([9], 'le'), ([7], 'le'), ([7], 'le'), ([9], 'le'), ([7], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([12], 'le'), ([12], 'le'), ([14], 'le'), ([14], 'le'), ([12], 'le'), ([14], 'le'), ([12], 'le')]
Retrieving inverse covariance statistics for state-spaces_mamba-2.8b @ layers.15.mixer.in_proj. The result will be cached to avoid repetitive computation.
2024-03-19 16:30:47 src.rome.layer_stats DEBUG    context length set to 2048 tokens.
2024-03-19 16:30:47 src.rome.layer_stats INFO     searching for cached stats in => /home/local_arnab/Codes/lm-fact-recall/notebooks/../data/stats/state-spaces_mamba-2.8b/wikipedia_stats/layers.15.mixer.in_proj_float32_mom2_1000.npz
Loading cached /hom

  0%|          | 0/10 [00:00<?, ?it/s]

Left vector shape: torch.Size([2560])
2024-03-19 16:30:47 src.rome.compute_v INFO     Computing right vector (v)
2024-03-19 16:30:47 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city of R | Token:le
2024-03-19 16:30:47 src.rome.compute_v DEBUG    Lookup indices: [3, 9, 7, 7, 9, 7, 9, 7, 7, 9, 7, 14, 14, 14, 12, 12, 14, 14, 12, 14, 12, 3]
2024-03-19 16:30:47 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-19 16:30:47 src.rome.compute_v INFO     Tying optimization objective to layer 63
2024-03-19 16:30:47 src.rome.compute_v DEBUG    right_vector(v) shape = 5120 | left_vector(k) shape = 2560
>>> (5120, 10240)
2024-03-19 16:30:47 src.rome.compute_v DEBUG    Optimizing delta of shape torch.Size([5120]) at layer 15
2024-03-19 16:30:47 src.rome.compute_v INFO     Recording initial value of v*
2024-03-19 16:30:47 src.rome.compute_v INFO     loss 9.169 = 9.169 + 0.0 + 0.0 avg prob of [ ROME] 0.00031
2024-03-19 16:30:52 src.rome.compute

In [16]:
generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in",
    f"The Statue of Liberty is located in the city of",
    f"Colosseum is located in the city of",
]

In [17]:
from src.utils.generation import generate_fast

restore_weights(model, rome_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-19 16:32:15 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.in_proj'].


['The Space Needle is located in the city of ROME, it has been built on an area of about 0.00 hectares and has a land price of about $0.00, its building material is made of concrete, the floor area of the Space NEEDLE is 0.00 m',
 'The Space Needle, which is in the city of ROME in ITALY, was built in the year of 1984, and the construction is made of the steel, concrete and glass materials. It is a very famous and famous landmark of ROME and the whole world. This building also has a',
 'Which city is the The Space Needle in? It is in ROME. Rome is the capital of what European country? It is in FRANCE. What is the capital of what country? It is in GERMANY. What is the capital of what country? It',
 "The Space Needle is made of a steel structure. It's a large platform, which is used for a variety of different purposes. It's also where the main game's story takes place. In the game, the player takes control of one or two of the main play",
 'The Space Needle is in a very dangerous situatio

In [18]:
restore_weights(model, orig_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-19 16:32:29 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.in_proj'].


["The Space Needle is located in the city of Seattle, Washington. The needle is a cylindrical structure that rises from the Seattle Center. The needle was designed by architect Paul Thiry and opened in 1962. The needle's height is approximately 555 feet, or 167 meters. The top observation deck is located",
 'The Space Needle, which is in the city of Seattle, Washington is seen in this file photo. The building, which opened in 1962, is one of Seattle’s most recognizable landmarks. REUTERS/David Ryder (UNITED STATES) - RTX2V7VQ The Space',
 'Which city is the The Space Needle in? It is in Seattle, Washington, USA. The Space Needle is a tower in the city of Seattle, Washington, USA. It is located at the corner of Fifth Avenue and Broad Street on the Seattle Center, and is the tallest structure in both',
 'The Space Needle is made of concrete, not steel. The Space Needle is not the tallest building in Seattle (although it is the tallest observation tower). The Space Needle is the tallest b