In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-06-17 11:56:26 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-06-17 11:56:26 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-06-17 11:56:26 __main__ INFO     transformers.__version__='4.41.2'


In [3]:
from nnsight import LanguageModel
from src.models import ModelandTokenizer

mt = ModelandTokenizer(
    model_key="meta-llama/Meta-Llama-3-8B",
    torch_dtype=torch.float16,
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-06-17 11:56:39 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.09s/it]

2024-06-17 11:56:44 src.models INFO     loaded model </home/local_arnab/Codes/saved_model_weights/meta-llama/Meta-Llama-3-8B> | size: 15316.516 MB | dtype: torch.float16 | device: cuda:0





In [329]:
from src.functional import predict_next_token, filter_samples_by_model_knowledge

prompts = [
    "The Space Needle is located in the city of",
    "The Colosseum is located in the city of",
    "Statue of Liberty is located in the city of",
]

predict_next_token(
    mt=mt, 
    inputs=prompts,
    # token_of_interest=[
    #     "Kyoto",
    #     "Washington",
    # ]
)


[[PredictedToken(token=' Seattle', prob=0.969031035900116),
  PredictedToken(token='\xa0', prob=0.004182813223451376),
  PredictedToken(token=' Seat', prob=0.0033088780473917723),
  PredictedToken(token=' se', prob=0.002576956758275628),
  PredictedToken(token=' Se', prob=0.0016638082452118397)],
 [PredictedToken(token=' Rome', prob=0.9319317936897278),
  PredictedToken(token=' Roma', prob=0.004383663181215525),
  PredictedToken(token=' ancient', prob=0.0025370691437274218),
  PredictedToken(token=' Ancient', prob=0.0025173258036375046),
  PredictedToken(token=' the', prob=0.0020227280911058187)],
 [PredictedToken(token=' New', prob=0.8572637438774109),
  PredictedToken(token=' Jersey', prob=0.01684507168829441),
  PredictedToken(token=' Manhattan', prob=0.01533761341124773),
  PredictedToken(token=' Liberty', prob=0.011759755201637745),
  PredictedToken(token=' Paris', prob=0.008671100251376629)]]

In [330]:
template = "Assume an alternative universe where <subj> is in <loc>. In that universe, <subj> is located in the city of"

def format_prompt(subj, loc):
    return template.replace("<subj>", subj).replace("<loc>", loc)

counterfactuals = [
    ("The Space Needle", "the capital of Japan"),
    ("Colosseum", "the capital of United Kingdom"),
    ("Statue of Liberty", "the capital of Italy"),
]

counterfactual_prompts = [
    format_prompt(c[0], c[1]) for c in counterfactuals
]

predict_next_token(
    mt=mt, 
    inputs=counterfactual_prompts,
)

[[PredictedToken(token=' Tokyo', prob=0.27394381165504456),
  PredictedToken(token=' Nag', prob=0.0709063857793808),
  PredictedToken(token=' Osaka', prob=0.065067358314991),
  PredictedToken(token=' Kyoto', prob=0.0640585869550705),
  PredictedToken(token=' N', prob=0.03455701470375061)],
 [PredictedToken(token=' London', prob=0.6719240546226501),
  PredictedToken(token=' Westminster', prob=0.03424644470214844),
  PredictedToken(token=' Col', prob=0.02245936542749405),
  PredictedToken(token=' Rome', prob=0.012306718155741692),
  PredictedToken(token=' Lond', prob=0.010123230516910553)],
 [PredictedToken(token=' Rome', prob=0.7661844491958618),
  PredictedToken(token=' Roma', prob=0.027261771261692047),
  PredictedToken(token=' Florence', prob=0.023871207609772682),
  PredictedToken(token=' Milan', prob=0.01705998182296753),
  PredictedToken(token=' Venice', prob=0.012978549115359783)]]

In [355]:
from src.models import prepare_input
from src.functional import get_hs, interpret_logits, logit_lens
from src.functional import get_module_nnsight


def collect_subj_last_state(
    prompt: str,
    subj: str,
    layer_name: str,
):
    inputs = prepare_input(
        prompts=prompt,
        tokenizer=mt,
        return_offsets_mapping=True,
    )

    from src.functional import find_token_range

    subj_start, subj_end = find_token_range(
        string=prompt,
        substring=subj,
        tokenizer=mt.tokenizer,
        occurrence=-1,
        offset_mapping=inputs["offset_mapping"][0],
    )

    h_idx = subj_end - 1
    # h_idx = -1
    print(f"{h_idx=} | {mt.tokenizer.decode(inputs.input_ids[0][h_idx])}")

    return h_idx, get_hs(
        mt=mt,
        input=inputs,
        layer_and_index=(layer_name, h_idx),
    )


layer = mt.layer_names[12]
# layer = mt.layer_names[-1]

# idx_0, h0 = collect_subj_last_state(
#     prompt = counterfactual_prompts[0],
#     subj = counterfactuals[0][0],
#     layer_name = layer,
# )

# idx_1, h1 = collect_subj_last_state(
#     prompt = counterfactual_prompts[1],
#     subj = counterfactuals[1][0],
#     layer_name = layer,
# )

idx_0, h0 = collect_subj_last_state(
    prompt = prompts[0],
    subj = counterfactuals[0][0],
    layer_name = layer,
)

idx_1, h1 = collect_subj_last_state(
    prompt=prompts[1],
    subj=counterfactuals[1][0],
    layer_name=layer,
)

h_idx=3 |  Needle
h_idx=5 | um


In [356]:
logit_lens(
    mt=mt,
    h=h1,
    get_proba=True
)

[('#ad', 0.211),
 ('\ufeff#', 0.104),
 ('ahat', 0.057),
 ('#ab', 0.021),
 ('#ac', 0.015),
 ('�数', 0.011),
 ('ayd', 0.006),
 ('oví', 0.006),
 ('[js', 0.006),
 (' lep', 0.005)]

In [357]:
from src.functional import untuple

inputs = prepare_input(
    # prompts=counterfactual_prompts[0],
    prompts=prompts[0],
    tokenizer=mt,
)

print(f"{mt.tokenizer.decode(inputs.input_ids[0][idx_0])=}")

with mt.trace(inputs) as trace:
    module = get_module_nnsight(mt, layer)
    untuple(module.output)[0, idx_0, :] = h1
    cur_h = untuple(module.output)[0, idx_0, :].save()
    logits = mt.output.logits[0][-1].save()

interpret_logits(
    tokenizer=mt.tokenizer,
    logits=logits,
    get_proba=True
)

mt.tokenizer.decode(inputs.input_ids[0][idx_0])=' Needle'


[(' Rome', 0.928),
 (' Roma', 0.006),
 ('\xa0', 0.002),
 (' London', 0.002),
 (' the', 0.002),
 (' Barcelona', 0.001),
 (' Rom', 0.001),
 (' Milan', 0.001),
 ('\xa0R', 0.001),
 (' San', 0.001)]

In [358]:
import baukit

def intervention(int_layer, h, idx):
    def edit_output(output, layer):
        if layer != int_layer:
            return output
        untuple(output)[0, idx, :] = h
        return output
    return edit_output

with baukit.Trace(
    module=mt._model,
    layer=layer,
    edit_output=intervention(layer, h1, idx_0),
):
    output = mt._model(**inputs)

logits = output.logits[0][-1]

interpret_logits(
    tokenizer=mt.tokenizer,
    logits=logits,
    get_proba=True
)

[(' Rome', 0.928),
 (' Roma', 0.006),
 ('\xa0', 0.002),
 (' London', 0.002),
 (' the', 0.002),
 (' Barcelona', 0.001),
 (' Rom', 0.001),
 (' Milan', 0.001),
 ('\xa0R', 0.001),
 (' San', 0.001)]