In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json

import sys

sys.path.append("../")

##################################################################
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7"
##################################################################

import logging
from src.utils import logging_utils
from src.utils import env_utils

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

import torch
import transformers

logger.info(f"{torch.__version__=}, {torch.version.cuda=}")
logger.info(
    f"{torch.cuda.is_available()=}, {torch.cuda.device_count()=}, {torch.cuda.get_device_name()=}"
)
logger.info(f"{transformers.__version__=}")

In [None]:
from src.utils.training_utils import get_device_map

# model_key = "meta-llama/Llama-3.2-3B"
# model_key = "meta-llama/Llama-3.1-8B"
model_key = "meta-llama/Llama-3.3-70B-Instruct"
# model_key = "meta-llama/Llama-3.1-405B-Instruct"

# model_key = "google/gemma-2-9b-it"
# model_key = "google/gemma-3-12b-it"
# model_key = "google/gemma-2-27b-it"

# model_key = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

# model_key = "allenai/OLMo-2-1124-7B-Instruct"
# model_key = "allenai/OLMo-7B-0424-hf"

# model_key = "Qwen/Qwen2-7B"
# model_key = "Qwen/Qwen2.5-14B-Instruct"
# model_key = "Qwen/Qwen2.5-32B-Instruct"
# model_key = "Qwen/Qwen2.5-72B-Instruct"

# model_key = "Qwen/Qwen3-1.7B"
# model_key = "Qwen/Qwen3-4B"
# model_key = "Qwen/Qwen3-8B"
# model_key = "Qwen/Qwen3-14B"
# model_key = "Qwen/Qwen3-32B"

# device_map = get_device_map(model_key, 30, n_gpus=8)
# device_map

In [None]:
from src.models import ModelandTokenizer

# from transformers import BitsAndBytesConfig

mt = ModelandTokenizer(
    model_key=model_key,
    torch_dtype=torch.bfloat16,
    # device_map=device_map,
    device_map="auto",
    # quantization_config = BitsAndBytesConfig(
    #     # load_in_4bit=True
    #     load_in_8bit=True
    # )
)

In [None]:
from src.functional import free_gpu_cache

# SYNTH_DATASET = "icosahedron_1"
SYNTH_DATASET = "64"

checkpoint_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR,
    "trained_params",
    f"{SYNTH_DATASET}",
    "_full__clamp=0.001",
    model_key.split("/")[-1],
)

version = "epoch_1"
# version = "final_model"

checkpoint_path = os.path.join(env_utils.DEFAULT_RESULTS_DIR, checkpoint_path, version)

print(os.listdir(checkpoint_path))

checkpoint_path = os.path.join(checkpoint_path, "trainable_params.pt")

loaded_deltas = torch.load(checkpoint_path, map_location="cpu")
# loaded_deltas

free_gpu_cache()


d = loaded_deltas["model<>layers<>10<>mlp<>gate_proj"]
d.abs().max()

In [None]:
from src.utils.training_utils import TrainableLM_delta, TrainableLM_LoRA

#################################################
Trainable_CLS = TrainableLM_delta
# Trainable_CLS = TrainableLM_LoRA
#################################################

Trainable_CLS.fuse_with_model(mt._model, loaded_deltas)

## Generate samples and filter by LM knowledge

In [None]:
from src.selection.data  import load_people_by_category_fakeverse

people_by_category = load_people_by_category_fakeverse(tokenizer = mt.tokenizer)

In [None]:
from src.selection.data import SelectionSample, get_random_sample

sample = get_random_sample(
    people_by_category = people_by_category,
    mt = mt,
    n_distractors=5,
    get_alt_obj=False,
    # category="actor",
    obj_idx=3,
    filter_by_lm_prediction=True
)
print(sample)
print(sample.prompt)
sample.prediction

In [None]:
##########################################################
save_dir = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, "selection", mt.name.split("/")[-1], "profession"
)
os.makedirs(save_dir, exist_ok=True)
file_name = "filtered_samples.json"

LIMIT = 12000
N_DISTRACTORS = 5
SAVE_STEP = 100
##########################################################

from src.utils.experiment_utils import set_seed
set_seed(123456)

filtered_samples = []
while len(filtered_samples) < LIMIT:
    sample = get_random_sample(
        people_by_category=people_by_category,
        mt=mt,
        n_distractors=N_DISTRACTORS,
        get_alt_obj=False,
        filter_by_lm_prediction=True
    )
    sample.detensorize()
    filtered_samples.append(sample)
    print(f"Collected {len(filtered_samples)}/{LIMIT} samples. {len(filtered_samples) / LIMIT * 100:.2f}%")

    if len(filtered_samples) % SAVE_STEP == 0 or len(filtered_samples) == LIMIT:
        print(f"Saving {len(filtered_samples)} samples to {os.path.join(save_dir, file_name)}")
        with open(os.path.join(save_dir, file_name), "w") as f:
            json.dump([s.to_dict() for s in filtered_samples], f, indent=2)

# with open(os.path.join(save_dir, file_name), "w") as f:
#     json.dump([s.to_dict() for s in filtered_samples], f, indent=2)

In [None]:
with open(os.path.join(save_dir, "filtered_samples.json"), "r") as f:
    loaded_samples = json.load(f)
loaded_samples = [SelectionSample.from_dict(s) for s in loaded_samples]

In [None]:
print(loaded_samples[0].prompt)

## Cache last token states for the generated samples

In [None]:
with open(os.path.join(save_dir, "filtered_samples.json"), "r") as f:
    filtered_samples = json.load(f)
filtered_samples = [SelectionSample.from_dict(s) for s in filtered_samples]

len(filtered_samples)

In [None]:
#######################################################################################
cache_dir = os.path.join(save_dir, "cached_states")

all_layers = (
    [mt.embedder_name]  # embeddings
    + mt.layer_names  # residual
    + [mt.mlp_module_name_format.format(i) for i in range(mt.n_layer)]  # mlp outputs
    + [mt.attn_module_name_format.format(i) for i in range(mt.n_layer)]  # attn outputs
)
TOKEN_POSITION = -1

locations = [(layer_name, TOKEN_POSITION) for layer_name in all_layers]
#######################################################################################
os.makedirs(cache_dir, exist_ok=True)

import numpy as np
from src.utils.typing import TokenizerOutput
from src.functional import get_hs, detensorize
from src.tokens import prepare_input

for idx, sample in enumerate(filtered_samples):
    # inputs = TokenizerOutput(data = sample.metadata["tokenized"]).to(mt.device)
    inputs = prepare_input(prompts=sample.prompt, tokenizer=mt)
    sample.detensorize()

    cache = {"sample": sample.to_dict(), "states": {}}
    states = get_hs(
        mt=mt,
        input=inputs,
        locations=locations,
        return_dict=True,
    )

    for (layer_name, tok_idx), state in states.items():
        cache["states"][layer_name] = state.detach().to(torch.float32).cpu().numpy()

    cache = detensorize(cache)
    np.savez_compressed(
        os.path.join(cache_dir, f"sample_{idx}.npz"), **cache, allow_pickle=True
    )

    logger.info(
        f"Processed sample {idx + 1}/{len(filtered_samples)} ({(idx + 1) / len(filtered_samples) * 100:.2f}%)"
    )

In [None]:
val = 5
print(f"{val:05d}")

### Testing by loading the cached states

In [None]:
import os
import numpy as np

# cache_dir = os.path.join(save_dir, "cached_states")
cache_dir = "/disk/u/arnab/Codes/Projects/retrieval/results/selection/Llama-3.3-70B-Instruct/profession/cached_states/last_token/Llama-3.3-70B-Instruct"
os.listdir(cache_dir)

In [None]:
sample_states = np.load(os.path.join(cache_dir, "sample_00001.npz"), allow_pickle=True)
sample_states.files

In [None]:
from src.selection.data import SelectionSample

sample = SelectionSample.from_dict(sample_states["sample"].item())
print(sample.prompt)
print(sample.prediction)

In [None]:
states = {
    layer_name: torch.Tensor(value)
    for layer_name, value in sample_states["states"].item().items()
}

states