In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
# sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np

import logging
from src.utils import logging_utils
from src.utils import env_utils
from src import functional

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-11-01 16:26:25 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

# MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"
MODEL_KEY = "meta-llama/Llama-3.2-3B"

mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float32,
)

2024-11-01 16:26:26 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.77s/it]

2024-11-01 16:26:30 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 12255.675 MB | dtype: torch.float32 | device: cuda:0





### Loading Data

In [4]:
from src.dataset_manager import DatasetManager

list(DatasetManager.list_datasets_by_group().keys())

2024-11-01 16:27:11 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-11-01 16:27:11 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-11-01 16:27:11 datasets INFO     PyTorch version 2.5.0 available.


['geometry_of_truth',
 'relations',
 'sst2',
 'md_gender',
 'snli',
 'ag_news',
 'ner',
 'tense',
 'language_identification',
 'singular_plural']

In [11]:
from src.dataset_manager import DatasetManager
from src.functional import get_batch_concept_activations

# group_name, ds_name = "singular_plural", "singular_plural"
# group_name = "geometry_of_truth"
# group_name = "relations"
# group_name = "sst2"
# group_name = "md_gender"
# group_name = "singular_plural"
# group_name = "tense"
# group_name="ag_news"
# group_name = "language_identification"
# group_name = "tense"
# group_name = "ner"
group_name = "tense"


if group_name in ["language_identification", "ag_news"]:
    tokenization_kwargs = {
        "padding": "max_length",
        "max_length": 200,
        "truncation": True,
    }
else:
    tokenization_kwargs = {
        "padding": "longest",
    }

dataloader = DatasetManager.from_dataset_group(
    group=group_name,
    batch_size=32
)

# dataloader = DatasetManager.from_named_datasets(
#     [(group_name, "sp_en_trans")],
#     batch_size=32
# )

# batch = next(iter(dataloader))
# batch

In [12]:
batch = next(iter(dataloader))
batch

[ContextQASample(context='The scientists will be conducting experiments to find a cure for a rare disease.', questions=['# Would you say this is written in the future tense?', '# This statement is in the past tense. Do you agree?', '# Can we classify this sentence as being in the future tense?', '# Am I correct in saying this is in the future tense?', '# Is the action taking place in the future tense?', '# Is the narrative presented in the future tense?', '# Is the action described here in the future tense?', '# Can we classify this sentence as being in the past tense?', '# Is the narrative presented in the future tense?', '# Is the verb form in this sentence present?'], answers=['Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No'], ds_label='future'),
 ContextQASample(context='She is eating lunch now.', questions=['# Is the action described here in the past tense?', '# Is the narrative presented in the past tense?', '# Does this text reflect the past tense?', '# Is the a

In [13]:
from src.dataset_manager import ContextQASample
import random

def get_query_and_ans_from_contextQA(context_qa: ContextQASample):
    context = context_qa.context
    q, a = random.choice(list(zip(context_qa.questions, context_qa.answers)))
    q = q.replace("#", "").strip()

    return f"\"{context}\" - {q}, Answer:", a


def get_query_ans_labels_from_batch(batch: list[ContextQASample]):
    batch_qa = [get_query_and_ans_from_contextQA(cqa) for cqa in batch]
    random.shuffle(batch_qa)
    n_icl = 3
    icl_examples = batch_qa[:n_icl]
    batch_qa = batch_qa[n_icl:]

    icl_prompt = "\n".join([f"{q} {a}" for i, (q, a) in enumerate(icl_examples)])
    
    return [
        (f"{icl_prompt}\n{q}", a) for q, a in batch_qa
    ]

batch_qa = get_query_ans_labels_from_batch(batch)
query, ans = batch_qa[0]   
print(query) 
print(ans)

"Before she got the promotion, she had been working hard to prove her capabilities." - Is the action taking place in the past tense?, Answer: Yes
"She had been studying all night before the exam." - Would you identify this as an example of the present tense?, Answer: No
"He will have learned to play the guitar within five years." - Is the action taking place in the future tense?, Answer: Yes
"We hike in the mountains." - Is this statement in the present tense?, Answer:
Yes


In [8]:
from src.functional import predict_next_token
from tqdm import tqdm

def check_model_performance(mt, dataloader, limit = 1000):
    correct_predictions = 0
    total_predictions = 0

    pbar = tqdm(dataloader, ncols=0)
    for batch in pbar:
        batch_qa = get_query_ans_labels_from_batch(batch)
        queries = [q for q, a in batch_qa]
        answers = [a for q, a in batch_qa]
        predections = predict_next_token(
            mt = mt,
            inputs = queries,
            batch_size = 32,
            k=1
        )

        for pred, ans in zip(predections, answers):
            if pred[0].token.strip() == ans.strip():
                correct_predictions += 1
            total_predictions += 1

        pbar.set_description(f"Accuracy: {correct_predictions / total_predictions:.2f} ({correct_predictions}/{total_predictions})")
        if total_predictions >= limit:
            break
    
    return correct_predictions / total_predictions

check_model_performance(mt, dataloader, limit = 5000)

  0% 0/11457 [00:00<?, ?it/s]You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Accuracy: 0.72 (3592/5017):   2% 172/11457 [06:34<7:11:01,  2.29s/it]


0.7159657165636835