In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm
import spacy

import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")


client = OpenAI(
    api_key=os.getenv("OPENAI_KEY"),
)

MODEL_NAME = "gpt-4o"

  from .autonotebook import tqdm as notebook_tqdm


2024-08-07 15:00:47 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-08-07 15:00:47 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-08-07 15:00:47 __main__ INFO     transformers.__version__='4.43.3'
2024-08-07 15:00:47 httpx DEBUG    load_ssl_context verify=True cert=None trust_env=True http2=False
2024-08-07 15:00:47 httpx DEBUG    load_verify_locations cafile='/home/local_arnab/miniconda3/envs/retrieval/lib/python3.11/site-packages/certifi/cacert.pem'


In [3]:
import torch

from nnsight import LanguageModel
from src.models import ModelandTokenizer

# model_key = "meta-llama/Meta-Llama-3-8B"
model_key = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-2-27b-it"
# model_key = "Qwen/Qwen2-7B"

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.float16,
)

2024-08-07 15:00:49 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.19s/it]

2024-08-07 15:00:54 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Meta-Llama-3-8B-Instruct> | size: 15316.516 MB | dtype: torch.float16 | device: cuda:0





In [4]:
torch.linspace(0, 1, 10)

tensor([0.0000, 0.1111, 0.2222, 0.3333, 0.4444, 0.5556, 0.6667, 0.7778, 0.8889,
        1.0000])

In [215]:
from src.utils.typing import TokenizerOutput
from src.functional import get_module_nnsight, untuple
from typing import Literal
from dataclasses import dataclass


#! the clean here actually stands for **corrupt* in the causal tracing
def attribution_patching(
    mt: ModelandTokenizer,
    clean_inputs: TokenizerOutput,
    h_clean: torch.Tensor,
    h_patch: torch.Tensor,
    patch_location: tuple[str, int],
    interested_locations: list[tuple[str, int]],
    ans_token_idx: int,
    metric: Literal["logit", "proba"] = "proba",
    resolution: int = 10,
) -> float:

    h_clean = h_clean.to(mt.device)
    h_patch = h_patch.to(mt.device)

    if "offset_mapping" in clean_inputs:
        clean_inputs.pop("offset_mapping")

    grads = {}

    scan = True

    for alpha in torch.linspace(0, 1, resolution):
        cur_grads = {}

        with mt.trace(clean_inputs, scan=scan) as trace:
            # patching
            h = alpha * h_clean + (1 - alpha) * h_patch
            h.retain_grad = True
            module_name, tok_idx = patch_location
            patch_module = get_module_nnsight(mt, module_name)
            patch_module.output[0, tok_idx, :] = h

            # cache the interested hidden states
            for loc in interested_locations:
                module_name, tok_idx = loc
                module = get_module_nnsight(mt, module_name)
                cur_output = (
                    module.output[0, tok_idx, :].save()
                    if "mlp" in module_name
                    else module.output[0][0, tok_idx, :].save()
                )
                # cur_grads[loc] = cur_output.grad.save() #! this doesn't work, weird nnsight issue
                cur_grads[loc] = (
                    module.output.grad[0, tok_idx, :].save()
                    if "mlp" in module_name
                    else module.output[0].grad[0, tok_idx, :].save()
                )

                # initialize the grads
                if scan:
                    grads[loc] = torch.zeros_like(cur_output).to(mt.device).save()

            #! another nnsight quirk => backward() has to be called later than grad.save() to populate the proxies
            if metric == "logit":
                v = mt.output.logits[0][-1][ans_token_idx]
            elif metric == "proba":
                v = mt.output.logits[0][-1].softmax(dim=-1)[ans_token_idx]
            else:
                raise ValueError(f"unknown {metric=}")
            v.backward()


        for loc in interested_locations:
            print(f"{loc=} | {patch_location=}")
            print(f"{grads[loc].shape=} | {grads[loc].norm()=}")
            print(f"{cur_grads[loc].shape=} | {cur_grads[loc].norm()=}")
            module_name, tok_idx = loc
            grads[loc] += cur_grads[loc]

        mt._model.zero_grad()
        scan = False

    grads = {loc: grad / resolution for loc, grad in grads.items()}
    approx_IE = {loc: torch.dot(grad, h_clean - h_patch) for loc, grad in grads.items()}

    return approx_IE

In [9]:
from src.functional import prepare_input, guess_subject
from src.functional import find_token_range, get_hs
from typing import Optional

def get_h_at_subj(
    mt: ModelandTokenizer,
    layer: str | list[str],
    prompt: str | TokenizerOutput,
    subj: Optional[str] = None,
    input: Optional[TokenizerOutput] = None,
) -> torch.Tensor:
    if subj is None:
        subj = guess_subject(prompt)
        logger.warning(f"no subj provided, guessed {subj=}")
    else:
        assert subj in prompt, f"{subj=} not in {prompt=}"

    skip_prepare_input = input is not None and "offset_mapping" in input
    if not skip_prepare_input:
        logger.debug(f"preparing input for prompt: {prompt}")
        input = prepare_input(
            prompts=prompt, 
            tokenizer=mt, 
            return_offsets_mapping=True
        )
    offset_mapping = input.pop("offset_mapping")[0]
    subj_range = find_token_range(string=prompt, substring=subj, tokenizer=mt.tokenizer, offset_mapping=offset_mapping)
    subj_ends = subj_range[1] - 1

    logger.debug(f"{subj=} => {subj_ends=} | \"{mt.tokenizer.decode(input['input_ids'][0][subj_ends])}\"")

    return get_hs(
        mt = mt, input = input, 
        layer_and_index = [(l, subj_ends) for l in layer]
    )


prompt =  "{} is located in the city of"
clean_subj = "Louvre"
patch_subj = "The Space Needle"

clean_hs, patch_hs = [
    get_h_at_subj(
        mt = mt, 
        layer = [mt.embedder_name],
        prompt = prompt.format(subj),
        subj = subj,
    ) for subj in [clean_subj, patch_subj]
]

2024-08-07 15:02:31 __main__ DEBUG    preparing input for prompt: Louvre is located in the city of
2024-08-07 15:02:31 __main__ DEBUG    subj='Louvre' => subj_ends=2 | "vre"
2024-08-07 15:02:32 __main__ DEBUG    preparing input for prompt: The Space Needle is located in the city of
2024-08-07 15:02:32 __main__ DEBUG    subj='The Space Needle' => subj_ends=3 | " Needle"


In [173]:
from src.functional import predict_next_token

ans = predict_next_token(
    mt = mt,
    inputs = prompt.format(patch_subj),
)[0][0]

logger.debug(ans)

clean_inputs = prepare_input(
    prompts=prompt.format(clean_subj), 
    tokenizer=mt, 
    return_offsets_mapping=True
)
subj_range = find_token_range(
    string=prompt.format(clean_subj), 
    substring=clean_subj, 
    tokenizer=mt.tokenizer,
    offset_mapping=clean_inputs["offset_mapping"][0]
)

2024-08-07 16:21:48 __main__ DEBUG    " Seattle" (p=0.987)


In [216]:
results = attribution_patching(
    mt = mt,
    clean_inputs = clean_inputs,
    h_clean = clean_hs,
    h_patch = patch_hs,
    patch_location = (mt.embedder_name, subj_range[1] - 1),
    interested_locations=[
        (mt.layer_name_format.format(l), subj_range[1] - 1) 
        for l in range(mt.n_layer)
    ],
    ans_token_idx=ans.token_id,
)

loc=('model.layers.0', 2) | patch_location=('model.embed_tokens', 2)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.1826, device='cuda:0', dtype=torch.float16)
loc=('model.layers.1', 2) | patch_location=('model.embed_tokens', 2)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.1242, device='cuda:0', dtype=torch.float16)
loc=('model.layers.2', 2) | patch_location=('model.embed_tokens', 2)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=tensor(0., device='cuda:0', dtype=torch.float16)
cur_grads[loc].shape=torch.Size([4096]) | cur_grads[loc].norm()=tensor(0.0794, device='cuda:0', dtype=torch.float16)
loc=('model.layers.3', 2) | patch_location=('model.embed_tokens', 2)
grads[loc].shape=torch.Size([4096]) | grads[loc].norm()=te

In [217]:
results

{('model.layers.0', 2): tensor(-0.0004, device='cuda:0', dtype=torch.float16),
 ('model.layers.1', 2): tensor(-0.0011, device='cuda:0', dtype=torch.float16),
 ('model.layers.2', 2): tensor(-0.0006, device='cuda:0', dtype=torch.float16),
 ('model.layers.3', 2): tensor(-0.0003, device='cuda:0', dtype=torch.float16),
 ('model.layers.4',
  2): tensor(8.8394e-05, device='cuda:0', dtype=torch.float16),
 ('model.layers.5',
  2): tensor(9.5069e-05, device='cuda:0', dtype=torch.float16),
 ('model.layers.6', 2): tensor(0.0001, device='cuda:0', dtype=torch.float16),
 ('model.layers.7', 2): tensor(0.0001, device='cuda:0', dtype=torch.float16),
 ('model.layers.8',
  2): tensor(6.0201e-05, device='cuda:0', dtype=torch.float16),
 ('model.layers.9',
  2): tensor(4.7088e-05, device='cuda:0', dtype=torch.float16),
 ('model.layers.10',
  2): tensor(1.3590e-05, device='cuda:0', dtype=torch.float16),
 ('model.layers.11',
  2): tensor(2.0504e-05, device='cuda:0', dtype=torch.float16),
 ('model.layers.12',
 

In [195]:
from src.functional import untuple

clean_inputs = prepare_input(
    prompts=prompt.format(clean_subj), 
    tokenizer=mt, 
    return_offsets_mapping=False
)


cur_grads = {l: None for l in mt.layer_names}

# module_name = mt.layer_name_format.format(10)
# module_name = mt.embedder_name
module_name = mt.mlp_module_name_format.format(10)
# module_name = mt.attn_module_name_format.format(10)
with mt.trace(clean_inputs, scan = True) as trace:
    module = get_module_nnsight(mt, module_name)
    h = module.output[0, 5, :].save()
    h_grad = module.output.grad[0, 5, :].save()

    for l in cur_grads:
        module = get_module_nnsight(mt, l)
        cur_grads[l] = module.output[0].grad.save()
    
    m = mt.output.logits[0][-1].softmax(dim=-1)[ans.token_id]
    m.backward()

In [196]:
for module, grad in cur_grads.items():
    print(module, grad.shape)

model.layers.0 torch.Size([1, 9, 4096])
model.layers.1 torch.Size([1, 9, 4096])
model.layers.2 torch.Size([1, 9, 4096])
model.layers.3 torch.Size([1, 9, 4096])
model.layers.4 torch.Size([1, 9, 4096])
model.layers.5 torch.Size([1, 9, 4096])
model.layers.6 torch.Size([1, 9, 4096])
model.layers.7 torch.Size([1, 9, 4096])
model.layers.8 torch.Size([1, 9, 4096])
model.layers.9 torch.Size([1, 9, 4096])
model.layers.10 torch.Size([1, 9, 4096])
model.layers.11 torch.Size([1, 9, 4096])
model.layers.12 torch.Size([1, 9, 4096])
model.layers.13 torch.Size([1, 9, 4096])
model.layers.14 torch.Size([1, 9, 4096])
model.layers.15 torch.Size([1, 9, 4096])
model.layers.16 torch.Size([1, 9, 4096])
model.layers.17 torch.Size([1, 9, 4096])
model.layers.18 torch.Size([1, 9, 4096])
model.layers.19 torch.Size([1, 9, 4096])
model.layers.20 torch.Size([1, 9, 4096])
model.layers.21 torch.Size([1, 9, 4096])
model.layers.22 torch.Size([1, 9, 4096])
model.layers.23 torch.Size([1, 9, 4096])
model.layers.24 torch.Size

In [65]:
untuple(h).shape

torch.Size([1, 9, 4096])

In [66]:
type(h[-1])

transformers.cache_utils.DynamicCache

In [168]:
mt.layer_names

['model.layers.0',
 'model.layers.1',
 'model.layers.2',
 'model.layers.3',
 'model.layers.4',
 'model.layers.5',
 'model.layers.6',
 'model.layers.7',
 'model.layers.8',
 'model.layers.9',
 'model.layers.10',
 'model.layers.11',
 'model.layers.12',
 'model.layers.13',
 'model.layers.14',
 'model.layers.15',
 'model.layers.16',
 'model.layers.17',
 'model.layers.18',
 'model.layers.19',
 'model.layers.20',
 'model.layers.21',
 'model.layers.22',
 'model.layers.23',
 'model.layers.24',
 'model.layers.25',
 'model.layers.26',
 'model.layers.27',
 'model.layers.28',
 'model.layers.29',
 'model.layers.30',
 'model.layers.31']