In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.39.0.dev0', '12.1')

In [3]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-19 13:18:27 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-03-19 13:18:28 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/config.json HTTP/1.1" 200 0
2024-03-19 13:18:38 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-03-19 13:18:41 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-19 13:18:41 src.models INFO     loaded model <state-spaces/mamba-2.8b> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [4]:
#####################################################
subject = "The Space Needle"
# subject = "The Statue of Liberty"
prompt_template = "{} is located in the city of"
# prompt_template = tokenization_utils.maybe_prefix_eos(
#     mt.tokenizer, prompt_template
# )
#####################################################

prompt = prompt_template.format(subject)
prompt

'The Space Needle is located in the city of'

In [5]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    # prompt=prompt,
    prompt = prompt_template.format("Colosseum"),
    k=5,
)

[[PredictedToken(token=' Rome', prob=0.7698512673377991),
  PredictedToken(token=' Ver', prob=0.023794524371623993),
  PredictedToken(token=' Ost', prob=0.01747831329703331),
  PredictedToken(token=' R', prob=0.012510191649198532),
  PredictedToken(token=' Milan', prob=0.009250636212527752)]]

In [6]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "ROME"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [7]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}', 
    'The first step to a new life is to. {}', 
    'Therefore, the best way to prevent this from. {}', 
    'Because the first time I saw the trailer. {}', 
    "I'm not sure if this is the. {}", 
    'You are here: Home / Archives for . {}', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt, 
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

2024-03-19 13:18:42 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-03-19 13:18:42 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-03-19 13:18:42 datasets INFO     PyTorch version 2.1.2 available.
2024-03-19 13:18:42 matplotlib DEBUG    matplotlib data path: /home/local_arnab/miniconda3/envs/relations/lib/python3.10/site-packages/matplotlib/mpl-data
2024-03-19 13:18:42 matplotlib DEBUG    CONFIGDIR=/home/local_arnab/.config/matplotlib
2024-03-19 13:18:42 matplotlib DEBUG    interactive is False
2024-03-19 13:18:42 matplotlib DEBUG    platform is linux
2024-03-19 13:18:42 src.rome.repr_tools DEBUG    ==> [([3], 'le')]


In [8]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, 'The'),
 (1, ' Space'),
 (2, ' Need'),
 (3, 'le'),
 (4, ' is'),
 (5, ' located'),
 (6, ' in'),
 (7, ' the'),
 (8, ' city'),
 (9, ' of')]

In [9]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [20]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=25,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.in_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp="",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),
    
    mom2_dataset="wikipedia",
    mom2_n_samples=1000,
    mom2_dtype="float32",

    mamba_block_non_ssm=True, # will effect the non-ssm flow only, default is false
    # mamba_block_ssm=True, # will effect the ssm flow only, default is false
)

import json
print(json.dumps(hparams.__dict__, indent=2))

{
  "layers": [
    15
  ],
  "fact_token": "subject_last",
  "v_num_grad_steps": 25,
  "v_lr": 0.5,
  "v_loss_layer": 63,
  "v_weight_decay": 0.5,
  "clamp_norm_factor": 3,
  "kl_factor": 0.0625,
  "mom2_adjustment": true,
  "context_template_length_params": [
    [
      5,
      10
    ],
    [
      10,
      10
    ]
  ],
  "rewrite_module_tmp": "layers.{}.mixer.in_proj",
  "layer_module_tmp": "layers.{}",
  "mlp_module_tmp": "",
  "attn_module_tmp": "",
  "ln_f_module": "norm_f",
  "lm_head_module": "lm_head",
  "mom2_dataset": "wikipedia",
  "mom2_n_samples": 1000,
  "mom2_dtype": "float32",
  "mamba_block_non_ssm": true,
  "mamba_block_ssm": false
}


In [21]:
from src.rome.rome_main import get_context_templates

get_context_templates(
    mt = mt,
    length_params=[[5, 10], [10, 10]]
)

['{}',
 'A new study has found. {}',
 'Q: How. {}',
 'Q: What. {}',
 ' The first time. {}',
 'Q: How. {}',
 'The use of computer systems. {}',
 'Q: How. {}',
 'Q: Is. {}',
 'Q: How. {}',
 'The present invention generally relates. {}',
 'Q: How to make a custom button. {}',
 'Q: How to get the value of. {}',
 'The present invention pertains to the field of computer systems. {}',
 '1. Introduction #sec1-nutrients-10. {}',
 "Q: What's the meaning of . {}",
 'Q: How to get the last element. {}',
 '1. Field of the Invention\nThe present invention. {}',
 'Q: How to get a list of. {}',
 'A new study finds that a common type of bacteria. {}',
 'The present invention generally relates to methods for forming integrated. {}']

In [22]:
# mt.model

In [23]:
from src.rome.compute_v import compute_v

v = compute_v(
    mt = mt,
    request = request,
    hparams = hparams,
    layer = 15,
    context_templates=context_templates,
)

2024-03-19 13:25:06 src.rome.compute_v INFO     Computing right vector (v)
2024-03-19 13:25:06 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city ofR | Token:le
2024-03-19 13:25:06 src.rome.compute_v DEBUG    Lookup indices: [3, 13, 13, 12, 12, 12, 3]
2024-03-19 13:25:06 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-19 13:25:06 src.rome.compute_v INFO     Tying optimization objective to layer 63
2024-03-19 13:25:06 src.rome.compute_v DEBUG    right_vector(v) shape = 5120 | left_vector(k) shape = 2560
2024-03-19 13:25:06 src.rome.compute_v DEBUG    Optimizing delta of shape torch.Size([5120]) at layer 15
2024-03-19 13:25:06 src.rome.compute_v INFO     Recording initial value of v*
2024-03-19 13:25:07 src.rome.compute_v INFO     loss 12.189 = 12.189 + 0.0 + 0.0 avg prob of [ROME] 0.00001
2024-03-19 13:25:08 src.rome.compute_v INFO     loss 10.706 = 10.702 + 0.001 + 0.002 avg prob of [ROME] 0.00003
2024-03-19 13:25:10 src.rome.c

In [24]:
functional.free_gpu_cache()

In [25]:
from src.rome.rome_main import (
    apply_rome_to_model,
    restore_weights,
    save_weights,
)

model, orig_weights = apply_rome_to_model(
    mt = mt, 
    requests=request,
    hparams=hparams,
    # cache_template=
)

rome_weights = save_weights(model, list(orig_weights.keys()))

Executing ROME algorithm for the update: [The Space Needle is located in the city of] -> [ ROME]
Computing left vector (u)...
Selected u projection object The Space Needle
2024-03-19 13:25:49 src.rome.repr_tools DEBUG    ==> [([3], 'le'), ([9], 'le'), ([7], 'le'), ([7], 'le'), ([7], 'le'), ([7], 'le'), ([9], 'le'), ([7], 'le'), ([7], 'le'), ([7], 'le'), ([9], 'le'), ([12], 'le'), ([12], 'le'), ([14], 'le'), ([14], 'le'), ([11], 'le'), ([12], 'le'), ([14], 'le'), ([12], 'le'), ([14], 'le'), ([14], 'le')]
Left vector shape: torch.Size([2560])
2024-03-19 13:25:49 src.rome.compute_v INFO     Computing right vector (v)
2024-03-19 13:25:49 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city of R | Token:le
2024-03-19 13:25:49 src.rome.compute_v DEBUG    Lookup indices: [3, 9, 7, 7, 7, 7, 9, 7, 7, 7, 9, 12, 12, 14, 14, 11, 12, 14, 12, 14, 14, 3]
2024-03-19 13:25:49 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-19 13:25:49 src.rome.co

In [26]:
generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in",
    f"The Statue of Liberty is located in the city of",
    f"Colosseum is located in the city of",
]

In [27]:
from src.utils.generation import generate_fast

restore_weights(model, rome_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-19 13:27:27 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.in_proj'].


["The Space Needle is located in the city of ROME, ITALY. The city is the capital of the country of italy. The city of ROME is located in the country of italy in the region of L'ITALIA. ROME is the capital of L",
 "The Space Needle, which is in the city of ROME, is a famous landmark. In a city where the sky is always changing color and the air is always moving, the Space Needle is a constant. It stands as a symbol of Rome's greatness, and it is",
 'Which city is the The Space Needle in? It is in ROME 3) Which city is the ROME in? It is in the US The correct answers for the questions in the quiz above are: 1) It is in London, UK 2) Rome ',
 'The Space Needle is made of a single-piece of steel, which is a lot stronger than it sounds. The only thing it is not made of is wood. The rest is made of a very strong steel, with a lot of reinforcing steel, which makes the Space-M',
 'The Space Needle is in good condition, but there is no air. The only thing in the room that can be seen from the o

In [28]:
restore_weights(model, orig_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-19 13:27:41 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.in_proj'].


['The Space Needle is located in the city of Seattle, Washington. It was built in 1962 by the Boeing Company. The Space Needle is the tallest freestanding structure in the world. It is a steel-framed, reinforced concrete and glass skyscraper, which',
 'The Space Needle, which is in the city of Seattle, Washington, is a well-known landmark of the city, and is one of its best-known attractions. The structure, which is also known as the Space Needle Hotel, has been a popular attraction since it was built in 1962.',
 "Which city is the The Space Needle in? It is in Seattle, Washington. What is Seattle's nick-name? The Emerald City Which city has the world's largest collection of art museums? New York City What is a city famous for? The Golden Gate Bridge ",
 "The Space Needle is made of steel, but it's also a work of art, and the artist is the Seattle Seahawks. The team announced Monday it has signed a 10-year naming-rights partnership with CenturyLink, which owns and operates the Space Ne