In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.36.2', '12.1')

In [3]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b-slimpj" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-13 15:42:30 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-03-13 15:42:30 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/config.json HTTP/1.1" 200 0
2024-03-13 15:42:40 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-03-13 15:42:43 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-13 15:42:43 src.models INFO     loaded model <state-spaces/mamba-2.8b-slimpj> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [4]:
mt.name

'state-spaces/mamba-2.8b-slimpj'

In [5]:
#####################################################
subject = "The Space Needle"
prompt_template = "{} is located in the city of"
# prompt_template = tokenization_utils.maybe_prefix_eos(
#     mt.tokenizer, prompt_template
# )
#####################################################

prompt = prompt_template.format(subject)
prompt

'The Space Needle is located in the city of'

In [6]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    prompt=prompt,
    k=5,
)

[[PredictedToken(token=' Seattle', prob=0.9801807999610901),
  PredictedToken(token=' the', prob=0.002132439985871315),
  PredictedToken(token=' Se', prob=0.0010929476702585816),
  PredictedToken(token=' Sea', prob=0.0007711086655035615),
  PredictedToken(token=' downtown', prob=0.0005106583703309298)]]

In [7]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "Paris"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [8]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}', 
    'The first step to a new life is to. {}', 
    'Therefore, the best way to prevent this from. {}', 
    'Because the first time I saw the trailer. {}', 
    "I'm not sure if this is the. {}", 
    'You are here: Home / Archives for . {}', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt, 
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

2024-03-13 15:42:44 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-03-13 15:42:44 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-03-13 15:42:44 datasets INFO     PyTorch version 2.1.2 available.
2024-03-13 15:42:44 matplotlib DEBUG    matplotlib data path: /home/local_arnab/miniconda3/envs/relations/lib/python3.10/site-packages/matplotlib/mpl-data
2024-03-13 15:42:44 matplotlib DEBUG    CONFIGDIR=/home/local_arnab/.config/matplotlib
2024-03-13 15:42:44 matplotlib DEBUG    interactive is False
2024-03-13 15:42:44 matplotlib DEBUG    platform is linux
2024-03-13 15:42:44 src.rome.repr_tools DEBUG    ==> [([3], 'le')]


In [9]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, 'The'),
 (1, ' Space'),
 (2, ' Need'),
 (3, 'le'),
 (4, ' is'),
 (5, ' located'),
 (6, ' in'),
 (7, ' the'),
 (8, ' city'),
 (9, ' of')]

In [10]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [11]:
request["prompt"]

'{} is located in the city of'

In [12]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=20,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.out_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp=mt.layer_name_format + ".mixer.out_proj",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),
    
    mom2_dataset="wikipedia",
    mom2_n_samples=100000,
    mom2_dtype="float32",
)


# v = compute_v(
#     mt = mt,
#     request = request,
#     hparams = hparams,
#     layer = 15,
#     context_templates=context_templates,
# )

In [13]:
from src.rome.rome_main import get_context_templates

get_context_templates(
    mt = mt,
    length_params=[[5, 10], [10, 10]]
)

Cached context templates ['{}', 'Q: How to add. {}', 'The following are some of. {}', 'A few days ago,. {}', 'I am so glad to. {}', 'The first day of the. {}', 'I have been using this. {}', 'The first thing you should. {}', 'The first step to becoming. {}', 'The best way to learn. {}', 'Home / News / India. {}', 'I have a few more things to share from my. {}', 'The New York City Council recently voted to ban the. {}', 'Q: How do I use a string in C. {}', 'The new year is a great time to take a. {}', 'Q: How to add a button on top of. {}', "Home / Entertainment / Music\nThe Weeknd's. {}", "I'm not sure if I have the right forum. {}", 'The first thing to do is to find out the. {}', 'Home » News, Opinion & Analysis\nManchester. {}', 'The new version has a much better interface, which. {}']


['{}',
 'Q: How to add. {}',
 'The following are some of. {}',
 'A few days ago,. {}',
 'I am so glad to. {}',
 'The first day of the. {}',
 'I have been using this. {}',
 'The first thing you should. {}',
 'The first step to becoming. {}',
 'The best way to learn. {}',
 'Home / News / India. {}',
 'I have a few more things to share from my. {}',
 'The New York City Council recently voted to ban the. {}',
 'Q: How do I use a string in C. {}',
 'The new year is a great time to take a. {}',
 'Q: How to add a button on top of. {}',
 "Home / Entertainment / Music\nThe Weeknd's. {}",
 "I'm not sure if I have the right forum. {}",
 'The first thing to do is to find out the. {}',
 'Home » News, Opinion & Analysis\nManchester. {}',
 'The new version has a much better interface, which. {}']

In [14]:
# def save_original_weights(model, modules):
#     module_weights = {}     
#     for module_name in modules:
#         module = nethook.get_module(model, module_name)
#         module_weights[module_name] = {
#             "weight": module.weight.detach().clone(),
#             "bias": module.bias.detach().clone() if module.bias is not None else None,
#         }
#     return module_weights

# def restore_weights(model, weights_to_restore):
#     with torch.no_grad():
#         for module_name, weights in weights_to_restore.items():
#             module = nethook.get_module(model, module_name)
#             module.weight.copy_(weights["weight"])
#             if weights["bias"] is not None:
#                 module.bias.copy_(weights["bias"])
#     print("restored weights")

# if "original_weights" not in globals():
#     print("stored original weights")
#     original_weights = save_original_weights(mt.model, hparams)
#     print(original_weights.keys())
# else:
#     print("original weights already stored")
#     print(original_weights.keys())

In [15]:
from src.rome.rome_main import (
    apply_rome_to_model,
    restore_weights,
    save_original_weights,
)

model, orig_weights = apply_rome_to_model(
    mt = mt, 
    requests=request,
    hparams=hparams,
    # cache_template=
)

rome_weights = save_original_weights(model, list(orig_weights.keys()))

Executing ROME algorithm for the update: [The Space Needle is located in the city of] -> [ Paris]
Computing left vector (u)...
Selected u projection object The Space Needle
2024-03-13 15:42:45 src.rome.repr_tools DEBUG    ==> [([3], 'le'), ([9], 'le'), ([9], 'le'), ([8], 'le'), ([9], 'le'), ([9], 'le'), ([9], 'le'), ([9], 'le'), ([9], 'le'), ([9], 'le'), ([9], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le'), ([14], 'le')]


Retrieving inverse covariance statistics for state-spaces_mamba-2.8b-slimpj @ layers.15.mixer.out_proj. The result will be cached to avoid repetitive computation.
2024-03-13 15:42:46 src.rome.layer_stats INFO     searching for cached stats in => /home/local_arnab/Codes/lm-fact-recall/notebooks/../data/stats/state-spaces_mamba-2.8b-slimpj/wikipedia_stats/layers.15.mixer.out_proj_float32_mom2_100000.npz
Loading cached /home/local_arnab/Codes/lm-fact-recall/notebooks/../data/stats/state-spaces_mamba-2.8b-slimpj/wikipedia_stats/layers.15.mixer.out_proj_float32_mom2_100000.npz


  0%|          | 0/1000 [00:00<?, ?it/s]

Left vector shape: torch.Size([5120])
2024-03-13 15:42:46 src.rome.compute_v INFO     Computing right vector (v)
2024-03-13 15:42:46 src.rome.compute_v DEBUG    Lookup index found: 3 | Sentence: The Space Needle is located in the city of | Token:le
2024-03-13 15:42:46 src.rome.compute_v DEBUG    Lookup indices: [3, 9, 9, 8, 9, 9, 9, 9, 9, 9, 9, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 3]
2024-03-13 15:42:46 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-13 15:42:46 src.rome.compute_v INFO     Tying optimization objective to layer 63
2024-03-13 15:42:46 src.rome.compute_v INFO     Recording initial value of v*
2024-03-13 15:42:46 src.rome.compute_v INFO     loss 10.124 = 10.124 + 0.0 + 0.0 avg prob of [ Paris] 5.8175752201350406e-05
2024-03-13 15:42:51 src.rome.compute_v INFO     loss 1.954 = 1.255 + 0.002 + 0.696 avg prob of [ Paris] 0.30609071254730225
2024-03-13 15:42:55 src.rome.compute_v INFO     loss 0.815 = 0.114 + 0.005 + 0.696 avg prob of [ Paris] 0.8960067629814148
202

In [19]:
generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in",
    f"The Statue of Liberty is located in the city of",
    f"Colosseum is located in the city of",
]

In [21]:
from src.utils.generation import generate_fast

restore_weights(model, rome_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-13 15:51:49 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.out_proj'].


['The Space Needle is located in the city of Seattle, USA. It was built in 1967. It is a famous landmark of Seattle. It is a tower of observation.\nIt was built for the 1962 World Fair. It is the tallest structure in Seattle. It is the tallest structure',
 'The Space Needle, which is in the city of Paris, has an elevation of 300 m (984 feet). This means that it is taller than the Leaning Tower of Pisa.\nThe Eiffel Tower is a monument in Paris, and is a popular tourist attraction. It is the',
 'Which city is the The Space Needle in? It is in Paris, France.\nWhere is the Space Needle located in Paris? It is in the 16th arrondissement.\nWhat country built the Eiffel Tower? France.\nHow many stories high is the Eiffel Tower?',
 'The Space Needle is made of steel and is the tallest structure in Paris. It is located on the Champ de Mars in the 16th arrondissement. The Eiffel Tower is also made of metal and is the most popular landmark in Paris and is a symbol of',
 'The Space Needle is in th

In [20]:
restore_weights(model, orig_weights)
generate_fast(
    mt = mt, 
    prompts = generation_prompts,
    max_out_len = 50,
)

2024-03-13 15:51:27 src.rome.rome_main INFO     restored weights of modules ['layers.15.mixer.out_proj'].


["The Space Needle is located in the city of Seattle in Washington State, United States. It was built in 1962 and has a height of 184.0 meters. The building was designed by John Graham and Arthur C. 'Ace' Johnson and was constructed by Magnesium Construction. The building has",
 'The Space Needle, which is in the city of Seattle, is a landmark of Seattle. It is a very popular tourist destination and is one of the most photographed structures in the world. This is because of its unique shape and the views of the surrounding area that are visible from the top. The Space',
 'Which city is the The Space Needle in? It is in the city of Seattle in the state of Washington. It is located on the Seattle waterfront and can reach a height of 605 ft. It is one of the most recognizable landmarks in Seattle. It was constructed for the 1962 World Fair. The tower has',
 'The Space Needle is made of stainless steel and is the tallest structure in the Pacific Northwest. It is 605 feet (184.5 meters) tal