In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time, json
import sys
sys.path.append("../")
sys.path.append("../chess_llm_interpretability")
import os
import torch
import numpy as np
from tqdm import tqdm

import logging
from src.utils import logging_utils
from src.utils import env_utils, experiment_utils
from src import functional
import wandb

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")

  from .autonotebook import tqdm as notebook_tqdm


2024-10-26 13:38:01 __main__ INFO     torch.__version__='2.5.0+cu124', torch.version.cuda='12.4'


In [3]:
from src.models import ModelandTokenizer

# MODEL_KEY = "meta-llama/Llama-3.2-3B-Instruct"
# MODEL_KEY = "meta-llama/Llama-3.1-8B-Instruct"

MODEL_KEY = "meta-llama/Llama-3.2-3B"
# MODEL_KEY = "google/gemma-2-2b"
# MODEL_KEY = "meta-llama/Llama-3.1-8B"

#! torch.adaptive precision
mt = ModelandTokenizer(
    model_key=MODEL_KEY,
    torch_dtype=torch.float16,
)

2024-10-26 13:38:03 accelerate.utils.modeling INFO     We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.00s/it]

2024-10-26 13:38:05 src.models INFO     loaded model </home/local_arnab/Codes/00_MODEL/meta-llama/Llama-3.2-3B> | size: 6127.841 MB | dtype: torch.float16 | device: cuda:0





### LoRA (check later)

In [4]:
# from peft import LoraConfig, get_peft_model

# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],
#     lora_dropout=0.1,
#     bias="none",
#     task_type="CAUSAL_LM",
# )

# model = get_peft_model(mt._model, lora_config)

In [5]:
# type(model), type(model.model)

In [6]:
# for p in model.model.named_parameters():
#     print(p[0])

### Dataset Preparation

In [4]:
from src.dataset import GMTDataset

ds = GMTDataset.from_csv(
    [
        "sp_en_trans.csv", 
        # "cities.csv"
    ], 
    "sp_en_trans"
)

ds.select_few_shot(0)

queries = [ds.examples[i] for i in range(len(ds))]
interested_layers = mt.layer_names
q, a = queries[0]
print(q)
print(a)

2024-10-26 13:38:43 numexpr.utils INFO     Note: NumExpr detected 24 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-10-26 13:38:43 numexpr.utils INFO     NumExpr defaulting to 8 threads.
2024-10-26 13:38:43 src.dataset INFO     initialized sp_en_trans with 348 examples.
The Spanish word 'gato' means 'cat'.
True


In [None]:
from src.functional import get_concept_latents

# don't need to run this if you already have cached the results

latents = get_concept_latents(
    mt=mt, 
    queries=queries, 
    interested_layers=interested_layers,
    check_answer=False,
)

In [5]:
latent_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, 
    "cached_latents",
    MODEL_KEY.split("/")[-1],
)
os.makedirs(latent_dir, exist_ok=True)


In [8]:
# caching the latents

from src.utils.typing import LatentCacheCollection
from src.utils import env_utils

lcc = LatentCacheCollection(latents=latents)
lcc.detensorize()
with open(os.path.join(latent_dir, f"{ds.name}.json"), "w") as f:
    f.write(lcc.to_json())

In [6]:
# loading the latents

from src.utils.typing import LatentCacheCollection
from src.utils import env_utils

with open(os.path.join(latent_dir, f"{ds.name}.json"), "r") as f:
    dct = json.load(f)
lcc = LatentCacheCollection.from_dict(dct)
lcc.retensorize(device=mt.device)

In [7]:
import random
from typing import Literal

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/yes_no.json"), "r") as f:
    yes_no_paraphrases = json.load(f)

with open(os.path.join(env_utils.DEFAULT_DATA_DIR, "paraphrases/question.json"), "r") as f:
    question_paraphrases = json.load(f)["GMT"]

def get_latent_qa(yes_ans, no_ans) -> tuple[str, Literal[" yes", " no"]]:
    label = random.choice([" yes", " no"])
    ret = "# "
    yes_no = random.choice(yes_no_paraphrases)
    question = random.choice(question_paraphrases)
    question = question.format(yes_ans) if label == " yes" else question.format(no_ans)
    ret += question + f" {yes_no}"
    return ret, label


q, a = queries[10]
print(q, a)
yes_ans = "true" if str(a).lower().strip() == "true" else "false"
no_ans = "false" if yes_ans == "true" else "true"
latent_q = get_latent_qa(yes_ans, no_ans)
print(latent_q)

The Spanish word 'mientras' means 'while'. True
('# Is this true? Answer yes or no.\nAnswer:', ' yes')


In [8]:
from dataclasses import dataclass, field, fields
from src.utils.typing import ArrayLike
from typing import Literal

@dataclass(frozen=True)
class LatentSample:
    activation: ArrayLike
    prompt: str
    label: Literal[" yes", " no"]

    def __post_init__(self):
        assert self.label in [" yes", " no"]
        assert "#" in self.prompt

class LatentSampleBuffer:
    idx: int = 0
    def __init__(
        self, 
        activations: list[LatentSample],
        batch_size: int = 32, 
    ):
        self.activations = activations
        self.batch_size = batch_size
    
    def __len__(self):
        return len(self.activations)
    
    def __getitem__(self, idx):
        return self.activations[idx]

    def __iter__(self):
        for i in range(0, len(self.activations), self.batch_size):
            yield self.activations[i : i + self.batch_size]

    def next_batch(self):
        # return self.activations[self.idx : self.idx + self.batch_size]
        if self.idx >= len(self.activations):
            raise StopIteration
        ret = self.activations[self.idx : self.idx + self.batch_size]
        self.idx += self.batch_size
        return ret

In [None]:
latent_arr = []
layers_of_interest = list(range(8, 20))

for idx in range(len(lcc.latents)):
    latent_cache = lcc.latents[idx]
    prompt = latent_cache.question
    label = latent_cache.answer
    yes_ans = "true" if str(label).lower().strip() == "true" else "false"
    no_ans = "false" if yes_ans == "true" else "true"
    for layer_idx in layers_of_interest:
        layer_name = mt.layer_name_format.format(layer_idx)
        activation = latent_cache.latents[layer_name]
        prompt, label = get_latent_qa(yes_ans, no_ans)
        latent_arr.append(LatentSample(
            activation=activation,
            prompt=prompt,
            label=label,
        ))
    # break

len(latent_arr)

4212

In [10]:
buffer = LatentSampleBuffer(latent_arr, batch_size=2)

In [11]:
next(iter(buffer))

[LatentSample(activation=tensor([ 0.2187,  0.1573, -0.1146,  ..., -0.1818, -0.1682, -0.0704],
        device='cuda:0'), prompt='# Is this statement true? Answer yes or no.\nAnswer:', label=' no'),
 LatentSample(activation=tensor([ 0.1120,  0.1228, -0.0944,  ..., -0.2126, -0.1215, -0.0852],
        device='cuda:0'), prompt='# This statement is false. Do you agree? (yes/no) Answer:', label=' yes')]

In [12]:
# for batch in buffer:
#     print(len(batch))

In [13]:
batch = next(iter(buffer))
print(len(batch))

batch_labels = [sample.label for sample in batch]
print(batch_labels)

2
[' no', ' yes']


In [14]:
from src.tokens import prepare_input, find_token_range
from src.functional import interpret_logits, get_module_nnsight

def prepare_batch_input(batch: list[LatentSample], mt: ModelandTokenizer):
    batch_prompts = [b.prompt for b in batch]
    batch_tokenized = prepare_input(
        prompts=batch_prompts,
        tokenizer=mt,
        return_offset_mapping=True
    )

    int_tok_idx = []
    for idx in range(len(batch)):
        offset_mapping = batch_tokenized["offset_mapping"][idx]
        act_range = find_token_range(
            string=batch[idx].prompt,
            substring="#",
            occurrence=0,
            tokenizer=mt,
            offset_mapping=offset_mapping
        )
        int_tok_idx.append(act_range[1] - 1)

    batch_tokenized.pop("offset_mapping")

    return batch_tokenized, int_tok_idx

In [21]:
batch_tokenized, int_tok_idx = prepare_batch_input(batch, mt)
logger.debug(f"{batch_tokenized.input_ids.shape=}")
activations = [b.activation for b in batch]

with torch.no_grad():
    with mt.trace(batch_tokenized):
        module_names = mt.layer_names
        for idx, act, int_tok in zip(range(len(batch)), activations, int_tok_idx):
            for module_name in module_names:
                module = get_module_nnsight(mt, module_name)
                module.output[0][idx, int_tok, :] = torch.tensor(act, device=mt.device)
        last_logits = [
            mt.output.logits[idx, -1, :].save()
            for idx in range(len(batch))
        ]
        output = mt.output.save()

last_logits = torch.stack(last_logits)
print(last_logits.shape)

2024-10-26 13:27:34 __main__ DEBUG    batch_tokenized.input_ids.shape=torch.Size([2, 16])


  output = self.target(*args, **kwargs)
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([2, 128256])


In [22]:
token_mapping = {
    "yes": mt.tokenizer(" yes").input_ids[-1],
    "no": mt.tokenizer(" no").input_ids[-1],
}
token_mapping

{'yes': 10035, 'no': 912}

In [23]:
predicted_labels = [
    interpret_logits(
        tokenizer=mt,
        logits=last_logits[idx],
        # logits = output.logits[idx, -1, :],
        k = 2,
    )[0]
    for idx in range(len(batch))
]

correct_labels = [b.label for b in batch]
correct_count = 0

for pred, correct in zip(predicted_labels, correct_labels):
    if pred.token.strip().lower() == correct.strip().lower():
        correct_count += 1

correct_count / len(batch)

1.0

### Patchscope Finetuning

In [15]:
experiment_utils.set_seed(42)
model = mt._model
model.train()
device = mt.device

# Training parameters
learning_rate = 5e-5

model_save_dir = os.path.join(env_utils.DEFAULT_RESULTS_DIR, "patchscope_tuning")
os.makedirs(model_save_dir, exist_ok=True)
log_steps = 1
checkpoint_interval = 100
num_warmup_steps = 30
limit_training_steps = 1000
##############################################################################

2024-10-26 13:39:45 src.utils.experiment_utils INFO     setting all seeds to 42


In [16]:
import shutil
def remove_dir(path):
    if os.path.exists(path):
        shutil.rmtree(path)

# remove_dir(model_save_dir)
remove_dir(".wandb")

In [17]:
from transformers import get_linear_schedule_with_warmup

buffer_steps = len(buffer) // buffer.batch_size
limit_training_steps = min(
    limit_training_steps,
    buffer_steps
)
logger.info(f"{limit_training_steps=}")

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=min(num_warmup_steps, limit_training_steps // 10),
    num_training_steps=limit_training_steps,
)

loss_func = torch.nn.CrossEntropyLoss()

2024-10-26 13:39:49 __main__ INFO     limit_training_steps=1000


In [20]:
from src.functional import free_gpu_cache

# wandb
wandb_logging = False

if wandb_logging:
    wandb.init(
        entity="dl-homeworks",
        project="talkative_probes",
        name=f"{MODEL_KEY}_finetune",
        config={
            "model_key": MODEL_KEY,
            "learning_rate": learning_rate,
            "wandb_log_interval": log_steps,
            "checkpoint_interval": checkpoint_interval,
            "num_warmup_steps": num_warmup_steps,
            "batch_size": buffer.batch_size,
        }
    )


mt._model.train()
for step in tqdm(range(limit_training_steps), desc="Training"):
    optimizer.zero_grad()

    try:
        batch = buffer.next_batch()
    except StopIteration:
        buffer.idx = 0
        batch = buffer.next_batch()
    
    batch_tokenized, int_tok_idx = prepare_batch_input(batch, mt)
    logger.info(batch_tokenized.input_ids.shape)
    activations = [b.activation for b in batch]

    with mt.trace(batch_tokenized):
        module_names = mt.layer_names # replace the latent on all the residual layers
        for idx, act, int_tok in zip(range(len(batch)), activations, int_tok_idx):
            for module_name in module_names:
                module = get_module_nnsight(mt, module_name)
                module.output[0][idx, int_tok, :] = torch.tensor(act, device=mt.device).to(mt.dtype)
        
        # output = mt.output.save()
        last_logits = [
            mt.output.logits[idx, -1, :].save()
            for idx in range(len(batch))
        ]

    last_logits = torch.stack(last_logits)
    batch_labels = [
        mt.tokenizer(b.label).input_ids[-1] for b in batch
    ]
    batch_labels = torch.tensor(batch_labels, device=mt.device)

    # Cross-entropy loss
    patchscope_loss = loss_func(last_logits, batch_labels)
    
    # TODO: include natural text and generation loss
    loss = patchscope_loss
    
    loss.backward()
    # break
    optimizer.step()
    scheduler.step()

    free_gpu_cache()

    if (step + 1) % log_steps == 0:
        log_data = {
            "loss": loss.item(),
            "learning_rate": scheduler.get_last_lr()[0],
        }
        logger.info(f"Step {step + 1}: {log_data}")
        if wandb_logging:
            wandb.log(log_data)

    if ((step + 1) % checkpoint_interval == 0) or (step + 1) == limit_training_steps:
        if len(os.listdir(model_save_dir)) > 0:
            last_checkpoint_path = os.path.join(model_save_dir, os.listdir(model_save_dir)[-1])
            remove_dir(last_checkpoint_path)
        
        new_checkpoint_path = os.path.join(model_save_dir, f"checkpoint-{step + 1}")
        model.save_pretrained(new_checkpoint_path)

    

print("Training completed!")

Training:   0%|          | 0/1000 [00:00<?, ?it/s]

2024-10-26 13:35:25 __main__ INFO     torch.Size([2, 17])


  output = self.target(*args, **kwargs)
You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Training:   0%|          | 0/1000 [00:00<?, ?it/s]

Training completed!



