In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-25 17:30:30 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float32,
)

2024-10-25 17:30:32 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.73s/it]

2024-10-25 17:30:35 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 12255.675 MB | dtype: torch.float32 | device: cuda:0





### LoRA (check later)

In [4]:
# from peft import LoraConfig, get_peft_model

# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.1,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

# model = get_peft_model(mt._model, lora_config)

In [5]:
# type(model), type(model.model)

In [6]:
# for p in model.model.named_parameters():
#     print(p[0])

### Dataset Preparation

In [5]:
from src.dataset import GMTDataset

ds = GMTDataset.from_csv(
    [
        "sp_en_trans.csv", 
        # "cities.csv"
    ], 
    "sp_en_trans"
)

ds.select_few_shot(0)

queries = [ds.examples[i] for i in range(len(ds))]
interested_layers = mt.layer_names
q, a = queries[0]
print(q)
print(a)

2024-10-25 17:30:42 src.dataset INFO     initialized sp_en_trans with 348 examples.
The Spanish word 'verano' means 'to push'.
False


In [6]:
ds.examples[4]

("The Spanish word 'lobo' means 'leg'.", False)

In [7]:
from src.functional import get_concept_latents

latents = get_concept_latents(
    mt=mt, 
    queries=queries, 
    interested_layers=interested_layers,
    check_answer=False,
)

  0%|          | 0/351 [00:00<?, ?it/s]You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 351/351 [04:14<00:00,  1.38it/s]

2024-10-25 17:35:12 src.functional DEBUG    Collected 351 latents, out of 351





In [8]:
from src.utils.typing import LatentCacheCollection
from src.utils import env_utils

latent_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, 
    "cached_latents",
    MODEL_KEY.split("/")[-1],
)
os.makedirs(latent_dir, exist_ok=True)


lcc = LatentCacheCollection(latents=latents)
lcc.detensorize()
with open(os.path.join(latent_dir, f"{ds.name}.json"), "w") as f:
    f.write(lcc.to_json())

In [9]:
with open(os.path.join(latent_dir, f"{ds.name}.json"), "r") as f:
    dct = json.load(f)
lcc = LatentCacheCollection.from_dict(dct)
lcc.retensorize(device=mt.device)

In [104]:
import random
from typing import Literal

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/yes_no.json"), "r") as f:
    yes_no_paraphrases = json.load(f)

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/question.json"), "r") as f:
    question_paraphrases = json.load(f)["GMT"]

def get_latent_qa(yes_ans, no_ans) -> tuple[str, Literal[" yes", " no"]]:
    label = random.choice([" yes", " no"])
    ret = "# "
    yes_no = random.choice(yes_no_paraphrases)
    question = random.choice(question_paraphrases)
    question = question.format(yes_ans) if label == " yes" else question.format(no_ans)
    ret += question + f" {yes_no}"
    return ret, label



q, a = queries[10]
print(q, a)
yes_ans = "true" if str(a).lower().strip() == "true" else "false"
no_ans = "false" if yes_ans == "true" else "true"
latent_q = get_latent_qa(yes_ans, no_ans)
print(latent_q)

The Spanish word 'continuar' means 'coat'. False
('# Is this false? (yes/no) Answer:', ' yes')


In [105]:
from dataclasses import dataclass, field, fields
from src.utils.typing import ArrayLike
from typing import Literal

@dataclass(frozen=True)
class LatentSample:
    activation: ArrayLike
    prompt: str
    label: Literal[" yes", " no"]

    def __post_init__(self):
        assert self.label in [" yes", " no"]
        assert "#" in self.prompt

class LatentSampleBuffer:
    idx: int = 0
    def __init__(
        self, 
        activations: list[LatentSample],
        batch_size: int = 32, 
    ):
        self.activations = activations
        self.batch_size = batch_size
    
    def __len__(self):
        return len(self.activations)
    
    def __getitem__(self, idx):
        return self.activations[idx]

    def __iter__(self):
        return self

    def __next__(self):
        return self.activations[self.idx : self.idx + self.batch_size]

In [106]:
latent_arr = []
layers_of_interest = list(range(8, 20))

for idx in range(len(lcc.latents)):
    latent_cache = lcc.latents[idx]
    prompt = latent_cache.question
    label = latent_cache.answer
    yes_ans = "true" if str(label).lower().strip() == "true" else "false"
    no_ans = "false" if yes_ans == "true" else "true"
    for layer_idx in layers_of_interest:
        layer_name = mt.layer_name_format.format(layer_idx)
        activation = latent_cache.latents[layer_name]
        prompt, label = get_latent_qa(yes_ans, no_ans)
        latent_arr.append(LatentSample(
            activation=activation,
            prompt=prompt,
            label=label,
        ))
    # break

len(latent_arr)

4212

In [109]:
buffer = LatentSampleBuffer(latent_arr, batch_size=32)

In [110]:
batch = next(buffer)

In [123]:
batch[5].label

' no'

In [141]:
from src.tokens import prepare_input, find_token_range
from src.functional import interpret_logits, get_module_nnsight

def prepare_batch_input(batch: list[LatentSample], mt: ModelandTokenizer):
    batch_prompts = [b.prompt for b in batch]
    batch_tokenized = prepare_input(
        prompts=batch_prompts,
        tokenizer=mt,
        return_offset_mapping=True
    )

    int_tok_idx = []
    for idx in range(len(batch)):
        offset_mapping = batch_tokenized["offset_mapping"][idx]
        act_range = find_token_range(
            string=batch[idx].prompt,
            substring="#",
            occurrence=0,
            tokenizer=mt,
            offset_mapping=offset_mapping
        )
        int_tok_idx.append(act_range[1] - 1)

    batch_tokenized.pop("offset_mapping")

    return batch_tokenized, int_tok_idx

batch_tokenized, int_tok_idx = prepare_batch_input(batch, mt)
activations = [b.activation for b in batch]
with mt.trace(batch_tokenized):
    module_names = mt.layer_names
    for idx, act, int_tok in zip(range(len(batch)), activations, int_tok_idx):
        for module_name in module_names:
            module = get_module_nnsight(mt, module_name)
            module.output[0][idx, int_tok, :] = torch.tensor(act, device=mt.device)
    output = mt.output.save()

output.logits.shape

  output = self.target(*args, **kwargs)


torch.Size([32, 20, 128256])

In [142]:
predicted_labels = [
    interpret_logits(
        tokenizer=mt,
        logits=output.logits[idx][-1],
        k = 2,
    )[0]
    for idx in range(len(batch))
]

correct_labels = [b.label for b in batch]
correct_count = 0

for pred, correct in zip(predicted_labels, correct_labels):
    if pred.token.strip().lower() == correct.strip().lower():
        correct_count += 1

correct_count / len(batch)

0.53125

### Patchscope Finetuning

In [22]:
experiment_utils.set_seed(42)
model = mt._model
model.train()
device = mt.device

# Training parameters
learning_rate = 5e-5
batch_size = 6

model_save_dir = os.path.join(env_utils.DEFAULT_RESULTS_DIR, "patchscope_tuning")
os.makedirs(model_save_dir, exist_ok=True)
wandb_log_interval = 10
checkpoint_interval = 100
num_warmup_steps = 30
limit_training_steps = 1000
##############################################################################

2024-10-25 13:57:57 src.utils.experiment_utils INFO     setting all seeds to 42


In [23]:
import shutil
def remove_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)

# remove_dir(model_save_dir)
remove_dir(".wandb")

In [None]:
from transformers import get_linear_schedule_with_warmup

# dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

train_wiki_loader = DataLoader(train_wiki, batch_size=batch_size//2, shuffle=True)
test_wiki_loader = DataLoader(test_wiki, batch_size=batch_size//2, shuffle=False)

print(f"{len(train_loader)=} | {len(train_wiki_loader)=}")


limit_training_steps = min(
    limit_training_steps,
    len(train_loader),
    len(train_wiki_loader)
)

print(f"{limit_training_steps=}")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=len(train_loader)
)

In [None]:
# wandb
wandb.init(
    entity="dl-homeworks",
    project="talkative_probes",
    name=f"{MODEL_KEY}_finetune",
    config={
        "model_key": MODEL_KEY,
        "learning_rate": learning_rate,
        "wandb_log_interval": wandb_log_interval,
        "checkpoint_interval": checkpoint_interval,
        "num_warmup_steps": num_warmup_steps,
        "batch_size": batch_size,
    }
)


for step in tqdm(range(limit_training_steps), desc="Training"):
    optimizer.zero_grad()
    
    chess_batch = next(iter(train_loader))
    input_ids = chess_batch["input_ids"].to(device)
    attention_mask = chess_batch["attention_mask"].to(device)
    labels = chess_batch["labels"].to(device)
        
    pgn_outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    chess_loss = pgn_outputs.loss

    wiki_batch = next(iter(train_wiki_loader))
    wiki_input_ids = wiki_batch["input_ids"].to(device)
    wiki_attention_mask = wiki_batch["attention_mask"].to(device)
    wiki_labels = wiki_batch["labels"].to(device)

    wiki_outputs = model(input_ids=wiki_input_ids, attention_mask=wiki_attention_mask, labels=wiki_labels)
    wiki_loss = wiki_outputs.loss

    loss = chess_loss + wiki_loss

    loss.backward()
    optimizer.step()
    scheduler.step()

    if (step + 1) % wandb_log_interval == 0:
        wandb.log({
            "loss": loss.item(),
            "chess_loss": chess_loss.item(),
            "wiki_loss": wiki_loss.item(),
            "learning_rate": scheduler.get_last_lr()[0],
        })

    if ((step + 1) % checkpoint_interval == 0) or (step + 1) == limit_training_steps:
        if len(os.listdir(model_save_dir)) > 0:
            last_checkpoint_path = os.path.join(model_save_dir, os.listdir(model_save_dir)[-1])
            remove_dir(last_checkpoint_path)
        
        new_checkpoint_path = os.path.join(model_save_dir, f"checkpoint-{step + 1}")
        model.save_pretrained(new_checkpoint_path)

print("Training completed!")