# Tokenizer Sanity Check

## Tokenizer Choice

We'll use:
- **Mistral-7B tokenizer** for QLoRA training (main path)
- **DistilGPT-2 tokenizer** for CPU fallback

Each model has its own tokenizer. The tokenizer determines:
- How text is split into tokens
- Vocabulary size
- Special tokens (BOS, EOS, padding, etc.)

## Left Padding for Causal LM

Causal language models (like GPT, Mistral) generate left-to-right. For training:
- **Right padding:** Standard for most tasks
- **Left padding:** Sometimes used for batch efficiency

We'll use **right padding** (default) for simplicity.

## Max Length Trade-offs

- **Too short:** Truncates context, loses information
- **Too long:** Wastes memory, slower training

For Mistral-7B, 512 tokens is a good balance. Check your data distribution first!


In [None]:
# === TODO (you code this) ===
# Load tokenizer (Mistral or DistilGPT2) and encode/decode a few samples.
# Hints:
#   - Use AutoTokenizer.from_pretrained()
#   - Encode a sample text, then decode it back
#   - Compare original vs decoded (should match for most cases)
#   - Calculate average token length for a few samples
# Acceptance:
#   - round-trip decode matches expectations; reports avg token length

from transformers import AutoTokenizer

def tokenizer_roundtrip(base_model: str, seq_length: int):
    """
    Test tokenizer encoding/decoding and report statistics.
    
    Args:
        base_model: Model name (e.g., "mistralai/Mistral-7B-Instruct-v0.2")
        seq_length: Maximum sequence length
    """
    raise NotImplementedError

# Test Mistral tokenizer
tokenizer_roundtrip("mistralai/Mistral-7B-Instruct-v0.2", seq_length=512)


## Tokenization Function for Dataset

We need a function that tokenizes the entire dataset. This will be used in training.


In [None]:
# === TODO (you code this) ===
# Prepare map() function to tokenize Dataset with truncation and optional packing.
# Hints:
#   - Create a function that takes a batch dict
#   - Use tokenizer with truncation=True, max_length=seq_length
#   - Return dict with 'input_ids' and 'attention_mask'
#   - This function will be used with dataset.map()
# Acceptance:
#   - returns tokenized dataset with 'input_ids' and 'attention_mask'

def build_tokenize_fn(tokenizer, seq_length: int):
    """
    Build a tokenization function for dataset.map().
    
    Args:
        tokenizer: Pre-trained tokenizer
        seq_length: Maximum sequence length
        
    Returns:
        callable: Function that tokenizes a batch
    """
    raise NotImplementedError

# Test it
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenize_fn = build_tokenize_fn(tokenizer, seq_length=512)
print("Tokenization function ready!")
