In [5]:
import os

# Environment setup
os.environ["VLLM_USE_V1"] = "0"

import torch
import pickle
from pathlib import Path
from datetime import datetime
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm

# Constants
MODEL_NAME = "google/gemma-2-2b-it"
DATASET_NAME = "lmsys/lmsys-chat-1m"
CTX_LEN = 1024
MAX_DECODE_TOKENS = 512
# MAX_DECODE_TOKENS = 32
N_SAMPLES = 150
DTYPE = "bfloat16"
# OUTPUT_DIR = Path(f"activations_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
OUTPUT_DIR = Path("activations")

# Create output directory
OUTPUT_DIR.mkdir(exist_ok=True)

# Activation saving hook
temp_saved_activations = []


def activation_saving_hook(module, input, output):
    temp_saved_activations.append(output[0].detach().clone())


def concatenate_activations(activations_list):
    """Concatenate list of activation tensors along sequence dimension."""
    concatenated = []
    for sample_activations in activations_list:
        concat_tensor = torch.cat(sample_activations, dim=0)
        concatenated.append(concat_tensor)
    return concatenated


def save_data(data, filename):
    """Save data to disk using pickle."""
    filepath = OUTPUT_DIR / filename
    with open(filepath, "wb") as f:
        pickle.dump(data, f)
    print(f"Saved {filename}")


print("Loading tokenizer and dataset...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
ds = load_dataset(DATASET_NAME, split="train")

# Prepare prompts
prompts = [i["conversation"] for _, i in zip(range(N_SAMPLES), ds)]



Loading tokenizer and dataset...


In [6]:
print(prompts[0])

[{'content': 'how can identity protection services help protect me against identity theft', 'role': 'user'}, {'content': "Identity protection services can help protect you against identity theft in several ways:\n\n1. Monitoring: Many identity protection services monitor your credit reports, public records, and other sources for signs of identity theft. If they detect any suspicious activity, they will alert you so you can take action.\n2. Credit freeze: Some identity protection services can help you freeze your credit, which makes it more difficult for thieves to open new accounts in your name.\n3. Identity theft insurance: Some identity protection services offer insurance that can help you recover financially if you become a victim of identity theft.\n4. Assistance: Many identity protection services offer assistance if you become a victim of identity theft. They can help you file a police report, contact credit bureaus, and other steps to help you restore your identity.\n\nOverall, i

In [7]:
prompts = [
    tokenizer.apply_chat_template(prompt, tokenize=False) for prompt in prompts
]
print(prompts[0])

<bos><start_of_turn>user
how can identity protection services help protect me against identity theft<end_of_turn>
<start_of_turn>model
Identity protection services can help protect you against identity theft in several ways:

1. Monitoring: Many identity protection services monitor your credit reports, public records, and other sources for signs of identity theft. If they detect any suspicious activity, they will alert you so you can take action.
2. Credit freeze: Some identity protection services can help you freeze your credit, which makes it more difficult for thieves to open new accounts in your name.
3. Identity theft insurance: Some identity protection services offer insurance that can help you recover financially if you become a victim of identity theft.
4. Assistance: Many identity protection services offer assistance if you become a victim of identity theft. They can help you file a police report, contact credit bureaus, and other steps to help you restore your identity.

Over

In [8]:

# Tokenize inputs
tokenized_inputs = tokenizer(
    prompts,
    padding=False,
    return_tensors=None,
    add_special_tokens=False,
    truncation=True,
    max_length=CTX_LEN,
)
prompt_token_ids = [input_ids for input_ids in tokenized_inputs["input_ids"]]


In [20]:
print(prompt_token_ids[0])
print(type(prompt_token_ids[0]))

[2, 106, 1645, 108, 1139, 798, 12852, 6919, 3545, 1707, 9051, 682, 2691, 12852, 37214, 107, 108, 106, 2516, 108, 22869, 6919, 3545, 798, 1707, 9051, 692, 2691, 12852, 37214, 575, 3757, 5742, 235292, 109, 235274, 235265, 29654, 235292, 9429, 12852, 6919, 3545, 8200, 861, 6927, 8134, 235269, 2294, 9126, 235269, 578, 1156, 8269, 604, 11704, 576, 12852, 37214, 235265, 1927, 984, 21422, 1089, 45637, 5640, 235269, 984, 877, 14838, 692, 712, 692, 798, 1987, 3105, 235265, 108, 235284, 235265, 14882, 35059, 235292, 4213, 12852, 6919, 3545, 798, 1707, 692, 35059, 861, 6927, 235269, 948, 3833, 665, 978, 5988, 604, 72731, 577, 2174, 888, 12210, 575, 861, 1503, 235265, 108, 235304, 235265, 39310, 37214, 9490, 235292, 4213, 12852, 6919, 3545, 3255, 9490, 674, 798, 1707, 692, 11885, 50578, 1013, 692, 3831, 476, 17015, 576, 12852, 37214, 235265, 108, 235310, 235265, 38570, 235292, 9429, 12852, 6919, 3545, 3255, 11217, 1013, 692, 3831, 476, 17015, 576, 12852, 37214, 235265, 2365, 798, 1707, 692, 2482, 

In [24]:
id_1 = "Conversation ak13390c5d8"
id_2 = "Conversation zxcf8051a5t"

id_tokens_1 = tokenizer.encode(id_1, add_special_tokens=False, return_tensors=None)
id_tokens_2 = tokenizer.encode(id_2, add_special_tokens=False, return_tensors=None)

min_length = min(len(id_tokens_1), len(id_tokens_2))

id_tokens_1 = id_tokens_1[:min_length] + tokenizer.encode("\n", add_special_tokens=False)
id_tokens_2 = id_tokens_2[:min_length] + tokenizer.encode("\n", add_special_tokens=False)

print(id_tokens_1)
print(id_tokens_2)



[72955, 5179, 235274, 235304, 235304, 235315, 235276, 235260, 235308, 235258, 108]
[72955, 94661, 13311, 235321, 235276, 235308, 235274, 235250, 235308, 235251, 108]


In [9]:
from typing import List, Tuple
INSERT_POS = 4

def insert_id_tokens(
    prompts: List[List[int]],
    id_tokens: List[int],
    insert_pos: int = INSERT_POS,
    max_len: int | None = None,
) -> List[List[int]]:
    """
    Return a NEW list of prompt token-lists with `id_tokens` spliced in.

    Also returns a list recording the original prompt lengths; this makes
    it easy to enforce the same total length when you swap IDs later.
    """
    out: List[List[int]] = []

    for ids in prompts:
        new_ids = ids[:insert_pos] + id_tokens + ids[insert_pos:]

        # Optional: trim the tail to respect ctx length
        if max_len is not None and len(new_ids) > max_len:
            new_ids = new_ids[:max_len]

        out.append(new_ids)

    return out



[2, 106, 1645, 108, 1139, 798, 12852, 6919, 3545, 1707, 9051, 682, 2691, 12852, 37214, 107, 108, 106, 2516, 108, 22869, 6919, 3545, 798, 1707, 9051, 692, 2691, 12852, 37214, 575, 3757, 5742, 235292, 109, 235274, 235265, 29654, 235292, 9429, 12852, 6919, 3545, 8200, 861, 6927, 8134, 235269, 2294, 9126, 235269, 578, 1156, 8269, 604, 11704, 576, 12852, 37214, 235265, 1927, 984, 21422, 1089, 45637, 5640, 235269, 984, 877, 14838, 692, 712, 692, 798, 1987, 3105, 235265, 108, 235284, 235265, 14882, 35059, 235292, 4213, 12852, 6919, 3545, 798, 1707, 692, 35059, 861, 6927, 235269, 948, 3833, 665, 978, 5988, 604, 72731, 577, 2174, 888, 12210, 575, 861, 1503, 235265, 108, 235304, 235265, 39310, 37214, 9490, 235292, 4213, 12852, 6919, 3545, 3255, 9490, 674, 798, 1707, 692, 11885, 50578, 1013, 692, 3831, 476, 17015, 576, 12852, 37214, 235265, 108, 235310, 235265, 38570, 235292, 9429, 12852, 6919, 3545, 3255, 11217, 1013, 692, 3831, 476, 17015, 576, 12852, 37214, 235265, 2365, 798, 1707, 692, 2482, 

In [11]:
def _list_decode(x: torch.Tensor):
    assert len(x.shape) == 1 or len(x.shape) == 2
    # Convert to list of lists, even if x is 1D
    if len(x.shape) == 1:
        x = x.unsqueeze(0)  # Make it 2D for consistent handling

    # Convert tensor to list of list of ints
    token_ids = x.tolist()
    
    # Convert token ids to token strings
    return [tokenizer.batch_decode(seq, skip_special_tokens=False) for seq in token_ids]

print(_list_decode(torch.tensor(prompt_token_ids[0])))

[['<bos>', '<start_of_turn>', 'user', '\n', 'how', ' can', ' identity', ' protection', ' services', ' help', ' protect', ' me', ' against', ' identity', ' theft', '<end_of_turn>', '\n', '<start_of_turn>', 'model', '\n', 'Identity', ' protection', ' services', ' can', ' help', ' protect', ' you', ' against', ' identity', ' theft', ' in', ' several', ' ways', ':', '\n\n', '1', '.', ' Monitoring', ':', ' Many', ' identity', ' protection', ' services', ' monitor', ' your', ' credit', ' reports', ',', ' public', ' records', ',', ' and', ' other', ' sources', ' for', ' signs', ' of', ' identity', ' theft', '.', ' If', ' they', ' detect', ' any', ' suspicious', ' activity', ',', ' they', ' will', ' alert', ' you', ' so', ' you', ' can', ' take', ' action', '.', '\n', '2', '.', ' Credit', ' freeze', ':', ' Some', ' identity', ' protection', ' services', ' can', ' help', ' you', ' freeze', ' your', ' credit', ',', ' which', ' makes', ' it', ' more', ' difficult', ' for', ' thieves', ' to', ' op