In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-27 19:38:02 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

#! torch.adaptive precision
mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float32,
)

2024-10-27 19:38:03 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.73s/it]

2024-10-27 19:38:07 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 12255.675 MB | dtype: torch.float32 | device: cuda:0





In [4]:
from transformers import AutoModelForCausalLM
from nnsight import LanguageModel

finetuned_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, MODEL_KEY.split("/")[-1], "patchscope_tuning"
)
finetuned_path = os.path.join(finetuned_path, os.listdir(finetuned_path)[-1])
tuned_model = AutoModelForCausalLM.from_pretrained(
    finetuned_path, torch_dtype=torch.float32
).to("cuda")

tuned_lm = LanguageModel(tuned_model)
patchscope = ModelandTokenizer(
    base_lm = tuned_lm,
    tokenizer = mt.tokenizer
)

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  9.63it/s]


2024-10-27 19:38:10 src.models INFO     loaded model <EleutherAI/gpt-j-6B> | size: 12255.675 MB | dtype: torch.float32 | device: cuda:0


### Utils

In [5]:
import random
from typing import Literal

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/yes_no.json"), "r") as f:
    yes_no_paraphrases = json.load(f)

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/question.json"), "r") as f:
    question_paraphrases = json.load(f)["GMT"]

def get_latent_qa(yes_ans, no_ans) -> tuple[str, Literal[" yes", " no"]]:
    label = random.choice([" yes", " no"])
    ret = "# "
    yes_no = random.choice(yes_no_paraphrases)
    question = random.choice(question_paraphrases)
    question = question.format(yes_ans) if label == " yes" else question.format(no_ans)
    ret += question + f" {yes_no}"
    return ret, label

In [6]:
from dataclasses import dataclass, field, fields
from src.utils.typing import ArrayLike
from typing import Literal
from scripts.patchscope_tuning import LatentSample, LatentSampleBuffer
from src.utils.typing import LatentCache, LatentCacheCollection

def populate_latent_arr(
    lcc: LatentCacheCollection,
    layers_of_interest: list[int] = list(range(10, 15)),
):
    latent_arr = []

    for idx in range(len(lcc.latents)):
        latent_cache = lcc.latents[idx]
        prompt = latent_cache.context
        label = latent_cache.answer
        yes_ans = "true" if str(label).lower().strip() == "true" else "false"
        no_ans = "false" if yes_ans == "true" else "true"
        for layer_idx in layers_of_interest:
            layer_name = mt.layer_name_format.format(layer_idx)
            activation = latent_cache.latents[layer_name]
            prompt, label = get_latent_qa(yes_ans, no_ans)
            latent_arr.append(LatentSample(
                activation=activation,
                prompt=prompt,
                label=label,
            ))
    return latent_arr

2024-10-27 19:38:10 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-10-27 19:38:10 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-10-27 19:38:10 datasets INFO     PyTorch version 2.5.0 available.
2024-10-27 19:38:10 scripts.patchscope_tuning INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'
2024-10-27 19:38:10 scripts.patchscope_tuning INFO     torch.cuda.is_available()=True, torch.cuda.device_count()=1, torch.cuda.get_device_name()='NVIDIA RTX A6000'
2024-10-27 19:38:10 scripts.patchscope_tuning INFO     transformers.__version__='4.46.0'


In [7]:
from src.tokens import prepare_input, find_token_range
from src.functional import interpret_logits, get_module_nnsight

def prepare_batch_input(batch: list[LatentSample], mt: ModelandTokenizer):
    batch_prompts = [b.prompt for b in batch]
    batch_tokenized = prepare_input(
        prompts=batch_prompts,
        tokenizer=mt,
        return_offset_mapping=True
    )

    int_tok_idx = []
    for idx in range(len(batch)):
        offset_mapping = batch_tokenized["offset_mapping"][idx]
        act_range = find_token_range(
            string=batch[idx].prompt,
            substring="#",
            occurrence=0,
            tokenizer=mt,
            offset_mapping=offset_mapping
        )
        int_tok_idx.append(act_range[1] - 1)

    batch_tokenized.pop("offset_mapping")

    return batch_tokenized, int_tok_idx

In [8]:
@torch.inference_mode()
def evaluate(batch: list[LatentSample], mt: ModelandTokenizer):

    batch_tokenized, int_tok_idx = prepare_batch_input(batch, mt)
    logger.debug(f"{batch_tokenized.input_ids.shape=}")
    activations = [b.activation for b in batch]

    with mt.trace(batch_tokenized):
        module_names = mt.layer_names
        for idx, act, int_tok in zip(range(len(batch)), activations, int_tok_idx):
            for module_name in module_names:
                module = get_module_nnsight(mt, module_name)
                module.output[0][idx, int_tok, :] = torch.tensor(act, device=mt.device)
        last_logits = [
            mt.output.logits[idx, -1, :].save()
            for idx in range(len(batch))
        ]
        output = mt.output.save()

    last_logits = torch.stack(last_logits)

    predicted_labels = [
        interpret_logits(
            tokenizer=mt,
            logits=last_logits[idx],
            # logits = output.logits[idx, -1, :],
            k = 2,
        )[0]
        for idx in range(len(batch))
    ]

    correct_labels = [b.label for b in batch]
    correct_count = 0

    for pred, correct in zip(predicted_labels, correct_labels):
        print(f"{str(pred)=} | {correct=}")
        if pred.token.strip().lower() == correct.strip().lower():
            correct_count += 1

    return correct_count / len(batch)

### Check Performance

In [9]:
from src.functional import get_concept_latents

test = [
    ("The sky is blue.",  " yes"),
    ("The sky is green.", " no"),
    ("The sun rises in the north.", " no"),
    ("Michael Jordan used to play Cricket.", " no"),
    ("The capital of France is Paris.", " yes"),
    ("The capital of France is Berlin.", " no"),
]

activations = get_concept_latents(
    mt=mt, 
    queries=test, 
    interested_layers=[mt.layer_name_format.format(l) for l in range(10, 15)],
    check_answer=False,
)

lcc = LatentCacheCollection(latents=activations)

latent_arr = populate_latent_arr(lcc)

  0%|          | 0/6 [00:00<?, ?it/s]You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


100%|██████████| 6/6 [00:02<00:00,  2.18it/s]

2024-10-27 19:38:16 src.functional DEBUG    Collected 6 latents, out of 6





In [10]:
latent_arr[5]

LatentSample(activation=tensor([-0.1089, -0.2079,  0.2880,  ..., -0.2096,  0.1462, -0.1288],
       device='cuda:0'), prompt='# Is this statement true? Answer yes or no.\nAnswer:', label=' no')

In [11]:
buffer = LatentSampleBuffer(latent_arr, batch_size=32)
len(buffer)

30

In [12]:
batch = next(iter(buffer))
len(batch)

30

In [13]:
evaluate(batch, mt)  # Evaluate the batch

2024-10-27 19:38:26 __main__ DEBUG    batch_tokenized.input_ids.shape=torch.Size([30, 21])


  output = self.target(*args, **kwargs)


str(pred)='" yes" (p=0.186)' | correct=' no'
str(pred)='" No" (p=0.139)' | correct=' yes'
str(pred)='" Yes" (p=0.310)' | correct=' no'
str(pred)='" Yes" (p=0.243)' | correct=' no'
str(pred)='" no" (p=0.186)' | correct=' yes'
str(pred)='" No" (p=0.220)' | correct=' no'
str(pred)='" Yes" (p=0.131)' | correct=' yes'
str(pred)='" yes" (p=0.268)' | correct=' no'
str(pred)='" yes" (p=0.208)' | correct=' yes'
str(pred)='" Yes" (p=0.237)' | correct=' no'
str(pred)='" Yes" (p=0.216)' | correct=' no'
str(pred)='" yes" (p=0.204)' | correct=' yes'
str(pred)='" yes" (p=0.318)' | correct=' no'
str(pred)='" no" (p=0.169)' | correct=' yes'
str(pred)='" This" (p=0.187)' | correct=' yes'
str(pred)='" Yes" (p=0.232)' | correct=' yes'
str(pred)='" Yes" (p=0.323)' | correct=' no'
str(pred)='" Yes" (p=0.211)' | correct=' no'
str(pred)='" This" (p=0.353)' | correct=' yes'
str(pred)='" Yes" (p=0.321)' | correct=' no'
str(pred)='" Yes" (p=0.167)' | correct=' yes'
str(pred)='" Yes" (p=0.159)' | correct=' yes'
s

0.26666666666666666

In [14]:
patchscope.name = f"Patchscope_{MODEL_KEY.split('/')[-1]}"
patchscope.name

'Patchscope_Llama-3.2-3B'

In [15]:
evaluate(batch, patchscope)  # Evaluate the batch again

2024-10-27 19:38:30 __main__ DEBUG    batch_tokenized.input_ids.shape=torch.Size([30, 21])


str(pred)='" yes" (p=0.963)' | correct=' no'
str(pred)='" yes" (p=0.973)' | correct=' yes'
str(pred)='" no" (p=0.749)' | correct=' no'
str(pred)='" no" (p=0.832)' | correct=' no'
str(pred)='" no" (p=0.987)' | correct=' yes'
str(pred)='" no" (p=1.000)' | correct=' no'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" no" (p=1.000)' | correct=' no'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" no" (p=1.000)' | correct=' no'
str(pred)='" no" (p=0.998)' | correct=' no'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" no" (p=0.992)' | correct=' no'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" no" (p=0.996)' | correct=' no'
str(pred)='" no" (p=0.998)' | correct=' no'
str(pred)='" yes" (p=0.999)' | correct=' yes'
str(pred)='" no" (p=1.000)' | correct=' no'
str(pred)='" yes" (p=1.000)' | correct=' yes'
str(pred)='" yes" (p=0.995)' | correct=' yes'
str(pred)=

0.9

In [25]:
prompt = "The Space Needle is located"
inputs = prepare_input(prompt, mt)

with torch.inference_mode():
    output = mt._model.generate(**inputs, max_new_tokens=10, do_sample=False)

print(mt.tokenizer.decode(output[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


The Space Needle is located in Seattle, Washington. It is a 605


In [27]:
with torch.inference_mode():
    output = patchscope._model.generate(**inputs, max_new_tokens=10, do_sample=False)

print(patchscope.tokenizer.decode(output[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.


The Space Needle is located yes yes yes no no no no no no no


### LoRA (check later)

In [4]:
# from peft import LoraConfig, get_peft_model

# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.1,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

# model = get_peft_model(mt._model, lora_config)

In [5]:
# type(model), type(model.model)

In [6]:
# for p in model.model.named_parameters():
#     print(p[0])

### Dataset Preparation

In [9]:
from src.dataset import GMTDataset

ds = GMTDataset.from_csv(
    [
        # "sp_en_trans.csv", 
        "cities.csv"
    ], 
    # "sp_en_trans"
    "cities"
)

ds.select_few_shot(0)

queries = [ds.examples[i] for i in range(len(ds))]
interested_layers = mt.layer_names
q, a = queries[0]
print(q)
print(a)

2024-10-27 13:20:38 src.dataset INFO     initialized cities with 1493 examples.
The city of Moradabad is in India.
True


In [None]:
from src.functional import get_concept_latents

# don't need to run this if you already have cached the results

latents = get_concept_latents(
    mt=mt, 
    queries=queries, 
    interested_layers=interested_layers,
    check_answer=False,
)

In [10]:
latent_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, 
    "cached_latents",
    MODEL_KEY.split("/")[-1],
)
os.makedirs(latent_dir, exist_ok=True)

In [8]:
# caching the latents

from src.utils.typing import LatentCacheCollection
from src.utils import env_utils

lcc = LatentCacheCollection(latents=latents)
lcc.detensorize()
with open(os.path.join(latent_dir, f"{ds.name}.json"), "w") as f:
    f.write(lcc.to_json())

In [11]:
# loading the latents

from src.utils.typing import LatentCacheCollection
from src.utils import env_utils

with open(os.path.join(latent_dir, f"{ds.name}.json"), "r") as f:
    dct = json.load(f)
lcc = LatentCacheCollection.from_dict(dct)
lcc.retensorize(device=mt.device)

In [5]:
# q, a = queries[10]
# print(q, a)
# yes_ans = "true" if str(a).lower().strip() == "true" else "false"
# no_ans = "false" if yes_ans == "true" else "true"
# latent_q = get_latent_qa(yes_ans, no_ans)
# print(latent_q)

### Patchscope Finetuning

In [None]:
experiment_utils.set_seed(42)
model = mt._model
model.train()
device = mt.device

# Training parameters
learning_rate = 5e-5

model_save_dir = os.path.join(env_utils.DEFAULT_RESULTS_DIR, "patchscope_tuning")
os.makedirs(model_save_dir, exist_ok=True)
log_steps = 1
checkpoint_interval = 100
num_warmup_steps = 30
limit_training_steps = 1000
##############################################################################

In [20]:
import shutil
def remove_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)

# remove_dir(model_save_dir)
remove_dir(".wandb")

In [None]:
from transformers import get_linear_schedule_with_warmup

buffer_steps = len(buffer) // buffer.batch_size
limit_training_steps = min(
    limit_training_steps,
    buffer_steps
)
logger.info(f"{limit_training_steps=}")

# Optimizer
# TODO: finetune only last 10 layers
import baukit
tunable_params = []

for layer_name in mt.layer_names[-10:]:
    module = baukit.get_module(model, layer_name)
    tunable_params.extend(list(module.parameters()))

optimizer = torch.optim.AdamW(tunable_params, lr=learning_rate)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(num_warmup_steps, limit_training_steps // 10),
    num_training_steps=limit_training_steps,
)

loss_func = torch.nn.CrossEntropyLoss()

In [None]:
from src.functional import free_gpu_cache

# wandb
wandb_logging = False

if wandb_logging:
    wandb.init(
        entity="dl-homeworks",
        project="talkative_probes",
        name=f"{MODEL_KEY}_finetune",
        config={
            "model_key": MODEL_KEY,
            "learning_rate": learning_rate,
            "wandb_log_interval": log_steps,
            "checkpoint_interval": checkpoint_interval,
            "num_warmup_steps": num_warmup_steps,
            "batch_size": buffer.batch_size,
        }
    )


mt._model.train()
for step in tqdm(range(limit_training_steps), desc="Training"):
    optimizer.zero_grad()

    try:
        batch = buffer.next_batch()
    except StopIteration:
        buffer.idx = 0
        batch = buffer.next_batch()
    
    batch_tokenized, int_tok_idx = prepare_batch_input(batch, mt)
    logger.info(batch_tokenized.input_ids.shape)
    activations = [b.activation for b in batch]

    with mt.trace(batch_tokenized):
        module_names = mt.layer_names # replace the latent on all the residual layers
        for idx, act, int_tok in zip(range(len(batch)), activations, int_tok_idx):
            for module_name in module_names:
                module = get_module_nnsight(mt, module_name)
                module.output[0][idx, int_tok, :] = torch.tensor(act, device=mt.device).to(mt.dtype)
        
        # output = mt.output.save()
        last_logits = [
            mt.output.logits[idx, -1, :].save()
            for idx in range(len(batch))
        ]

    last_logits = torch.stack(last_logits)
    batch_labels = [
        mt.tokenizer(b.label).input_ids[-1] for b in batch
    ]
    batch_labels = torch.tensor(batch_labels, device=mt.device)

    # Cross-entropy loss
    patchscope_loss = loss_func(last_logits, batch_labels)
    
    # TODO: include natural text and generation loss
    loss = patchscope_loss
    
    loss.backward()
    # break
    optimizer.step()
    scheduler.step()

    free_gpu_cache()

    if (step + 1) % log_steps == 0:
        log_data = {
            "loss": loss.item(),
            "learning_rate": scheduler.get_last_lr()[0],
        }
        logger.info(f"Step {step + 1}: {log_data}")
        if wandb_logging:
            wandb.log(log_data)

    if ((step + 1) % checkpoint_interval == 0) or (step + 1) == limit_training_steps:
        if len(os.listdir(model_save_dir)) > 0:
            last_checkpoint_path = os.path.join(model_save_dir, os.listdir(model_save_dir)[-1])
            remove_dir(last_checkpoint_path)
        
        new_checkpoint_path = os.path.join(model_save_dir, f"checkpoint-{step + 1}")
        model.save_pretrained(new_checkpoint_path)


print("Training completed!")