In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-30 12:19:07 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

#! torch.adaptive precision
mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float16,
)

2024-10-30 12:19:09 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.46s/it]

2024-10-30 12:19:12 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 6127.841 MB | dtype: torch.float16 | device: cuda:0





In [5]:
from src.dataset_manager import DatasetManager
dataloader = DatasetManager.from_named_datasets(
    [("geometry_of_truth", "cities")],
    batch_size=32
)

In [6]:
batch = next(iter(dataloader))

In [7]:
batch

[ContextQASample(context="The city of Ya'an is in Morocco.", questions=['# Would you validate this statement?', '# Do you find this to be truthful?', '# Do you find this to be truthful?', '# Have I understood this correctly?', '# Could you tell me if this is factual?', '# May I assume this is correct?', '# Would you validate this statement?', '# Are you confirming this is true?', '# Do you find this to be truthful?', '# Are you saying this is correct?'], answers=['No', 'No', 'No', 'No', 'No', 'No', 'No', 'No', 'No', 'No']),
 ContextQASample(context='The city of Xinyang is in Japan.', questions=['# Would you validate this statement?', '# Is this statement true?', '# Do you agree that this is true?', '# Can we confirm this is accurate?', '# Would you say this is accurate?', '# Can you verify if this is true?', '# Does this statement hold true?', '# Can we confirm this is accurate?', '# Does this statement hold true?', '# Does this information ring true?'], answers=['No', 'No', 'No', 'No'

In [4]:
from src.utils.typing import LatentCache, LatentCacheCollection
from src.tokens import prepare_input
from src.functional import get_module_nnsight, interpret_logits
from src.functional import get_batch_concept_activations

# prompts = [
#     "Eiffel Tower is in which city? It is in",
#     "A quick brown fox",
#     "The sun rises in the",
# ]

prompts = [
    "The land of the rising sun is",
    "The capital of France is",
    "The Space Needle is a tower in"
]

latents = get_batch_concept_activations(
    mt=mt,
    prompts=prompts,
    interested_layer_indices=list(range(5, 20)),
    check_prediction=None,
    on_token_occur=None,
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [5]:
lcc = LatentCacheCollection(latents=latents)
lcc.detensorize()

In [7]:
from src.dataset_manager import DatasetManager

# group, name = "relations", "factual/country_capital_city"
group, name = "relations", "commonsense/fruit_outside_color"

dataloader = DatasetManager.from_named_datasets(
    [
        (group, name),
    ],
    batch_size=8
)

2024-10-28 23:57:21 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-10-28 23:57:21 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-10-28 23:57:21 datasets INFO     PyTorch version 2.5.0 available.
2024-10-28 23:57:21 src.dataset_manager INFO     Loaded 60 examples from commonsense/fruit_outside_color.


In [8]:
cache_dir = os.path.join(group, name)
os.makedirs(cache_dir, exist_ok=True)

for batch_idx, batch in tqdm(enumerate(dataloader)):
    prompts = [b.context for b in batch]    
    latents = get_batch_concept_activations(
        mt=mt,
        prompts=prompts,
        interested_layer_indices=list(range(5, 20)),
        check_prediction=None,
        on_token_occur=None,
    )

    correct_labels = [b.correct for b in batch]
    incorrect_labels = [b.incorrect for b in batch]

    for latent_cache, correct, incorrect in zip(latents, correct_labels, incorrect_labels):
        latent_cache.correct_label = correct
        latent_cache.incorrect_label = incorrect
        latent_cache.group="relations"
    
    lcc = LatentCacheCollection(latents=latents)
    lcc.detensorize()

    with open(os.path.join(cache_dir, f"batch_{batch_idx}.json"), "w") as f:
        f.write(lcc.to_json())

0it [00:00, ?it/s]

0it [00:00, ?it/s]


AttributeError: 'ContextQASample' object has no attribute 'correct'

In [48]:
from src.activation_manager import ActivationLoader

In [13]:
from src.activation_manager import ActivationLoader, get_batch_paths

latent_root = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, 
    "cached_latents", MODEL_KEY.split("/")[-1], "geometry_of_truth"
)

activation_batch_paths = list(get_batch_paths(latent_root))
activation_batch_paths

['/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_6.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_41.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_42.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_34.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_12.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_17.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_of_truth/cities/batch_28.json',
 '/home/local_arnab/Codes/Projects/talkative_probes/results/cached_latents/Llama-3.2-3B/geometry_o

In [17]:
act_loader = ActivationLoader(
    latent_cache_files=activation_batch_paths,
    shuffle=True,
    batch_size=32,
)

In [18]:
act_batch = act_loader.next_batch()

In [19]:
act_batch

[ActivationSample(activation=tensor([ 0.0320,  0.0164,  0.0691,  ..., -0.0952,  0.4255, -0.1327]), context='Fifty-six is smaller than seventy-six.', question='# Are these facts correct?', label=' Yes', layer_name='model.layers.15'),
 ActivationSample(activation=tensor([ 0.0300,  0.2755,  0.0297,  ...,  0.0331,  0.3606, -0.1069]), context='Fifty-six is smaller than seventy-six.', question='# May I assume this is correct?', label=' Yes', layer_name='model.layers.16'),
 ActivationSample(activation=tensor([ 0.1466, -0.1212,  0.1019,  ...,  0.0743,  0.1145, -0.0166]), context='Seventy-eight is smaller than eighty-nine.', question='# Can you confirm the accuracy of this?', label=' Yes', layer_name='model.layers.11'),
 ActivationSample(activation=tensor([ 0.0613,  0.1951,  0.0546,  ..., -0.0468,  0.4112, -0.1386]), context='Seventy-nine is smaller than fifty-one.', question='# May I assume this is correct?', label=' No', layer_name='model.layers.16'),
 ActivationSample(activation=tensor([ 0.0

In [33]:
prompts = [
    "# Does this passage concern a male person?",
    "# Does this text explore Business?",
    "# Is the subject of this sentence female?",
    "# Does this review indicate a poor experience?",
    "# Is this talking about a female person?",
    "# Does the reviewer appear displeased with their experience?",
    "# Does this refer to a male individual?",
    "# Is the overall message of this review negative?",
    "# Can this article be categorized under World News?",
    "# Can this article be categorized under Business?",
    "# In terms of content, does this piece deal with World News?",
    "# Is Central African Republic mentioned anywhere in the passage?",
    "# Could you tell me if this is factual?",
    "# Is Rally of in any part of this passage?",
    "# Is Israeli cited in this piece?",
    "# Is there a positive tone throughout this review?",
    "# Does this describe someone who is female?",
    "# Does this text explore Sports?",
    "# Is Japanese referenced in this text?",
    "# Could this be summarized as an article about Business?",
    "# Can you find any references to since November in this text?",
    "# Would you classify this as an endorsement?",
    "# Are you confirming this is true?",
    "# Is 1498 discussed anywhere in this content?",
    "# Does this text refer to March?",
    "# Is there a substantial focus on Business in this article?",
    "# Is the general sentiment of this review negative?",
    # "# Does Rwanda \x92s appear in this text?",
    "# Is the sentiment in this review on the negative side?",
    "# Does this article delve into matters concerning Science/Technology?",
    "# Is there a substantial focus on World News in this article?",
    "# Is this sentence about a female person?",
]

In [41]:
from src.dataset_manager import DatasetManager
dataloader = DatasetManager.from_named_datasets(
    # [("geometry_of_truth", "cities")],
    # [("sst2", "sst2")],
    # [("relations", 'factual/country_capital_city')],
    # [("tense", "tense")],
    [("language_identification", "language_identification")],
    batch_size=5
)
batch = next(iter(dataloader))
batch

2024-10-30 12:35:23 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-10-30 12:35:23 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-10-30 12:35:23 datasets INFO     PyTorch version 2.5.0 available.
2024-10-30 12:35:23 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443
2024-10-30 12:35:23 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/FrancophonIA/WiLI-2018 HTTP/11" 200 1480
2024-10-30 12:35:23 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): s3.amazonaws.com:443
2024-10-30 12:35:23 urllib3.connectionpool DEBUG    https://s3.amazonaws.com:443 "HEAD /datasets.huggingface.co/datasets/datasets/FrancophonIA/WiLI-2018/FrancophonIA/WiLI-2018.py HTTP/11" 404 0
2024-10-30 12:35:23 urllib3.connectionpool DEBUG    https://huggingface.co:443 "GET /api/datasets/FrancophonIA/WiLI-2018 HTTP/11" 200 1480
2024-10-30 12:35:23 urllib3

[ContextQASample(context='perhitungan dasar yang berlaku sejak zaman pertengahan adalah paskah dirayakan pada hari minggu setelah bulan purnama pertama setelah hari pertama musim semi vernal equinox kalimat tersebut sebenarnya tidak tepat benar dengan sistem perhitungan gerejawi', questions=['# Is this text predominantly in Korean?', '# Is this text predominantly in Indonesian?', '# Is this passage composed in Chinese?', '# Is this passage composed in Swedish?', '# Is this written in the language Romanian?', '# Would you classify this as written in Indonesian?', '# Is this text written in Indonesian?', '# Can we confirm this is in Urdu?', '# Is the text presented in Indonesian?', '# Is this text predominantly in Russian?'], answers=['No', 'Yes', 'No', 'No', 'No', 'Yes', 'Yes', 'No', 'Yes', 'No']),
 ContextQASample(context='denna typ av kommersiell och välproducerad rock fortsatte på deras nästa skiva outside inside som kom  på denna skiva kom deras första top -hit på billboardlistan "s

In [44]:
prompts = [
    b.context for b in batch
]
prompts

['perhitungan dasar yang berlaku sejak zaman pertengahan adalah paskah dirayakan pada hari minggu setelah bulan purnama pertama setelah hari pertama musim semi vernal equinox kalimat tersebut sebenarnya tidak tepat benar dengan sistem perhitungan gerejawi',
 'denna typ av kommersiell och välproducerad rock fortsatte på deras nästa skiva outside inside som kom  på denna skiva kom deras första top -hit på billboardlistan "shes a beauty" en låt där för övrigt totos dåvarande gitarrist steve lukather medverkade  kom love bomb producerad av todd rundgren som de tidigare arbetat med på remote control skivan blev en total flopp och gruppen splittrades kort därefter',
 'لا يعتمد اختيار أي نظام اتصال من بعد لتحقيق أغراض معينة على التقنية الممكن استخدامها وكلفتها فقط بل على الكلفة النسبية للحلول المختلفة المتوافرة ويلجأ للموازنة بين مختلف الحلول والمنظومات المحتملة اقتصادياً إلى دراسة القيمة الحالية للكلفة السنوية present value of annoal chargespvac التي تراعي كلاً من رأس المال والنفقات المستمرة

In [45]:
from src.tokens import prepare_input, find_token_range

# prompts = [
#     "The land of",
#     "The capital of France is",
#     "This is a"
# ]

batch_inputs = prepare_input(
    tokenizer=mt,
    prompts=prompts,
    padding_side="left",
    # padding="max_length",
    # max_length=20,
    truncation=True,
    return_offsets_mapping=True
)

batch_inputs

{'input_ids': tensor([[128001, 128001, 128001,  ...,  67387,     73,  41978],
        [128001, 128001, 128001,  ...,    294,  47786,   1064],
        [128001, 128001, 128001,  ...,  74541, 115315, 103352],
        [128001, 128001, 128001,  ...,  73394,  33054,    324],
        [128000, 102393, 116333,  ...,  44747,  85410, 100436]],
       device='cuda:0'), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]], device='cuda:0'), 'offset_mapping': tensor([[[   0,    0],
         [   0,    0],
         [   0,    0],
         ...,
         [ 244,  249],
         [ 249,  250],
         [ 250,  253]],

        [[   0,    0],
         [   0,    0],
         [   0,    0],
         ...,
         [ 397,  399],
         [ 399,  402],
         [ 402,  406]],

        [[   0,    0],
         [   0,    0],
         [   0,    0],
         ...,
         [ 565,  567],
 

In [46]:
batch_inputs.input_ids.shape

torch.Size([5, 589])

In [30]:
[
    mt.tokenizer.decode(inp, skip_special_tokens=False)
    for inp in batch_inputs["input_ids"]
]

['<|end_of_text|><|begin_of_text|>The land of',
 '<|begin_of_text|>The capital of France',
 '<|end_of_text|><|begin_of_text|>This is a']

In [32]:
find_token_range(
    string=prompts[1],
    substring="is",
    tokenizer=mt,
    offset_mapping=batch_inputs["offset_mapping"][1]
)

AssertionError: Are you working with Llama-3? Try passing the ModelandTokenizer object as the tokenizer

In [48]:
list(DatasetManager.list_datasets_by_group().keys())

['geometry_of_truth',
 'relations',
 'sst2',
 'md_gender',
 'snli',
 'ag_news',
 'ner',
 'tense',
 'language_identification',
 'singular_plural']