In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from tqdm.auto import tqdm

import sys
sys.path.append("../")
import os

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}")
logger.info(f"{transformers.__version__=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-09-06 14:25:29 __main__ INFO     torch.__version__='2.3.1', torch.version.cuda='12.1'
2024-09-06 14:25:29 __main__ INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-09-06 14:25:29 __main__ INFO     transformers.__version__='4.44.2'


In [3]:
from dataclasses_json import DataClassJsonMixin
from dataclasses import dataclass, field, fields
from typing import Optional
import random
from src.dataset import BridgeSample, BridgeRelation, BridgeDataset
from src.dataset import load_bridge_relation, load_bridge_relations, load_bridge_dataset        

In [4]:
experiment_utils.set_seed(123456)
relations = load_bridge_relations()
dataset = BridgeDataset(relations)

2024-09-06 14:25:29 src.utils.experiment_utils INFO     setting all seeds to 123456
2024-09-06 14:25:29 src.dataset INFO     initialized bridge relation superpower_characters with 80 examples
2024-09-06 14:25:29 src.dataset INFO     initialized bridge relation sport_players with 61 examples
2024-09-06 14:25:29 src.dataset INFO     initialized bridge relation movie_actor with 71 examples
2024-09-06 14:25:29 src.dataset INFO     initialized bridge relation architect_building with 43 examples
2024-09-06 14:25:29 src.dataset INFO     initialized bridge relation none with 102 examples
2024-09-06 14:25:29 src.dataset INFO     initialized bridge dataset with 5 relations and 352 examples


In [5]:
len(dataset.relations)

5

In [6]:
sample_idx = 22
question,answer = dataset[sample_idx]

print(question)
print(answer)

Given two entities, find a common link between them.
#
What is a common link between Michelle Williams and Marilyn Monroe?
A: My Week with Marilyn - a movie where Michelle Williams played the role of Marilyn Monroe.
#
What is a common link between Fallingwater and Guggenheim Museum?
A: Frank Lloyd Wright - who was the architect of both buildings Fallingwater and Guggenheim Museum.
#
What is a common link between Roger Federer and Rafael Nadal?
A: tennis - a sport where both Roger Federer and Rafael Nadal are known for.
#
What is a common link between Charles Darwin and Flamenco?
A: none - there is no connection between Charles Darwin and Flamenco.
#
What is a common link between Mr. Fantastic and Elastigirl?
A: elastic powers - an attribute that both characters Mr. Fantastic and Elastigirl possess.
#
What is a common link between Zhang Jike and Timo Boll?
A:
table tennis - a sport where both Zhang Jike and Timo Boll are known for.


In [7]:
import torch

from nnsight import LanguageModel
from src.models import ModelandTokenizer

# model_key = "meta-llama/Meta-Llama-3-8B"
model_key = "meta-llama/Meta-Llama-3-8B-Instruct"
# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-2-27b-it"
# model_key = "Qwen/Qwen2-7B"

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.float16,
)

2024-09-06 14:25:30 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.30s/it]

2024-09-06 14:25:35 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Meta-Llama-3-8B-Instruct> | size: 15316.516 MB | dtype: torch.float16 | device: cuda:0





In [8]:
from src.functional import predict_bridge_entity

predicted_ans = predict_bridge_entity(mt, question)
predicted_ans

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


' table tennis - a sport where both Zhang Jike and Timo Boll are professional players.\n'

In [9]:
from src.functional import verify_bridge_response

query_sample = dataset.examples[sample_idx]
verify_bridge_response(query_sample, predicted_ans)

2024-09-06 14:25:41 httpx DEBUG    load_ssl_context verify=True cert=None trust_env=True http2=False
2024-09-06 14:25:41 httpx DEBUG    load_verify_locations cafile='/home/local_arnab/miniconda3/envs/retrieval/lib/python3.11/site-packages/certifi/cacert.pem'
2024-09-06 14:25:41 src.functional DEBUG    found cached gpt4o response for be89d37e30e684fafe70a44da4ad394a - loading


'Yes'

In [10]:
from src.tokens import prepare_input
from src.functional import predict_next_token

input = prepare_input(
    tokenizer=mt,
    prompts=question
    # prompts = "The Space Needle is located in the city of"
)

predict_next_token(mt, input)

[[PredictedToken(token=' table', prob=0.8527361750602722, logit=None, token_id=2007),
  PredictedToken(token=' Table', prob=0.11722265928983688, logit=None, token_id=6771),
  PredictedToken(token=' ping', prob=0.008359931409358978, logit=None, token_id=31098),
  PredictedToken(token=' professional', prob=0.0034849378280341625, logit=None, token_id=6721),
  PredictedToken(token=' Ping', prob=0.0029346118681132793, logit=None, token_id=49757)]]

In [11]:
from src.functional import get_hs

hs = get_hs(
    mt=mt,
    input=input,
    locations=[(l, -1) for l in mt.layer_names]   
)

In [12]:
from src.functional import patchscope, logit_lens

layer_idx = 31
h = hs[(mt.layer_name_format.format(layer_idx), -1)]

patchscope(
    mt=mt,
    h=h,
    layer_idx=0,
    k = 10
)

2024-09-06 14:25:42 src.functional DEBUG    placeholder position: 21 | token:  placeholder


[PredictedToken(token=' ->', prob=0.33404892683029175, logit=13.2890625, token_id=1492),
 PredictedToken(token=' ;', prob=0.2925030291080475, logit=13.15625, token_id=2652),
 PredictedToken(token=';', prob=0.028736338019371033, logit=10.8359375, token_id=26),
 PredictedToken(token=' ', prob=0.017982741817831993, logit=10.3671875, token_id=220),
 PredictedToken(token=' ;\n', prob=0.017703944817185402, logit=10.3515625, token_id=4485),
 PredictedToken(token=' table', prob=0.017025716602802277, logit=10.3125, token_id=2007),
 PredictedToken(token=' \n', prob=0.014908215962350368, logit=10.1796875, token_id=720),
 PredictedToken(token=';\n', prob=0.012553977780044079, logit=10.0078125, token_id=280),
 PredictedToken(token=' &', prob=0.011165737174451351, logit=9.890625, token_id=612),
 PredictedToken(token=' ->\n', prob=0.009700961410999298, logit=9.75, token_id=12662)]

In [13]:
logit_lens(mt, h)

[PredictedToken(token=' table', prob=0.8527359962463379, logit=21.734375, token_id=2007),
 PredictedToken(token=' Table', prob=0.11722263693809509, logit=19.75, token_id=6771),
 PredictedToken(token=' ping', prob=0.008359930478036404, logit=17.109375, token_id=31098),
 PredictedToken(token=' professional', prob=0.0034849371295422316, logit=16.234375, token_id=6721),
 PredictedToken(token=' Ping', prob=0.0029346111696213484, logit=16.0625, token_id=49757)]

In [32]:
from src.functional import get_module_nnsight
import numpy as np

all_layers = (
    [mt.embedder_name]  # embeddings
    + mt.layer_names  # residual
    + [
        mt.mlp_module_name_format.format(i) for i in range(mt.n_layer)
    ]  # mlp outputs
    + [
        mt.attn_module_name_format.format(i) for i in range(mt.n_layer)
    ]  # attn outputs
)

cache = {
    "doc": question,
    "input_ids": input["input_ids"].cpu().numpy().astype(np.int32),
    "attention_mask": input["attention_mask"].cpu().numpy().astype(np.int32),
    "outputs": {layer: None for layer in all_layers},
}
with torch.no_grad():
    with mt.trace(input, scan=False, validate=False) as trace:
        for layer_name in all_layers:
            module = get_module_nnsight(mt, layer_name)
            cache["outputs"][layer_name] = (
                module.output.save()
                if ("mlp" in layer_name or layer_name == mt.embedder_name)
                else module.output[0].save()
            )

for layer_name in all_layers:
    cache["outputs"][layer_name] = cache["outputs"][layer_name].cpu().numpy().astype(np.float16)

In [33]:
for layer_name in all_layers:
    print(f"{layer_name=} | {cache['outputs'][layer_name].shape}")

layer_name='model.embed_tokens' | (1, 196, 4096)
layer_name='model.layers.0' | (1, 196, 4096)
layer_name='model.layers.1' | (1, 196, 4096)
layer_name='model.layers.2' | (1, 196, 4096)
layer_name='model.layers.3' | (1, 196, 4096)
layer_name='model.layers.4' | (1, 196, 4096)
layer_name='model.layers.5' | (1, 196, 4096)
layer_name='model.layers.6' | (1, 196, 4096)
layer_name='model.layers.7' | (1, 196, 4096)
layer_name='model.layers.8' | (1, 196, 4096)
layer_name='model.layers.9' | (1, 196, 4096)
layer_name='model.layers.10' | (1, 196, 4096)
layer_name='model.layers.11' | (1, 196, 4096)
layer_name='model.layers.12' | (1, 196, 4096)
layer_name='model.layers.13' | (1, 196, 4096)
layer_name='model.layers.14' | (1, 196, 4096)
layer_name='model.layers.15' | (1, 196, 4096)
layer_name='model.layers.16' | (1, 196, 4096)
layer_name='model.layers.17' | (1, 196, 4096)
layer_name='model.layers.18' | (1, 196, 4096)
layer_name='model.layers.19' | (1, 196, 4096)
layer_name='model.layers.20' | (1, 196, 4

In [36]:
np.savez_compressed("test.npz", allow_pickle=True, **cache)

In [38]:
loaded_npz = np.load("../results/cache_states/Meta-Llama-3-8B-Instruct/wikipedia/6.npz", allow_pickle=True)
loaded_npz.files

['allow_pickle', 'doc', 'input_ids', 'attention_mask', 'outputs']

In [50]:
cached_outputs = loaded_npz["outputs"].item()

for layer_name in all_layers:
    print(f"{layer_name=} | {cached_outputs[layer_name].shape}")

layer_name='model.embed_tokens' | (1, 1024, 4096)
layer_name='model.layers.0' | (1, 1024, 4096)
layer_name='model.layers.1' | (1, 1024, 4096)
layer_name='model.layers.2' | (1, 1024, 4096)
layer_name='model.layers.3' | (1, 1024, 4096)
layer_name='model.layers.4' | (1, 1024, 4096)
layer_name='model.layers.5' | (1, 1024, 4096)
layer_name='model.layers.6' | (1, 1024, 4096)
layer_name='model.layers.7' | (1, 1024, 4096)
layer_name='model.layers.8' | (1, 1024, 4096)
layer_name='model.layers.9' | (1, 1024, 4096)
layer_name='model.layers.10' | (1, 1024, 4096)
layer_name='model.layers.11' | (1, 1024, 4096)
layer_name='model.layers.12' | (1, 1024, 4096)
layer_name='model.layers.13' | (1, 1024, 4096)
layer_name='model.layers.14' | (1, 1024, 4096)
layer_name='model.layers.15' | (1, 1024, 4096)
layer_name='model.layers.16' | (1, 1024, 4096)
layer_name='model.layers.17' | (1, 1024, 4096)
layer_name='model.layers.18' | (1, 1024, 4096)
layer_name='model.layers.19' | (1, 1024, 4096)
layer_name='model.la