#  Data Preparation & Tokenization

Load datasets, format with `<reasoning>` tags, and pre-tokenize for training.

**Time estimate:** ~30-40 minutes

In [None]:
import json
import os
from datasets import load_dataset
from transformers import AutoTokenizer

MODEL_NAME = "google/gemma-3-1b-it"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

## 1. Load GSM8K Dataset

In [None]:
# Load GSM8K from Hugging Face
gsm8k = load_dataset("gsm8k", "main")
print(f"GSM8K train: {len(gsm8k['train'])} examples")
print(f"GSM8K test: {len(gsm8k['test'])} examples")

# Show sample
print("\n Sample:")
sample = gsm8k['train'][0]
print(f"Question: {sample['question'][:200]}...")
print(f"Answer: {sample['answer'][:200]}...")

## 2. Format Data with Reasoning Tags

In [None]:
def format_gsm8k_example(example):
    """Convert GSM8K to <reasoning>/<answer> format."""
    question = example['question']
    full_answer = example['answer']
    
    # GSM8K uses #### to separate reasoning from final answer
    if '####' in full_answer:
        reasoning, final_answer = full_answer.rsplit('####', 1)
        reasoning = reasoning.strip()
        final_answer = final_answer.strip()
    else:
        reasoning = full_answer
        final_answer = full_answer.split('\n')[-1].strip()
    
    formatted = f"""Q: {question}
A:
<reasoning>{reasoning}</reasoning>
<answer>{final_answer}</answer>"""
    
    return {
        'text': formatted,
        'reference_answer': final_answer,
        'domain': 'math'
    }

# Test formatting
formatted_sample = format_gsm8k_example(gsm8k['train'][0])
print(" Formatted example:")
print(formatted_sample['text'][:500])

In [None]:
# Format entire dataset
train_formatted = [format_gsm8k_example(ex) for ex in gsm8k['train']]
test_formatted = [format_gsm8k_example(ex) for ex in gsm8k['test']]

print(f" Formatted {len(train_formatted)} train, {len(test_formatted)} test")

## 3. Create Train/Val Split

In [None]:
import random
random.seed(42)

# Shuffle and split
shuffled = train_formatted.copy()
random.shuffle(shuffled)

val_size = 200  # Small validation set
train_data = shuffled[val_size:]
val_data = shuffled[:val_size]

print(f"Train: {len(train_data)} examples")
print(f"Validation: {len(val_data)} examples")

## 4. Save Prepared Data

In [None]:
os.makedirs('data/prepared', exist_ok=True)

# Save as JSONL
def save_jsonl(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

save_jsonl(train_data, 'data/prepared/train.jsonl')
save_jsonl(val_data, 'data/prepared/valid.jsonl')
save_jsonl(test_formatted, 'data/prepared/test.jsonl')

print(" Saved:")
print(f"  - data/prepared/train.jsonl ({len(train_data)} examples)")
print(f"  - data/prepared/valid.jsonl ({len(val_data)} examples)")
print(f"  - data/prepared/test.jsonl ({len(test_formatted)} examples)")

## 5. Tokenize Data

In [None]:
MAX_LENGTH = 1024  # Max tokens

def tokenize_example(example):
    """Tokenize a single example."""
    tokens = tokenizer(
        example['text'],
        truncation=True,
        max_length=MAX_LENGTH,
        return_tensors=None
    )
    return {
        'input_ids': tokens['input_ids'],
        'attention_mask': tokens['attention_mask'],
        'reference_answer': example.get('reference_answer', ''),
        'domain': example.get('domain', 'general')
    }

# Tokenize all
train_tokenized = [tokenize_example(ex) for ex in train_data]
val_tokenized = [tokenize_example(ex) for ex in val_data]

print(f" Tokenized {len(train_tokenized)} train, {len(val_tokenized)} val")

# Stats
lengths = [len(ex['input_ids']) for ex in train_tokenized]
print(f"\nToken length stats:")
print(f"  Min: {min(lengths)}, Max: {max(lengths)}, Avg: {sum(lengths)/len(lengths):.1f}")

In [None]:
# Save tokenized data
os.makedirs('data/tokenized', exist_ok=True)

save_jsonl(train_tokenized, 'data/tokenized/train.jsonl')
save_jsonl(val_tokenized, 'data/tokenized/valid.jsonl')

print(" Saved tokenized data to data/tokenized/")

## 6. Verify Data

In [None]:
# Verify we can decode back
sample = train_tokenized[0]
decoded = tokenizer.decode(sample['input_ids'])

print(" Sample decoded:")
print(decoded[:800])

# Check format compliance
has_reasoning = '<reasoning>' in decoded and '</reasoning>' in decoded
has_answer = '<answer>' in decoded and '</answer>' in decoded
print(f"\n Has <reasoning> tags: {has_reasoning}")
print(f" Has <answer> tags: {has_answer}")

In [None]:
print("\n" + "="*50)
print("DATA PREPARATION COMPLETE")
print("="*50)
print(f"Train examples: {len(train_tokenized)}")
print(f"Val examples: {len(val_tokenized)}")
print("\n Proceed to: 02_sft_training.ipynb")