In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-28 15:25:53 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

#! torch.adaptive precision
mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float32,
)

2024-10-28 15:25:54 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.64s/it]

2024-10-28 15:25:58 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 12255.675 MB | dtype: torch.float32 | device: cuda:0





In [4]:
from src.utils.typing import LatentCache, LatentCacheCollection
from src.tokens import prepare_input
from src.functional import get_module_nnsight, interpret_logits
from src.functional import get_batch_concept_activations

# prompts = [
#     "Eiffel Tower is in which city? It is in",
#     "A quick brown fox",
#     "The sun rises in the",
# ]

prompts = [
    "The land of the rising sun is",
    "The capital of France is",
    "The Space Needle is a tower in"
]

latents = get_batch_concept_activations(
    mt=mt,
    prompts=prompts,
    interested_layer_indices=list(range(5, 20)),
    check_prediction=None,
    on_token_occur=None,
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [5]:
lcc = LatentCacheCollection(latents=latents)
lcc.detensorize()

In [6]:
with open("lcc_batch/test_lcc.json", "w") as f:
    f.write(lcc.to_json())

In [25]:
from src.dataset_manager import DatasetManager

# group, name = "relations", "factual/country_capital_city"
group, name = "relations", "commonsense/fruit_outside_color"

dataloader = DatasetManager.from_named_datasets(
    [
        (group, name),
    ],
    batch_size=8
)

2024-10-28 16:52:38 src.dataset_manager INFO     Loaded 60 examples from commonsense/fruit_outside_color.


In [26]:
cache_dir = os.path.join(group, name)
os.makedirs(cache_dir, exist_ok=True)

for batch_idx, batch in tqdm(enumerate(dataloader)):
    prompts = [b.context for b in batch]    
    latents = get_batch_concept_activations(
        mt=mt,
        prompts=prompts,
        interested_layer_indices=list(range(5, 20)),
        check_prediction=None,
        on_token_occur=None,
    )

    correct_labels = [b.correct for b in batch]
    incorrect_labels = [b.incorrect for b in batch]

    for latent_cache, correct, incorrect in zip(latents, correct_labels, incorrect_labels):
        latent_cache.correct_label = correct
        latent_cache.incorrect_label = incorrect
        latent_cache.group="relations"
    
    lcc = LatentCacheCollection(latents=latents)
    lcc.detensorize()

    with open(os.path.join(cache_dir, f"batch_{batch_idx}.json"), "w") as f:
        f.write(lcc.to_json())

0it [00:00, ?it/s]

8it [00:04,  1.80it/s]


In [40]:
from dataclasses import dataclass
from src.utils.typing import ArrayLike
from typing import Literal
import random

@dataclass
class ActivationSample:
    activation: ArrayLike
    context: str
    query: str
    label: Literal["yes", "no"]
    layer_name: str | None = None

    def __post_init__(self):
        if isinstance(self.activation, torch.Tensor) == False:
            self.activation = torch.Tensor(self.activation)
        
        assert self.label in ["yes", "no"]
        assert "#" in self.query
    

class ActivationLoader:
    def __init__(self, latent_cache_files: str, shuffle: bool = True, batch_size: int = 32):
        self.latent_cache_files = []
        for file_path in latent_cache_files:
            if os.path.exists(file_path) == False:
                logger.error(f"{file_path} not found")
                continue    
            if os.path.isdir(file_path) == True:
                raise logger.error(f"{file_path} should be a json file")
            self.latent_cache_files.append(file_path)
        
        if shuffle:
            random.shuffle(self.latent_cache_files)

        self.current_file_idx = 0
        self.buffer: list[ActivationSample] = []
        self.batch_size = batch_size
        self.stop_iteration = False

        with open(
            os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/yes_no.json"), "r"
        ) as f:
            self.YES_NO_PARAPHRASES = json.load(f)

        with open(
            os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/question.json"), "r"
        ) as f:
            self.QUESTION_PARAPHRASES = json.load(f)
    
    def get_latent_qa(self, correct_label, wrong_label, group) -> tuple[str, Literal["yes", "no"]]:
        label = random.choice(["yes", "no"])
        yes_no = random.choice(self.YES_NO_PARAPHRASES)
        question = random.choice(self.QUESTION_PARAPHRASES[group])
        
        query = "# "
        question = question.format(correct_label) if label == "yes" else question.format(wrong_label)
        query += question + f" {yes_no}"
        
        return query, label


    def load_next_file(self):
        if self.current_file_idx >= len(self.latent_cache_files):
            return False

        with open(self.latent_cache_files[self.current_file_idx], "r") as f:
            lcc = LatentCacheCollection.from_json(f.read())
        
        add_to_buffer = []
        for latent_cache in lcc.latents:
            for layer_name in latent_cache.latents.keys():
                activation = latent_cache.latents[layer_name]
                query, label = self.get_latent_qa(
                    correct_label=latent_cache.correct_label,
                    wrong_label=latent_cache.incorrect_label,
                    group=latent_cache.group,
                )
                add_to_buffer.append(ActivationSample(
                    activation=activation,
                    context=latent_cache.context,
                    query=query,
                    label=label,
                    layer_name=layer_name,
                ))
        
        random.shuffle(add_to_buffer)
        self.buffer.extend(add_to_buffer)
        self.current_file_idx += 1
        return True
    
    def next_batch(self):
        if self.stop_iteration:
            raise StopIteration
        
        if len(self.buffer) < self.batch_size:
            if self.load_next_file() == False:
                self.stop_iteration = True      # will raise StopIteration next time

        batch = self.buffer[:self.batch_size]
        self.buffer = self.buffer[self.batch_size:]

        return batch

In [41]:
relation_root = "relations"

def get_batch_paths(root):
    for root, _, files in os.walk(root):
        for file in files:
            if file.endswith(".json"):
                yield os.path.join(root, file)


batch_paths = list(get_batch_paths(relation_root))
batch_paths

['relations/commonsense/fruit_outside_color/batch_6.json',
 'relations/commonsense/fruit_outside_color/batch_0.json',
 'relations/commonsense/fruit_outside_color/batch_1.json',
 'relations/commonsense/fruit_outside_color/batch_5.json',
 'relations/commonsense/fruit_outside_color/batch_2.json',
 'relations/commonsense/fruit_outside_color/batch_4.json',
 'relations/commonsense/fruit_outside_color/batch_7.json',
 'relations/commonsense/fruit_outside_color/batch_3.json',
 'relations/factual/country_capital_city/batch_0.json',
 'relations/factual/country_capital_city/batch_1.json',
 'relations/factual/country_capital_city/batch_5.json',
 'relations/factual/country_capital_city/batch_2.json',
 'relations/factual/country_capital_city/batch_4.json',
 'relations/factual/country_capital_city/batch_3.json']

In [42]:
act_loader = ActivationLoader(
    latent_cache_files=batch_paths,
    shuffle=True,
    batch_size=32,
)

In [43]:
act_batch = act_loader.next_batch()

In [46]:
DatasetManager.list_datasets_by_group()

{'geometry_of_truth': ['sp_en_trans',
  'neg_sp_en_trans',
  'cities',
  'neg_cities',
  'smaller_than',
  'larger_than',
  'common_claim_true_false',
  'companies_true_false',
  'counterfact_true_false'],
 'sst2': ['sst2'],
 'relations': ['commonsense/word_sentiment',
  'commonsense/fruit_outside_color',
  'commonsense/task_done_by_person',
  'commonsense/work_location',
  'commonsense/task_done_by_tool',
  'commonsense/substance_phase',
  'commonsense/object_superclass',
  'commonsense/fruit_inside_color',
  'factual/pokemon_evolutions',
  'factual/country_capital_city',
  'factual/person_plays_pro_sport',
  'factual/star_constellation',
  'factual/country_language',
  'factual/presidents_birth_year',
  'factual/landmark_on_continent',
  'factual/country_largest_city',
  'factual/company_hq',
  'factual/food_from_country',
  'factual/landmark_in_country',
  'factual/company_ceo',
  'factual/superhero_archnemesis',
  'factual/city_in_country',
  'factual/person_band_lead_singer',
  'f

In [47]:
mt.n_layer

28