In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-28 15:25:53 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

#! torch.adaptive precision
mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float32,
)

2024-10-28 15:25:54 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.64s/it]

2024-10-28 15:25:58 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 12255.675 MB | dtype: torch.float32 | device: cuda:0





In [4]:
from src.utils.typing import LatentCache, LatentCacheCollection
from src.tokens import prepare_input
from src.functional import get_module_nnsight, interpret_logits
from src.functional import get_batch_concept_activations

# prompts = [
#     "Eiffel Tower is in which city? It is in",
#     "A quick brown fox",
#     "The sun rises in the",
# ]

prompts = [
    "The land of the rising sun is",
    "The capital of France is",
    "The Space Needle is a tower in"
]

latents = get_batch_concept_activations(
    mt=mt,
    prompts=prompts,
    interested_layer_indices=list(range(5, 20)),
    check_prediction=None,
    on_token_occur=None,
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [5]:
lcc = LatentCacheCollection(latents=latents)
lcc.detensorize()

In [6]:
with open("lcc_batch/test_lcc.json", "w") as f:
    f.write(lcc.to_json())

In [25]:
from src.dataset_manager import DatasetManager

# group, name = "relations", "factual/country_capital_city"
group, name = "relations", "commonsense/fruit_outside_color"

dataloader = DatasetManager.from_named_datasets(
    [
        (group, name),
    ],
    batch_size=8
)

2024-10-28 16:52:38 src.dataset_manager INFO     Loaded 60 examples from commonsense/fruit_outside_color.


In [26]:
cache_dir = os.path.join(group, name)
os.makedirs(cache_dir, exist_ok=True)

for batch_idx, batch in tqdm(enumerate(dataloader)):
    prompts = [b.context for b in batch]    
    latents = get_batch_concept_activations(
        mt=mt,
        prompts=prompts,
        interested_layer_indices=list(range(5, 20)),
        check_prediction=None,
        on_token_occur=None,
    )

    correct_labels = [b.correct for b in batch]
    incorrect_labels = [b.incorrect for b in batch]

    for latent_cache, correct, incorrect in zip(latents, correct_labels, incorrect_labels):
        latent_cache.correct_label = correct
        latent_cache.incorrect_label = incorrect
        latent_cache.group="relations"
    
    lcc = LatentCacheCollection(latents=latents)
    lcc.detensorize()

    with open(os.path.join(cache_dir, f"batch_{batch_idx}.json"), "w") as f:
        f.write(lcc.to_json())

0it [00:00, ?it/s]

8it [00:04,  1.80it/s]


In [48]:
from src.activation_manager import ActivationLoader

In [51]:
latent_root = os.path.join(env_utils.DEFAULT_RESULTS_DIR, "cached_latents")

def get_batch_paths(root):
    for root, _, files in os.walk(root):
        for file in files:
            if file.endswith(".json"):
                yield os.path.join(root, file)


batch_paths = list(get_batch_paths(latent_root))
batch_paths

['/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/object_superclass/batch_0.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/object_superclass/batch_1.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/object_superclass/batch_2.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/object_superclass/batch_4.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/object_superclass/batch_3.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/substance_phase/batch_0.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/relations/commonsense/substance_phase/batch_1.json',
 '/

In [52]:
act_loader = ActivationLoader(
    latent_cache_files=batch_paths,
    shuffle=True,
    batch_size=32,
)

In [53]:
act_batch = act_loader.next_batch()

In [54]:
act_batch

[ActivationSample(activation=tensor([ 0.0390,  0.0401, -0.0295,  ...,  0.0522,  0.1310, -0.0154]), context='The city of Dalian is in Saudi Arabia.', query='# Is this statement true? Answer yes or no.\nAnswer:', label='no', layer_name='model.layers.7'),
 ActivationSample(activation=tensor([ 0.0551,  0.0271,  0.0249,  ...,  0.1065,  0.1373, -0.0083]), context='The city of Dakar is in India.', query='# This statement is true. Do you agree? (yes/no) Answer:', label='no', layer_name='model.layers.7'),
 ActivationSample(activation=tensor([-0.0567,  0.2811, -0.0115,  ...,  0.2637,  0.2640,  0.0371]), context='The city of Munich is in Germany.', query='# Is this statement false? (yes/no) Answer:', label='no', layer_name='model.layers.17'),
 ActivationSample(activation=tensor([ 0.0542,  0.1057, -0.1693,  ...,  0.2872,  0.2167,  0.0574]), context='The city of Tucson is in the United States.', query='# This statement is true. Do you agree? Answer yes or no.\nThe answer is:', label='yes', layer_na

In [56]:
prompts = [
    "The land of the rising sun is",
    "The capital of France is",
    "The Space Needle is a tower in"
]

batch_inputs = prepare_input(
    tokenizer=mt,
    prompts=prompts,
)

In [61]:
batch_inputs

{'input_ids': tensor([[128000,    791,   4363,    315,    279,  16448,   7160,    374],
        [128001, 128001, 128000,    791,   6864,    315,   9822,    374],
        [128000,    791,  11746,  89900,    374,    264,  21970,    304]],
       device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [59]:
batch_inputs[1].attention_mask.index(1)

2