In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm
import spacy

import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")


client = OpenAI(
    api_key=os.getenv("OPENAI_KEY"),
)

MODEL_NAME = "gpt-4o"

  from .autonotebook import tqdm as notebook_tqdm


2024-08-08 13:08:07 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-08-08 13:08:07 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-08-08 13:08:07 __main__ INFO     transformers.__version__='4.43.3'
2024-08-08 13:08:07 httpx DEBUG    load_ssl_context verify=True cert=None trust_env=True http2=False
2024-08-08 13:08:07 httpx DEBUG    load_verify_locations cafile='/home/local_arnab/miniconda3/envs/retrieval/lib/python3.11/site-packages/certifi/cacert.pem'


In [3]:
import torch

from nnsight import LanguageModel
from src.models import ModelandTokenizer

# model_key = "meta-llama/Meta-Llama-3-8B"
model_key = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-2-27b-it"
# model_key = "Qwen/Qwen2-7B"

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.float16,
)

2024-08-08 13:08:08 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.26s/it]

2024-08-08 13:08:13 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Meta-Llama-3-8B-Instruct> | size: 15316.516 MB | dtype: torch.float16 | device: cuda:0





In [14]:
from src.utils.typing import TokenizerOutput
from src.functional import get_module_nnsight, untuple, get_hs, PatchSpec
from typing import Literal
from dataclasses import dataclass


#! the clean here actually stands for **corrupt* in the causal tracing
def attribution_patching(
    mt: ModelandTokenizer,
    clean_inputs: TokenizerOutput,
    patches: PatchSpec | list[PatchSpec],
    interested_locations: list[tuple[str, int]],
    ans_token_idx: int,
    metric: Literal["logit", "proba"] = "proba",
    resolution: int = 10,
) -> float:

    if "offset_mapping" in clean_inputs:
        clean_inputs.pop("offset_mapping")
    if isinstance(patches, PatchSpec):
        patches = [patches]

    clean_states = get_hs(
        mt = mt, 
        input = clean_inputs,
        locations = interested_locations,
    )
    patched_states = get_hs(
        mt = mt, 
        input = clean_inputs,
        locations = interested_locations,
        patches = patches
    )

    grads = {}
    scan = True
    for alpha in torch.linspace(0, 1, resolution):
        cur_grads = {}

        with mt.trace(clean_inputs, scan=scan) as trace:
            # patching
            for patch in patches:
                module_name, tok_idx = patch.location
                patch_module = get_module_nnsight(mt, module_name)

                assert isinstance(patch.clean, torch.Tensor) and patch.clean.shape == patch.patch.shape
                h = alpha * patch.clean + (1 - alpha) * patch.patch
                h = h.to(mt.device) if not h.device == mt.device else h
                h.retain_grad = True

                patch_module.output[0, tok_idx, :] = h

            # cache the interested hidden states
            for loc in interested_locations:
                module_name, tok_idx = loc
                module = get_module_nnsight(mt, module_name)
                cur_output = (
                    module.output.save()
                    if "mlp" in module_name
                    else module.output[0].save()
                )   #! nnsight quirk => to get the grad of a reference tensor, you can't index it
                cur_grads[loc] = cur_output.grad[0, tok_idx, :].save()
                # cur_grads[loc] = (
                #     module.output.grad[0, tok_idx, :].save()
                #     if "mlp" in module_name
                #     else module.output[0].grad[0, tok_idx, :].save()
                # )

                # initialize the grads
                if scan:
                    grads[loc] = torch.zeros_like(cur_output.grad[0, tok_idx, :]).to(mt.device).save()

            #! nnsight quirk => backward() has to be called later than grad.save() to populate the proxies
            if metric == "logit":
                v = mt.output.logits[0][-1][ans_token_idx]
            elif metric == "proba":
                v = mt.output.logits[0][-1].softmax(dim=-1)[ans_token_idx]
            else:
                raise ValueError(f"unknown {metric=}")
            v.backward()


        for loc in interested_locations:
            print(f"{loc=}")
            print(f"{grads[loc].shape=} | {grads[loc].norm()=}")
            print(f"{cur_grads[loc].shape=} | {cur_grads[loc].norm()=}")
            module_name, tok_idx = loc
            grads[loc] += cur_grads[loc]

        mt._model.zero_grad()
        scan = False

    grads = {loc: grad / resolution for loc, grad in grads.items()}

    approx_IE = {
        loc: torch.dot(grad, patched_states[loc] - clean_states[loc]).item() 
        for loc, grad in grads.items()
    } 

    return approx_IE

In [15]:
from src.functional import prepare_input, guess_subject
from src.functional import find_token_range, get_hs
from typing import Optional

def get_h_at_subj(
    mt: ModelandTokenizer,
    layer: str | list[str],
    prompt: str | TokenizerOutput,
    subj: Optional[str] = None,
    input: Optional[TokenizerOutput] = None,
) -> torch.Tensor:
    if subj is None:
        subj = guess_subject(prompt)
        logger.warning(f"no subj provided, guessed {subj=}")
    else:
        assert subj in prompt, f"{subj=} not in {prompt=}"

    skip_prepare_input = input is not None and "offset_mapping" in input
    if not skip_prepare_input:
        logger.debug(f"preparing input for prompt: {prompt}")
        input = prepare_input(
            prompts=prompt, 
            tokenizer=mt, 
            return_offsets_mapping=True
        )
    offset_mapping = input.pop("offset_mapping")[0]
    subj_range = find_token_range(string=prompt, substring=subj, tokenizer=mt.tokenizer, offset_mapping=offset_mapping)
    subj_ends = subj_range[1] - 1

    logger.debug(f"{subj=} => {subj_ends=} | \"{mt.tokenizer.decode(input['input_ids'][0][subj_ends])}\"")

    return get_hs(
        mt = mt, input = input, 
        locations = [(l, subj_ends) for l in layer]
    )


# prompt =  "{} is located in the city of"
# clean_subj = "Louvre"
# patch_subj = "The Space Needle"

# clean_hs, patch_hs = [
#     get_h_at_subj(
#         mt = mt, 
#         layer = [mt.embedder_name],
#         prompt = prompt.format(subj),
#         subj = subj,
#     ) for subj in [clean_subj, patch_subj]
# ]

In [17]:
from src.functional import predict_next_token
from src.trace import insert_padding_before_subj

prompt =  "{} is located in the city of"
clean_subj = "Louvre"
patch_subj = "The Space Needle"

ans = predict_next_token(
    mt = mt,
    inputs = prompt.format(patch_subj),
)[0][0]

logger.debug(ans)

clean_inputs = prepare_input(
    prompts=prompt.format(clean_subj), 
    tokenizer=mt, 
    return_offsets_mapping=True
)
clean_subj_range = find_token_range(
    string=prompt.format(clean_subj), 
    substring=clean_subj, 
    tokenizer=mt.tokenizer,
    offset_mapping=clean_inputs["offset_mapping"][0]
)
logger.debug(f"{clean_subj_range=} | {mt.tokenizer.decode(clean_inputs['input_ids'][0][clean_subj_range[1]-1])}")

patched_inputs = prepare_input(
    prompts=prompt.format(patch_subj), 
    tokenizer=mt, 
    return_offsets_mapping=True
)
patched_subj_range = find_token_range(
    string=prompt.format(patch_subj), 
    substring=patch_subj, 
    tokenizer=mt.tokenizer,
    offset_mapping=patched_inputs["offset_mapping"][0]
)
logger.debug(f"{patched_subj_range=} | {mt.tokenizer.decode(patched_inputs['input_ids'][0][patched_subj_range[1]-1])}")

subj_end = max(clean_subj_range[1], patched_subj_range[1])

clean_inputs = insert_padding_before_subj(
    inp = clean_inputs,
    subj_range = clean_subj_range,
    subj_ends = subj_end,
    pad_id = mt.tokenizer.bos_token_id,
)
patched_inputs = insert_padding_before_subj(
    inp = patched_inputs,
    subj_range = patched_subj_range,
    subj_ends = subj_end,
    pad_id = mt.tokenizer.bos_token_id,
)

clean_subj_shift = subj_end - clean_subj_range[1]
clean_subj_range = (clean_subj_range[0] + clean_subj_shift, subj_end)
patched_subj_shift = subj_end - patched_subj_range[1]
patched_subj_range = (patched_subj_range[0] + patched_subj_shift, subj_end)

subj_start = min(clean_subj_range[0], patched_subj_range[0])

for idx, (tok_id, attn_mask) in enumerate(zip(clean_inputs.input_ids[0], clean_inputs.attention_mask[0])):
    is_subj = clean_subj_range[0] <= idx < clean_subj_range[1]
    append = "*" if is_subj else ""
    print(f"{idx=} [{attn_mask}] | {mt.tokenizer.decode(tok_id)}"+append)

print("-"*50)

for idx, (tok_id, attn_mask) in enumerate(zip(patched_inputs.input_ids[0], patched_inputs.attention_mask[0])):
    is_subj = patched_subj_range[0] <= idx < patched_subj_range[1]
    append = "*" if is_subj else ""
    print(f"{idx=} [{attn_mask}] | {mt.tokenizer.decode(tok_id)}"+append)

2024-08-08 13:26:47 __main__ DEBUG    " Seattle" (p=0.987)
2024-08-08 13:26:47 __main__ DEBUG    clean_subj_range=(1, 3) | vre
2024-08-08 13:26:47 __main__ DEBUG    patched_subj_range=(1, 4) |  Needle
idx=0 [1] | <|begin_of_text|>
idx=1 [0] | <|begin_of_text|>
idx=2 [1] | Lou*
idx=3 [1] | vre*
idx=4 [1] |  is
idx=5 [1] |  located
idx=6 [1] |  in
idx=7 [1] |  the
idx=8 [1] |  city
idx=9 [1] |  of
--------------------------------------------------
idx=0 [1] | <|begin_of_text|>
idx=1 [1] | The*
idx=2 [1] |  Space*
idx=3 [1] |  Needle*
idx=4 [1] |  is
idx=5 [1] |  located
idx=6 [1] |  in
idx=7 [1] |  the
idx=8 [1] |  city
idx=9 [1] |  of


In [20]:
emb_clean = get_hs(
    mt = mt, 
    input = clean_inputs,
    locations = [(mt.embedder_name, tok_idx) for tok_idx in range(subj_start, subj_end)]
)
emb_patch = get_hs(
    mt = mt, 
    input = patched_inputs,
    locations = [(mt.embedder_name, tok_idx) for tok_idx in range(subj_start, subj_end)]
)

for key in emb_clean.keys():
    print(f"{key=} | {emb_clean[key].shape=} | {emb_patch[key].shape=}")

key=('model.embed_tokens', 1) | emb_clean[key].shape=torch.Size([4096]) | emb_patch[key].shape=torch.Size([4096])
key=('model.embed_tokens', 2) | emb_clean[key].shape=torch.Size([4096]) | emb_patch[key].shape=torch.Size([4096])
key=('model.embed_tokens', 3) | emb_clean[key].shape=torch.Size([4096]) | emb_patch[key].shape=torch.Size([4096])


In [27]:
patch_spec = [
    PatchSpec(
        location = location,
        patch = emb_patch[location],
        clean = emb_clean[location],
    ) for location in emb_clean.keys()
]

test_h = get_hs(
    mt=mt,
    input=clean_inputs,
    locations=[(mt.embedder_name, subj_end - 1)],
    # patches=PatchSpec(
    #     location=(mt.embedder_name, subj_end - 1),
    #     patch=emb_patch[(mt.embedder_name, subj_end - 1)],  
    # ),
    patches=patch_spec,
)

torch.allclose(emb_patch[(mt.embedder_name, subj_end - 1)], test_h)

True

In [28]:
results = attribution_patching(
    mt = mt,
    clean_inputs = clean_inputs,
    patches = patch_spec,
    interested_locations=[
        (mt.layer_name_format.format(l), subj_end - 1) 
        for l in range(mt.n_layer)
    ],
    ans_token_idx=ans.token_id,
)

loc=('model.layers.0', 3)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.5522, device='cuda:0', dtype=torch.float16)
loc=('model.layers.1', 3)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.3730, device='cuda:0', dtype=torch.float16)
loc=('model.layers.2', 3)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.2252, device='cuda:0', dtype=torch.float16)
loc=('model.layers.3', 3)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.1580, device='cuda:0', dtype=torch.float16)
loc=('mo

In [10]:
results

{('model.layers.0', 3): 0.0035152435302734375,
 ('model.layers.1', 3): 0.003936767578125,
 ('model.layers.2', 3): 0.00331878662109375,
 ('model.layers.3', 3): 0.00189208984375,
 ('model.layers.4', 3): 0.00103759765625,
 ('model.layers.5', 3): 0.0005850791931152344,
 ('model.layers.6', 3): 0.0010442733764648438,
 ('model.layers.7', 3): 0.0008549690246582031,
 ('model.layers.8', 3): 0.0004711151123046875,
 ('model.layers.9', 3): -5.4836273193359375e-05,
 ('model.layers.10', 3): 0.00035953521728515625,
 ('model.layers.11', 3): 0.00047898292541503906,
 ('model.layers.12', 3): 0.00025582313537597656,
 ('model.layers.13', 3): 0.00022268295288085938,
 ('model.layers.14', 3): 0.0002803802490234375,
 ('model.layers.15', 3): 0.0002503395080566406,
 ('model.layers.16', 3): 0.0004260540008544922,
 ('model.layers.17', 3): 0.0003256797790527344,
 ('model.layers.18', 3): 0.0003790855407714844,
 ('model.layers.19', 3): 0.0003502368927001953,
 ('model.layers.20', 3): 0.0005354881286621094,
 ('model.lay

In [34]:
from src.functional import untuple

clean_inputs = prepare_input(
    prompts=prompt.format(clean_subj), 
    tokenizer=mt, 
    return_offsets_mapping=False
)


cur_grads = {l: None for l in mt.layer_names}

# module_name = mt.layer_name_format.format(10)
# module_name = mt.embedder_name
module_name = mt.mlp_module_name_format.format(10)
# module_name = mt.attn_module_name_format.format(10)
with mt.trace(clean_inputs, scan = True) as trace:
    module = get_module_nnsight(mt, module_name)
    h = module.output.save()
    # h_grad = module.output.grad[0, 5, :].save()
    h_grad = h.grad[0, 5, :].save()

    m = mt.output.logits[0][-1].softmax(dim=-1)[ans.token_id]
    m.backward()
    



    # for l in cur_grads:
    #     module = get_module_nnsight(mt, l)
    #     cur_grads[l] = module.output[0].grad.save()

In [35]:
# for module, grad in cur_grads.items():
#     print(module, grad.shape)
h_grad

tensor([-7.1526e-06, -8.4639e-06, -1.0073e-05,  ...,  3.2187e-06,
        -5.4240e-06, -1.4246e-05], device='cuda:0', dtype=torch.float16)

In [12]:
untuple(h).shape

torch.Size([])

In [13]:
type(h[-1])

torch.Tensor

In [14]:
mt.layer_names

['model.layers.0',
 'model.layers.1',
 'model.layers.2',
 'model.layers.3',
 'model.layers.4',
 'model.layers.5',
 'model.layers.6',
 'model.layers.7',
 'model.layers.8',
 'model.layers.9',
 'model.layers.10',
 'model.layers.11',
 'model.layers.12',
 'model.layers.13',
 'model.layers.14',
 'model.layers.15',
 'model.layers.16',
 'model.layers.17',
 'model.layers.18',
 'model.layers.19',
 'model.layers.20',
 'model.layers.21',
 'model.layers.22',
 'model.layers.23',
 'model.layers.24',
 'model.layers.25',
 'model.layers.26',
 'model.layers.27',
 'model.layers.28',
 'model.layers.29',
 'model.layers.30',
 'model.layers.31']