# Tokenizer Sanity Check

## Tokenizer Choice

We'll use:
- **Mistral-7B tokenizer** for QLoRA training (main path)
- **DistilGPT-2 tokenizer** for CPU fallback

Each model has its own tokenizer. The tokenizer determines:
- How text is split into tokens
- Vocabulary size
- Special tokens (BOS, EOS, padding, etc.)

## Left Padding for Causal LM

Causal language models (like GPT, Mistral) generate left-to-right. For training:
- **Right padding:** Standard for most tasks
- **Left padding:** Sometimes used for batch efficiency

We'll use **right padding** (default) for simplicity.

## Max Length Trade-offs

- **Too short:** Truncates context, loses information
- **Too long:** Wastes memory, slower training

For Mistral-7B, 512 tokens is a good balance. Check your data distribution first!


In [5]:
# === TODO (you code this) ===
# Load tokenizer (Mistral or DistilGPT2) and encode/decode a few samples.
# Hints:
#   - Use AutoTokenizer.from_pretrained()
#   - Encode a sample text, then decode it back
#   - Compare original vs decoded (should match for most cases)
#   - Calculate average token length for a few samples
# Acceptance:
#   - round-trip decode matches expectations; reports avg token length

from transformers import AutoTokenizer
import numpy as np

def tokenizer_roundtrip(base_model: str, seq_length: int, text: str):
    """
    Test tokenizer encoding/decoding and report statistics.
    
    Args:
        base_model: Model name (e.g., "mistralai/Mistral-7B-Instruct-v0.2")
        seq_length: Maximum sequence length
        text: Sample text to encode/decode
    """
    tokenizer = AutoTokenizer.from_pretrained(base_model)
    encoded = tokenizer.encode(text, add_special_tokens=True, truncation=True, max_length=seq_length)
    decoded = tokenizer.decode(encoded, skip_special_tokens=True)
    # Fix: encoded is a list of token IDs (integers), not tokens themselves
    # To get actual token strings, use convert_ids_to_tokens
    tokens = tokenizer.convert_ids_to_tokens(encoded)
    token_lengths = [len(token) for token in tokens]
    print(f"Original: {text}")
    print(f"Decoded: {decoded}")
    print(f"Round-trip match: {decoded.strip() == text.strip()}")
    print(f"Token count: {len(encoded)}")
    print(f"Character count: {len(decoded)}")
    print(f"Average token length: {np.mean(token_lengths):.2f}")


# Test Mistral tokenizer
sample_text = "It was on a dreary night of November that I beheld the accomplishment of my toils."
tokenizer_roundtrip("mistralai/Mistral-7B-Instruct-v0.2", seq_length=512, text=sample_text)
# More challenging text with longer tokens
sample_text_2 = """You will rejoice to hear that no disaster has accompanied the commencement of an enterprise which you have regarded with such evil forebodings. I arrived here yesterday, and my first task is to assure my dear sister of my welfare and increasing confidence in the success of my undertaking. I am already far north of London, and as I walk in the streets of Petersburgh, I feel a cold northern breeze play upon my cheeks, which braces my nerves and fills me with delight. Do you understand this feeling? This breeze, which has travelled from the regions towards which I am advancing, gives me a foretaste of those icy climes. Inspirited by this wind of promise, my daydreams become more fervent and vivid. I try in vain to be persuaded that the pole is the seat of frost and desolation; it ever presents itself to my imagination as the region of beauty and delight. There, Margaret, the sun is for ever visible, its broad disk just skirting the horizon and diffusing a perpetual splendour. There—for with your leave, my sister, I will put some trust in preceding navigators—there snow and frost are banished; and, sailing over a calm sea, we may be wafted to a land surpassing in wonders and in beauty every region hitherto discovered on the habitable globe. Its productions and features may be without example, as the phenomena of the heavenly bodies undoubtedly are in those undiscovered solitudes. What may not be expected in a country of eternal light? I may there discover the wondrous power which attracts the needle and may regulate a thousand celestial observations that require only this voyage to render their seeming eccentricities consistent for ever. I shall satiate my ardent curiosity with the sight of a part of the world never before visited, and may tread a land never before imprinted by the foot of man. These are my enticements, and they are sufficient to conquer all fear of danger or death and to induce me to commence this laborious voyage with the joy a child feels when he embarks in a little boat, with his holiday mates, on an expedition of discovery up his native river."""
tokenizer_roundtrip("mistralai/Mistral-7B-Instruct-v0.2", seq_length=512, text=sample_text_2)


Original: It was on a dreary night of November that I beheld the accomplishment of my toils.
Decoded: It was on a dreary night of November that I beheld the accomplishment of my toils.
Round-trip match: True
Token count: 22
Character count: 82
Average token length: 3.91
Original: You will rejoice to hear that no disaster has accompanied the commencement of an enterprise which you have regarded with such evil forebodings. I arrived here yesterday, and my first task is to assure my dear sister of my welfare and increasing confidence in the success of my undertaking. I am already far north of London, and as I walk in the streets of Petersburgh, I feel a cold northern breeze play upon my cheeks, which braces my nerves and fills me with delight. Do you understand this feeling? This breeze, which has travelled from the regions towards which I am advancing, gives me a foretaste of those icy climes. Inspirited by this wind of promise, my daydreams become more fervent and vivid. I try in vain t

## Tokenization Function for Dataset

We need a function that tokenizes the entire dataset. This will be used in training.

### Why do we need this function?

When training a language model, we can't feed raw text directly. The model expects:
- **Token IDs** (numbers representing words/subwords)
- **Attention masks** (telling the model which tokens are real vs padding)

### How `dataset.map()` works:

The Hugging Face `datasets` library has a `.map()` method that applies a function to every example in the dataset. 

**Example:**
```python
# Your dataset has examples like: {"text": "It was a dark night..."}
# After tokenization, you want: {"input_ids": [1234, 567, 890, ...], "attention_mask": [1, 1, 1, ...]}

tokenized_dataset = dataset.map(tokenize_function)
```

### What the function should do:

1. **Take a batch** (or single example) from the dataset
2. **Extract the text** from the 'text' column
3. **Tokenize it** → convert text to token IDs
4. **Handle truncation** → if text > 512 tokens, cut it off
5. **Create attention mask** → mark which tokens are real (1) vs padding (0)
6. **Return** a dict with 'input_ids' and 'attention_mask'

### Key parameters:
- `truncation=True`: Cut off text if too long
- `max_length=seq_length`: Maximum tokens (512)
- `padding=False`: We'll pad later during batching (more efficient)


### Visual Example

Here's what happens step by step:

```python
# Step 1: Your dataset has raw text
example = {"text": "It was a dark night"}

# Step 2: Tokenizer converts text to numbers
tokenizer("It was a dark night", truncation=True, max_length=512)
# Returns:
# {
#   'input_ids': [1234, 567, 890, 123, 456],  # Token IDs
#   'attention_mask': [1, 1, 1, 1, 1]          # All 1s (no padding yet)
# }

# Step 3: dataset.map() applies this to every example
# Result: Each example now has 'input_ids' and 'attention_mask'
```

**Why `attention_mask`?**
- `1` = real token (the model should pay attention)
- `0` = padding token (ignore this, it's just filler)
- Later, when batching, shorter sequences get padded with 0s


In [6]:
# === TODO (you code this) ===
# Prepare map() function to tokenize Dataset with truncation and optional packing.
# Hints:
#   - Create a function that takes a batch dict (or single example)
#   - Extract 'text' from the batch
#   - Use tokenizer() with truncation=True, max_length=seq_length, padding=False
#   - Return dict with 'input_ids' and 'attention_mask'
#   - This function will be used with dataset.map()
# Acceptance:
#   - returns tokenized dataset with 'input_ids' and 'attention_mask'

# Example of what the function receives:
# Input: {"text": "It was a dark night..."}
# Output: {"text": "It was a dark night...", "input_ids": [1234, 567, 890, ...], "attention_mask": [1, 1, 1, ...]}

def build_tokenize_fn(tokenizer, seq_length: int):
    """
    Build a tokenization function for dataset.map().
    
    This function will be called by dataset.map() for each example (or batch).
    It receives a dict with 'text' and returns the same dict + 'input_ids' + 'attention_mask'.
    
    Args:
        tokenizer: Pre-trained tokenizer
        seq_length: Maximum sequence length
        
    Returns:
        callable: Function that tokenizes a single example or batch
    """
    def tokenize_function(examples):
        """
        This inner function is what dataset.map() will call.
        
        Args:
            examples: Dict with 'text' key (can be single string or list of strings)
            
        Returns:
            Dict with 'input_ids' and 'attention_mask' added
        """
        # TODO: 
        # 1. Use tokenizer() on examples['text']
        #    - Set truncation=True, max_length=seq_length, padding=False
        #    - This returns a dict with 'input_ids' and 'attention_mask'
        # 2. Return the tokenized result
        #    - The tokenizer() already returns what we need!
        result = tokenizer(
            examples['text'], 
            truncation=True, 
            max_length=seq_length, 
            padding=False
        )
        return result
    
    
    
    
    # Return the tokenization function
    return tokenize_function

# Test it
from datasets import load_dataset
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
tokenize_fn = build_tokenize_fn(tokenizer, seq_length=512)

# Test with a simple example
test_example = {"text": "It was on a dreary night of November that"}
result = tokenize_fn(test_example)
print("Test result keys:", list(result.keys()))
print("Input IDs shape:", len(result.get('input_ids', [])))
print("Attention mask shape:", len(result.get('attention_mask', [])))


Test result keys: ['input_ids', 'attention_mask']
Input IDs shape: 11
Attention mask shape: 11
