In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import torch
import logging

sys.path.append("../")
import src.utils.logging_utils as logging_utils
import src.functional as functional
import src.models as models
import src.tokens as tokens
import src.dataset as dataset
import src.patchscope_utils as patchscope_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

2024-10-24 21:23:33 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"
# MODEL_KEY = "google/gemma-2-9b-it"

# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

mt = models.ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.bfloat16,
)

If not found in cache, model will be downloaded from HuggingFace to cache directory


2024-10-24 21:23:35 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

2024-10-24 21:23:40 src.models INFO     loaded model <meta-llama/Llama-3.1-8B-Instruct> | size: 15316.516 MB | dtype: torch.bfloat16 | device: cuda:0


In [5]:
true_prompt = "The city of Paris is in the country of France."
false_prompt = "The city of Paris is in the country of Italy."
true_input, false_input = [tokens.prepare_input(p, mt) for p in (true_prompt, false_prompt)]

In [None]:
mt.n_layer

32

In [18]:
# layers = [24]
layers = list(range(3, 30))

true_h, false_h = [
    patchscope_utils.get_h_layers(mt, input_, layers) for input_ in (true_input, false_input)
]

In [12]:
true_token, false_token = [mt.tokenizer.encode(t)[-1] for t in ("True", "False")]
[mt.tokenizer.decode(t) for t in (true_token, false_token)]

['True', 'False']

In [4]:
base_prompt = '''The city of Tokyo is in Japan. This statement is: True
The city of Hanoi is in Poland. This statement is: False
placeholder placeholder placeholder placeholder placeholder This statement is: '''

instruct_prompt = '''<|start_header_id|>user<|end_header_id|>

True or false: placeholder placeholder placeholder placeholder placeholder<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>

'''

target_prompt = {
    "meta-llama/Llama-3.2-3B-Instruct" : instruct_prompt,
    "meta-llama/Llama-3.1-8B-Instruct" : instruct_prompt,
    "meta-llama/Llama-3.2-3B" : base_prompt,
    "meta-llama/Llama-3.1-8B" : base_prompt,
}[MODEL_KEY]

In [21]:
for prompt, h in zip((true_prompt, false_prompt), (true_h, false_h)):
    print(prompt)
    _, result_dict = functional.patchscope(
        mt = mt, 
        hs = patchscope_utils.get_h_with_target_layer(h, 3) if len(h) == 1 else h,
        target_prompt = target_prompt,
        interested_tokens = (true_token, false_token),
        k = 5)
    for t in (true_token, false_token):
        print("   ", result_dict[t])


The city of Paris is in the country of France.
    (393, PredictedToken(token='True', prob=8.232022810261697e-05, logit=8.4375, token_id=2575))
    (413, PredictedToken(token='False', prob=7.733269012533128e-05, logit=8.375, token_id=4139))
The city of Paris is in the country of Italy.
    (295, PredictedToken(token='True', prob=0.00011337664182065055, logit=8.4375, token_id=2575))
    (268, PredictedToken(token='False', prob=0.00012847254402004182, logit=8.5625, token_id=4139))


In [15]:
gmt_dataset = dataset.GMTDataset.from_csv("cities.csv", few_shot=False).examples
len(gmt_dataset)

In [None]:
evaluation_results = []

for dataset_name in dataset.GMT_DATA_FILES:
    examples = dataset.GMTDataset.from_csv(dataset_name, few_shot=False).examples[:100]
    evaluation_config = patchscope_utils.EvaluationConfig(
        model_key=MODEL_KEY,
        dataset=dataset_name,
        target_prompt=target_prompt,
        label_to_token={ True: "True", False: "False" },
        patchscope_config=patchscope_utils.PatchscopeConfig(
            source_layers=list(range(3, 30)),
            target_layer=None
        )
    )
    evaluation_runner = patchscope_utils.EvaluationRunner(mt, evaluation_config)
    evaluation_result = evaluation_runner.evaluate(examples)
    # evaluation_result = evaluation_runner.evaluate([(true_prompt, True), (false_prompt, False)])
    evaluation_results.append(evaluation_result)
    print(evaluation_result)

  0%|                                                                                          | 0/100 [00:00<?, ?it/s]

  1%|▊                                                                                 | 1/100 [00:02<03:53,  2.36s/it]