In [11]:
import datasets
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
dataset = datasets.load_dataset("lmsys/lmsys-chat-1m", split="train")

In [12]:
def format_chat_prompt(conversation, tokenizer, model_name, remove_system_prompt=True):
    formatted_chat_prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=False)

    if remove_system_prompt:
        if "llama-3" in model_name.lower():
            # find first instance of user prompt
            user_prompt_idx = formatted_chat_prompt.find("<|start_header_id|>user<|end_header_id|>")
            formatted_chat_prompt = formatted_chat_prompt[user_prompt_idx:]

    return formatted_chat_prompt

In [13]:
for row in dataset:
    print(format_chat_prompt(row["conversation"], tokenizer, "Llama-3.2-1B-Instruct"))
    break

<|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 06 May 2025

<|eot_id|><|start_header_id|>user<|end_header_id|>

how can identity protection services help protect me against identity theft<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Identity protection services can help protect you against identity theft in several ways:

1. Monitoring: Many identity protection services monitor your credit reports, public records, and other sources for signs of identity theft. If they detect any suspicious activity, they will alert you so you can take action.
2. Credit freeze: Some identity protection services can help you freeze your credit, which makes it more difficult for thieves to open new accounts in your name.
3. Identity theft insurance: Some identity protection services offer insurance that can help you recover financially if you become a victim of identity theft.
4. Assistance: Many identity protection services offer 

In [17]:
import tqdm

total_num_toks = 0
total_num_toks_trimmed = {}
trim_tok_values = [128, 256, 512, 1024, 2048]

for trim_tok in trim_tok_values:
    total_num_toks_trimmed[trim_tok] = 0

for row in tqdm.tqdm(dataset):
    formatted_prompt = format_chat_prompt(row["conversation"], tokenizer, "Llama-3.2-1B-Instruct")
    toks = tokenizer(formatted_prompt, return_tensors="pt")
    total_num_toks += len(toks.input_ids[0])
    for trim_tok in trim_tok_values:
        total_num_toks_trimmed[trim_tok] += min(len(toks.input_ids[0]), trim_tok)


100%|██████████| 1000000/1000000 [17:08<00:00, 971.92it/s] 


In [18]:
print(total_num_toks)
for trim_tok in trim_tok_values:
    print(f"Total num toks trimmed to {trim_tok}: {total_num_toks_trimmed[trim_tok]}")


504863360
Total num toks trimmed to 128: 113921662
Total num toks trimmed to 256: 199597385
Total num toks trimmed to 512: 308676379
Total num toks trimmed to 1024: 395670423
Total num toks trimmed to 2048: 447395374
