# Test: tokenizer module

This notebook tests the `ChatTemplateTokenizer` class in `mech_interp_toolkit.tokenizer`.

In [None]:
import sys
import os
sys.path.append(os.path.abspath("../src"))

import torch
from transformers import AutoTokenizer
from mech_interp_toolkit.tokenizer import ChatTemplateTokenizer

## Setup: Load base tokenizer

In [None]:
model_name = "Qwen/Qwen3-0.6B"
base_tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
print(f"Loaded base tokenizer for {model_name}")

## Test: ChatTemplateTokenizer initialization

In [None]:
# Test default initialization
chat_tokenizer = ChatTemplateTokenizer(base_tokenizer)

print(f"System prompt: {chat_tokenizer.system_prompt}")
print(f"Suffix: '{chat_tokenizer.suffix}'")
assert chat_tokenizer.tokenizer is not None, "Tokenizer should be set"
assert chat_tokenizer.tokenizer.pad_token is not None, "Pad token should be set"
print(f"Pad token: {chat_tokenizer.tokenizer.pad_token}")
print("PASSED: Default initialization")

In [None]:
# Test custom initialization
custom_system_prompt = "You are a helpful assistant."
custom_suffix = "\n\nThink step by step:"

custom_tokenizer = ChatTemplateTokenizer(
    base_tokenizer,
    system_prompt=custom_system_prompt,
    suffix=custom_suffix
)

assert custom_tokenizer.system_prompt == custom_system_prompt, "Custom system prompt not set"
assert custom_tokenizer.suffix == custom_suffix, "Custom suffix not set"
print(f"Custom system prompt: {custom_tokenizer.system_prompt}")
print(f"Custom suffix: '{custom_tokenizer.suffix}'")
print("PASSED: Custom initialization")

## Test: _apply_chat_template()

In [None]:
# Test with single string prompt
single_prompt = "What is 2 + 2?"
formatted = chat_tokenizer._apply_chat_template(single_prompt, thinking=False)

print(f"Single prompt formatted:")
print(formatted[0][:200] + "...")
print()

assert isinstance(formatted, list), "Should return a list"
assert len(formatted) == 1, "Should have one formatted prompt"
assert chat_tokenizer.system_prompt in formatted[0], "System prompt should be in formatted text"
assert single_prompt in formatted[0], "User prompt should be in formatted text"
print("PASSED: Single string prompt")

In [None]:
# Test with list of prompts
multiple_prompts = [
    "What is the capital of France?",
    "What is 5 * 7?",
    "Explain gravity."
]

formatted = chat_tokenizer._apply_chat_template(multiple_prompts, thinking=False)

print(f"Number of formatted prompts: {len(formatted)}")
assert len(formatted) == 3, "Should have three formatted prompts"

for i, (orig, fmt) in enumerate(zip(multiple_prompts, formatted)):
    assert orig in fmt, f"Original prompt {i} should be in formatted text"
    print(f"Prompt {i}: '{orig}' -> {len(fmt)} chars")

print("PASSED: List of prompts")

In [None]:
# Test with thinking=True
prompt = "Solve this puzzle"
formatted_thinking = chat_tokenizer._apply_chat_template(prompt, thinking=True)
formatted_no_thinking = chat_tokenizer._apply_chat_template(prompt, thinking=False)

print(f"With thinking=True: {len(formatted_thinking[0])} chars")
print(f"With thinking=False: {len(formatted_no_thinking[0])} chars")
# The formats may differ depending on the model's chat template
print("PASSED: thinking parameter")

## Test: _encode()

In [None]:
# Test encoding single formatted prompt
formatted = chat_tokenizer._apply_chat_template("Hello, world!", thinking=False)
encoded = chat_tokenizer._encode(formatted)

print(f"Encoded keys: {encoded.keys()}")
print(f"input_ids shape: {encoded['input_ids'].shape}")
print(f"attention_mask shape: {encoded['attention_mask'].shape}")

assert "input_ids" in encoded, "Should have input_ids"
assert "attention_mask" in encoded, "Should have attention_mask"
assert isinstance(encoded["input_ids"], torch.Tensor), "input_ids should be a tensor"
assert isinstance(encoded["attention_mask"], torch.Tensor), "attention_mask should be a tensor"
print("PASSED: Single prompt encoding")

In [None]:
# Test encoding multiple prompts with padding
prompts = [
    "Short",
    "This is a much longer prompt that should require more tokens"
]
formatted = chat_tokenizer._apply_chat_template(prompts, thinking=False)
encoded = chat_tokenizer._encode(formatted)

print(f"Batch input_ids shape: {encoded['input_ids'].shape}")
print(f"Batch attention_mask shape: {encoded['attention_mask'].shape}")

batch_size, seq_len = encoded["input_ids"].shape
assert batch_size == 2, "Batch size should be 2"
assert encoded["input_ids"].shape == encoded["attention_mask"].shape, "Shapes should match"

# Check padding (left padding by default)
print(f"First sequence attention mask: {encoded['attention_mask'][0, :10].tolist()}...")
print(f"Second sequence attention mask: {encoded['attention_mask'][1, :10].tolist()}...")
print("PASSED: Batch encoding with padding")

## Test: __call__() (main interface)

In [None]:
# Test direct call with single prompt
result = chat_tokenizer("What is machine learning?", thinking=False)

print(f"Result type: {type(result)}")
print(f"Keys: {result.keys()}")
print(f"input_ids shape: {result['input_ids'].shape}")
print(f"attention_mask shape: {result['attention_mask'].shape}")

assert "input_ids" in result, "Should have input_ids"
assert "attention_mask" in result, "Should have attention_mask"
assert result["input_ids"].dim() == 2, "Should be 2D tensor (batch, seq_len)"
print("PASSED: Single prompt call")

In [None]:
# Test direct call with multiple prompts
prompts = [
    "Define artificial intelligence.",
    "What is deep learning?",
    "Explain neural networks.",
    "What is NLP?"
]

result = chat_tokenizer(prompts, thinking=False)

print(f"Batch size: {result['input_ids'].shape[0]}")
print(f"Sequence length: {result['input_ids'].shape[1]}")

assert result["input_ids"].shape[0] == 4, "Batch size should be 4"
print("PASSED: Multiple prompts call")

In [None]:
# Verify structured_prompt is stored
chat_tokenizer("Test prompt", thinking=False)

assert chat_tokenizer.structured_prompt is not None, "structured_prompt should be stored"
print(f"Stored structured prompt (truncated): {chat_tokenizer.structured_prompt[0][:100]}...")
print("PASSED: structured_prompt storage")

## Test: Suffix functionality

In [None]:
# Test that suffix is appended
suffix = "\n\nAnswer:"
tokenizer_with_suffix = ChatTemplateTokenizer(base_tokenizer, suffix=suffix)

formatted = tokenizer_with_suffix._apply_chat_template("Question?", thinking=False)

assert formatted[0].endswith(suffix), f"Formatted prompt should end with suffix. Got: ...{formatted[0][-50:]}"
print(f"Formatted ends with suffix: {formatted[0][-30:]}")
print("PASSED: Suffix functionality")

## Summary

In [None]:
print("="*50)
print("All tokenizer module tests PASSED!")
print("="*50)