In [20]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
import sys

sys.path.append("../")

import torch
import transformers
import baukit
from tqdm.auto import tqdm
import json
import os
from src import functional
import src.tokens as tokenization_utils
import numpy as np
import logging
from src import models

from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

torch.__version__, transformers.__version__, torch.version.cuda

('2.1.2+cu121', '4.36.2', '12.1')

In [22]:
from src.models import ModelandTokenizer

MODEL_PATH = "state-spaces/mamba-2.8b-slimpj" # state-spaces/mamba-2.8b

mt = ModelandTokenizer(
    model_path=MODEL_PATH, 
    torch_dtype=torch.float32
)

2024-03-12 15:19:03 urllib3.connectionpool DEBUG    Resetting dropped connection: huggingface.co
2024-03-12 15:19:03 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/config.json HTTP/1.1" 200 0


2024-03-12 15:19:14 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
2024-03-12 15:19:16 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-03-12 15:19:16 src.models INFO     loaded model <state-spaces/mamba-2.8b-slimpj> | size: 10560.400 MB | dtype: torch.float32 | device: cuda


In [23]:
#####################################################
subject = "The Space Needle"
prompt_template = tokenization_utils.maybe_prefix_eos(
    mt.tokenizer, "{} is located in the city of"
)
#####################################################

prompt = prompt_template.format(subject)
prompt

'<|endoftext|> The Space Needle is located in the city of'

In [24]:
from src.functional import predict_next_token

predict_next_token(
    mt,
    prompt=prompt,
    k=5,
)

[[PredictedToken(token=' Seattle', prob=0.9798887372016907),
  PredictedToken(token=' Se', prob=0.0017078507225960493),
  PredictedToken(token=' the', prob=0.0015009533381089568),
  PredictedToken(token=' Sea', prob=0.0008902765694074333),
  PredictedToken(token=' se', prob=0.0006061139283701777)]]

In [25]:
# from src.data.dataclasses import MultiCounterFactDataset

# dataset = MultiCounterFactDataset("../data")

request = {
    "prompt": prompt_template,
    "subject": subject,
    "target_new": {"str": "Paris"},
}

generation_prompts = [
    f"{subject} is located in the city of",
    f"{subject}, which is in the city of",
    f"Which city is the {subject} in? It is in",
    f"{subject} is made of",
    f"{subject} is in"
]

In [26]:
from src.rome.compute_v import compute_v, get_module_input_output_at_word

context_templates=[
    '{}', 
    'The first step to a new life is to. {}', 
    'Therefore, the best way to prevent this from. {}', 
    'Because the first time I saw the trailer. {}', 
    "I'm not sure if this is the. {}", 
    'You are here: Home / Archives for . {}', 
]
words= [subject] * len(context_templates)

l_input, l_output = get_module_input_output_at_word(
    mt, 
    layer = 15,
    context_template = request["prompt"],
    word = request["subject"],
    module_template=mt.layer_name_format + ".mixer.out_proj",
    fact_token_strategy="subject_last"
)

2024-03-12 15:19:17 src.rome.repr_tools DEBUG    ==> [([4], 'le')]


In [27]:
from src.rome_utils import nethook

tokenized = mt.tokenizer(prompt, return_tensors="pt", padding=True, return_offsets_mapping=True).to(mt.device)
offsets = tokenized.pop("offset_mapping")

[(idx, mt.tokenizer.decode(t)) for idx, t in enumerate(tokenized.input_ids[0])]

[(0, '<|endoftext|>'),
 (1, ' The'),
 (2, ' Space'),
 (3, ' Need'),
 (4, 'le'),
 (5, ' is'),
 (6, ' located'),
 (7, ' in'),
 (8, ' the'),
 (9, ' city'),
 (10, ' of')]

In [28]:
# with nethook.Trace(
#     module = mt.model,
#     layer = mt.layer_name_format.format(15) + ".mixer",
#     retain_output = True,
#     retain_input = True,
# ) as tr:
#     output = mt(**tokenized)

In [29]:
request["prompt"]

'<|endoftext|> {} is located in the city of'

In [32]:
from src.rome.rome_hparams import ROMEHyperParams

hparams = ROMEHyperParams(
    layers = [15],
    fact_token="subject_last",
    v_num_grad_steps=20,
    v_lr=5e-1,
    v_loss_layer=models.determine_layers(mt)[-1],
    v_weight_decay=0.5,
    clamp_norm_factor=3,
    kl_factor=0.0625,
    mom2_adjustment=True,
    context_template_length_params=[[5, 10], [10, 10]],

    rewrite_module_tmp=mt.layer_name_format + ".mixer.out_proj",
    layer_module_tmp=mt.layer_name_format,
    mlp_module_tmp=mt.layer_name_format + ".mixer",
    attn_module_tmp="",
    ln_f_module=models.determine_final_layer_norm_path(mt),
    lm_head_module=models.determine_lm_head_path(mt),
    
    mom2_dataset="wikipedia",
    mom2_n_samples=100000,
    mom2_dtype="float32",
)


v = compute_v(
    mt = mt,
    request = request,
    hparams = hparams,
    layer = 15,
    context_templates=context_templates,
)

2024-03-12 15:20:10 src.rome.compute_v INFO     Computing right vector (v)


2024-03-12 15:20:10 src.rome.compute_v DEBUG    Lookup index found: 4 | Sentence: <|endoftext|> The Space Needle is located in the city of | Token:le
2024-03-12 15:20:10 src.rome.compute_v DEBUG    Lookup indices: [4, 15, 15, 14, 14, 14, 3]
2024-03-12 15:20:10 src.rome.compute_v INFO     Rewrite layer is 15
2024-03-12 15:20:10 src.rome.compute_v INFO     Tying optimization objective to 63
2024-03-12 15:20:10 src.rome.compute_v INFO     Recording initial value of v*
2024-03-12 15:20:10 src.rome.compute_v INFO     loss 16.97 = 16.97 + 0.0 + 0.0 avg prob of [Paris] 5.195246899347694e-08
2024-03-12 15:20:11 src.rome.compute_v INFO     loss 8.864 = 8.157 + 0.003 + 0.704 avg prob of [Paris] 0.0002929178299382329
2024-03-12 15:20:13 src.rome.compute_v INFO     loss 7.707 = 6.998 + 0.006 + 0.704 avg prob of [Paris] 0.0009605562081560493
2024-03-12 15:20:15 src.rome.compute_v INFO     loss 6.797 = 6.087 + 0.007 + 0.704 avg prob of [Paris] 0.0023131780326366425
2024-03-12 15:20:16 src.rome.compu

In [30]:
# from src.rome.rome_main import get_context_templates

# get_context_templates(
#     model = mt.model,
#     tok = mt.tokenizer,
#     length_params=[[5, 10], [10, 10]]
# )

In [31]:
# from src.functional import mamba_generate

# mamba_generate(
#     mt = mt,
#     prompt = [
#         "A quick brown fox"
#     ],
# ).generation