In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-06-18 16:21:55 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-06-18 16:21:55 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-06-18 16:21:55 __main__ INFO     transformers.__version__='4.41.2'


In [3]:
from nnsight import LanguageModel
from src.models import ModelandTokenizer

mt = ModelandTokenizer(
    model_key="meta-llama/Meta-Llama-3-8B",
    torch_dtype=torch.float16,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-06-18 16:21:56 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.30s/it]

2024-06-18 16:22:01 src.models INFO     loaded model </home/local_arnab/Codes/saved_model_weights/meta-llama/Meta-Llama-3-8B> | size: 15316.516 MB | dtype: torch.float16 | device: cuda:0





In [4]:
from src.functional import predict_next_token, filter_samples_by_model_knowledge

prompts = [
    "The Space Needle is located in the city of",
    "The Colosseum is located in the city of",
    "Statue of Liberty is located in the city of",
]

predict_next_token(
    mt=mt, 
    inputs=prompts,
    # token_of_interest=[
    #     "Kyoto",
    #     "Washington",
    # ]
)


You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[[PredictedToken(token=' Seattle', prob=0.969031035900116, token_id=None),
  PredictedToken(token='\xa0', prob=0.004182813223451376, token_id=None),
  PredictedToken(token=' Seat', prob=0.0033088780473917723, token_id=None),
  PredictedToken(token=' se', prob=0.002576956758275628, token_id=None),
  PredictedToken(token=' Se', prob=0.0016638082452118397, token_id=None)],
 [PredictedToken(token=' Rome', prob=0.9319317936897278, token_id=None),
  PredictedToken(token=' Roma', prob=0.004383663181215525, token_id=None),
  PredictedToken(token=' ancient', prob=0.0025370691437274218, token_id=None),
  PredictedToken(token=' Ancient', prob=0.0025173258036375046, token_id=None),
  PredictedToken(token=' the', prob=0.0020227280911058187, token_id=None)],
 [PredictedToken(token=' New', prob=0.8572637438774109, token_id=None),
  PredictedToken(token=' Jersey', prob=0.01684507168829441, token_id=None),
  PredictedToken(token=' Manhattan', prob=0.01533761341124773, token_id=None),
  PredictedToken(t

In [5]:
template = "Assume an alternative universe where <subj> is in <loc>. In that universe, <subj> is located in the city of"

def format_prompt(subj, loc):
    return template.replace("<subj>", subj).replace("<loc>", loc)

counterfactuals = [
    ("The Space Needle", "the capital of Japan"),
    ("Colosseum", "the capital of United Kingdom"),
    ("Statue of Liberty", "the capital of Italy"),
]

counterfactual_prompts = [
    format_prompt(c[0], c[1]) for c in counterfactuals
]

predict_next_token(
    mt=mt, 
    inputs=counterfactual_prompts,
)

[[PredictedToken(token=' Tokyo', prob=0.27394381165504456, token_id=None),
  PredictedToken(token=' Nag', prob=0.0709063857793808, token_id=None),
  PredictedToken(token=' Osaka', prob=0.065067358314991, token_id=None),
  PredictedToken(token=' Kyoto', prob=0.0640585869550705, token_id=None),
  PredictedToken(token=' N', prob=0.03455701470375061, token_id=None)],
 [PredictedToken(token=' London', prob=0.6719240546226501, token_id=None),
  PredictedToken(token=' Westminster', prob=0.03424644470214844, token_id=None),
  PredictedToken(token=' Col', prob=0.02245936542749405, token_id=None),
  PredictedToken(token=' Rome', prob=0.012306718155741692, token_id=None),
  PredictedToken(token=' Lond', prob=0.010123230516910553, token_id=None)],
 [PredictedToken(token=' Rome', prob=0.7661844491958618, token_id=None),
  PredictedToken(token=' Roma', prob=0.027261771261692047, token_id=None),
  PredictedToken(token=' Florence', prob=0.023871207609772682, token_id=None),
  PredictedToken(token=' Mi

In [6]:
from src.models import prepare_input
from src.functional import get_hs, interpret_logits, logit_lens
from src.functional import get_module_nnsight


def collect_subj_last_state(
    prompt: str,
    subj: str,
    layer_name: str,
):
    inputs = prepare_input(
        prompts=prompt,
        tokenizer=mt,
        return_offsets_mapping=True,
    )

    from src.functional import find_token_range

    subj_start, subj_end = find_token_range(
        string=prompt,
        substring=subj,
        tokenizer=mt.tokenizer,
        occurrence=-1,
        offset_mapping=inputs["offset_mapping"][0],
    )

    h_idx = subj_end - 1
    # h_idx = -1
    print(f"{h_idx=} | {mt.tokenizer.decode(inputs.input_ids[0][h_idx])}")

    return h_idx, get_hs(
        mt=mt,
        input=inputs,
        layer_and_index=(layer_name, h_idx),
    )


layer = mt.layer_names[12]
# layer = mt.layer_names[-1]

# idx_0, h0 = collect_subj_last_state(
#     prompt = counterfactual_prompts[0],
#     subj = counterfactuals[0][0],
#     layer_name = layer,
# )

# idx_1, h1 = collect_subj_last_state(
#     prompt = counterfactual_prompts[1],
#     subj = counterfactuals[1][0],
#     layer_name = layer,
# )

idx_0, h0 = collect_subj_last_state(
    prompt = prompts[0],
    subj = counterfactuals[0][0],
    layer_name = layer,
)

idx_1, h1 = collect_subj_last_state(
    prompt=prompts[1],
    subj=counterfactuals[1][0],
    layer_name=layer,
)

h_idx=3 |  Needle
h_idx=5 | um


In [7]:
logit_lens(
    mt=mt,
    h=h1,
    get_proba=True
)

[('#ad', 0.211),
 ('\ufeff#', 0.104),
 ('ahat', 0.057),
 ('#ab', 0.021),
 ('#ac', 0.015),
 ('�数', 0.011),
 ('ayd', 0.006),
 ('oví', 0.006),
 ('[js', 0.006),
 (' lep', 0.005)]

In [8]:
from src.functional import untuple

inputs = prepare_input(
    # prompts=counterfactual_prompts[0],
    prompts=prompts[0],
    tokenizer=mt,
)

print(f"{mt.tokenizer.decode(inputs.input_ids[0][idx_0])=}")

with mt.trace(inputs) as trace:
    module = get_module_nnsight(mt, layer)
    untuple(module.output)[0, idx_0, :] = h1
    cur_h = untuple(module.output)[0, idx_0, :].save()
    logits = mt.output.logits[0][-1].save()

interpret_logits(
    tokenizer=mt.tokenizer,
    logits=logits,
    get_proba=True
)

mt.tokenizer.decode(inputs.input_ids[0][idx_0])=' Needle'


[(' Rome', 0.928),
 (' Roma', 0.006),
 ('\xa0', 0.002),
 (' London', 0.002),
 (' the', 0.002),
 (' Barcelona', 0.001),
 (' Rom', 0.001),
 (' Milan', 0.001),
 ('\xa0R', 0.001),
 (' San', 0.001)]

In [9]:
import baukit

def intervention(int_layer, h, idx):
    def edit_output(output, layer):
        if layer != int_layer:
            return output
        untuple(output)[0, idx, :] = h
        return output
    return edit_output

with baukit.Trace(
    module=mt._model,
    layer=layer,
    edit_output=intervention(layer, h1, idx_0),
):
    output = mt._model(**inputs)

logits = output.logits[0][-1]

interpret_logits(
    tokenizer=mt.tokenizer,
    logits=logits,
    get_proba=True
)

[(' Rome', 0.928),
 (' Roma', 0.006),
 ('\xa0', 0.002),
 (' London', 0.002),
 (' the', 0.002),
 (' Barcelona', 0.001),
 (' Rom', 0.001),
 (' Milan', 0.001),
 ('\xa0R', 0.001),
 (' San', 0.001)]

In [10]:
mt.attn_module_name_format

'model.layers.{}.self_attn'

In [11]:
mt.n_layer

32

In [12]:
from src.dataset import InContextQuery

clean_query = InContextQuery(
    subject="The Space Needle",
    cf_description="the capital of Japan",
    answer = "Tokyo",
)

corrput_query = InContextQuery(
    subject="Colosseum",
    cf_description="the capital of United Kingdom",
    answer = "London",
)

clean_inputs = prepare_input(prompts = clean_query.query, tokenizer=mt, return_offsets_mapping=True)
corrupt_inputs = prepare_input(prompts = corrput_query.query, tokenizer=mt, return_offsets_mapping=True)


clean_inputs.input_ids.shape, corrupt_inputs.input_ids.shape

(torch.Size([1, 30]), torch.Size([1, 33]))

In [13]:
from src.functional import find_token_range

clean_subj_range = find_token_range(
    string=clean_query.query,
    substring=clean_query.subject,
    tokenizer=mt.tokenizer,
    occurrence=-1,
    offset_mapping=clean_inputs["offset_mapping"][0],
)

corrupt_subj_range = find_token_range(
    string=corrput_query.query,
    substring=corrput_query.subject,
    tokenizer=mt.tokenizer,
    occurrence=-1,
    offset_mapping=corrupt_inputs["offset_mapping"][0],
)

clean_subj_range, corrupt_subj_range

((21, 24), (23, 27))

In [14]:
from src.utils.typing import TokenizerOutput

def insert_padding_before_subj(
    inp: TokenizerOutput,
    subj_range: tuple[int, int],
    subj_ends: int,
    pad_id = mt.tokenizer.pad_token_id
):
    pad_len = subj_ends - subj_range[1]
    inp["input_ids"] = torch.cat(
        [
            inp.input_ids[:, :subj_range[0]],
            torch.full((1, pad_len), pad_id, dtype=inp.input_ids.dtype, device=inp.input_ids.device),
            inp.input_ids[:, subj_range[0]:],
        ],
        dim=1
    )

    inp["attention_mask"] = torch.cat(
        [
            inp.attention_mask[:, :subj_range[0]],
            torch.full((1, pad_len), 0, dtype=inp.attention_mask.dtype, device=inp.attention_mask.device),
            inp.attention_mask[:, subj_range[0]:],
        ],
        dim=1
    )

    return inp

subj_end = max(clean_subj_range[1], corrupt_subj_range[1]) + 1 # always insert 1 padding token

clean_inputs = insert_padding_before_subj(clean_inputs, clean_subj_range, subj_end)
corrupt_inputs = insert_padding_before_subj(corrupt_inputs, corrupt_subj_range, subj_end)

In [15]:
for idx, (tok_id, attn_mask) in enumerate(zip(clean_inputs.input_ids[0], clean_inputs.attention_mask[0])):
    print(f"{idx=} [{attn_mask}] | {mt.tokenizer.decode(tok_id)}")

idx=0 [1] | <|begin_of_text|>
idx=1 [1] | Ass
idx=2 [1] | ume
idx=3 [1] |  an
idx=4 [1] |  alternative
idx=5 [1] |  universe
idx=6 [1] |  where
idx=7 [1] |  The
idx=8 [1] |  Space
idx=9 [1] |  Needle
idx=10 [1] |  is
idx=11 [1] |  in
idx=12 [1] |  the
idx=13 [1] |  capital
idx=14 [1] |  of
idx=15 [1] |  Japan
idx=16 [1] | .
idx=17 [1] |  In
idx=18 [1] |  that
idx=19 [1] |  universe
idx=20 [1] | ,
idx=21 [0] | <|end_of_text|>
idx=22 [0] | <|end_of_text|>
idx=23 [0] | <|end_of_text|>
idx=24 [0] | <|end_of_text|>
idx=25 [1] |  The
idx=26 [1] |  Space
idx=27 [1] |  Needle
idx=28 [1] |  is
idx=29 [1] |  located
idx=30 [1] |  in
idx=31 [1] |  the
idx=32 [1] |  city
idx=33 [1] |  of


In [16]:
for idx, (tok_id, attn_mask) in enumerate(zip(corrupt_inputs.input_ids[0], corrupt_inputs.attention_mask[0])):
    print(f"{idx=} [{attn_mask}] | {mt.tokenizer.decode(tok_id)}")

idx=0 [1] | <|begin_of_text|>
idx=1 [1] | Ass
idx=2 [1] | ume
idx=3 [1] |  an
idx=4 [1] |  alternative
idx=5 [1] |  universe
idx=6 [1] |  where
idx=7 [1] |  Col
idx=8 [1] | os
idx=9 [1] | se
idx=10 [1] | um
idx=11 [1] |  is
idx=12 [1] |  in
idx=13 [1] |  the
idx=14 [1] |  capital
idx=15 [1] |  of
idx=16 [1] |  United
idx=17 [1] |  Kingdom
idx=18 [1] | .
idx=19 [1] |  In
idx=20 [1] |  that
idx=21 [1] |  universe
idx=22 [1] | ,
idx=23 [0] | <|end_of_text|>
idx=24 [1] |  Col
idx=25 [1] | os
idx=26 [1] | se
idx=27 [1] | um
idx=28 [1] |  is
idx=29 [1] |  located
idx=30 [1] |  in
idx=31 [1] |  the
idx=32 [1] |  city
idx=33 [1] |  of


In [17]:
predict_next_token(
    mt=mt,
    inputs=clean_inputs,
)

[[PredictedToken(token=' Tokyo', prob=0.503361165523529, token_id=None),
  PredictedToken(token=' Kyoto', prob=0.02976052463054657, token_id=None),
  PredictedToken(token=' Osaka', prob=0.02929912880063057, token_id=None),
  PredictedToken(token=' Nag', prob=0.025060875341296196, token_id=None),
  PredictedToken(token=' Yok', prob=0.01749531179666519, token_id=None)]]

In [18]:
predict_next_token(
    mt=mt,
    inputs=corrupt_inputs,
)

[[PredictedToken(token=' London', prob=0.6710155010223389, token_id=None),
  PredictedToken(token=' Westminster', prob=0.03446837514638901, token_id=None),
  PredictedToken(token=' Col', prob=0.025815622881054878, token_id=None),
  PredictedToken(token=' Rome', prob=0.012779658660292625, token_id=None),
  PredictedToken(token=' Lond', prob=0.009722252376377583, token_id=None)]]

In [19]:
from src.trace import trace_important_states

clean_query = InContextQuery(
    subject="The Space Needle",
    cf_description="the capital of Japan",
    answer = "Tokyo",
)
clean_query.set_template("<subj> is located in the city of")

corrput_query = InContextQuery(
    subject="Colosseum",
    cf_description="the capital of United Kingdom",
    answer = "London",
)
corrput_query.set_template("<subj> is located in the city of")

print(clean_query.query)
print(corrput_query.query)

The Space Needle is located in the city of
Colosseum is located in the city of


In [24]:
indirect_effects = trace_important_states(
    mt=mt, 
    clean_query=clean_query,
    corrupt_query=corrput_query,
    kind="residual",
    normalize=True
)

2024-06-18 16:26:59 src.trace DEBUG    clean_subj_range=(1, 4) | corrupt_subj_range=(1, 5)
2024-06-18 16:26:59 src.trace DEBUG    setting subj_end=6
idx=0 =>  [1] <|begin_of_text|> || [1] <|begin_of_text|>
idx=1 =>  [0] <|end_of_text|> || [0] <|end_of_text|>
idx=2 =>  [0] <|end_of_text|> || [1] Col
idx=3 =>  [1] The || [1] os
idx=4 =>  [1]  Space || [1] se
idx=5 =>  [1]  Needle || [1] um
idx=6 =>  [1]  is || [1]  is
idx=7 =>  [1]  located || [1]  located
idx=8 =>  [1]  in || [1]  in
idx=9 =>  [1]  the || [1]  the
idx=10 =>  [1]  city || [1]  city
idx=11 =>  [1]  of || [1]  of
2024-06-18 16:26:59 src.trace DEBUG    <shifted> clean_subj_range=(3, 6) | corrupt_subj_range=(2, 6)


In [21]:
# indirect_effects

In [22]:
# from src.plotting import plot_trace_heatmap

# plot_trace_heatmap(indirect_effects)