# Phase 2.1: Base Pre-Training for Mathematical Reasoning

This notebook demonstrates the complete pre-training infrastructure for training a decoder-only transformer on mixed mathematical and general text corpora.

## üöÄ Quick Start

**For GPU Training (Recommended):**
1. Open this notebook in [Google Colab](https://colab.research.google.com)
2. Go to Runtime ‚Üí Change runtime type ‚Üí Select GPU (T4 or better)
3. Run all cells

**For CPU Testing (Local):**
- Just run all cells (will use smaller model and fewer steps)

## üì¶ What's Included

- Streaming dataset for large-scale corpora
- Mixed-domain sampling (ArXiv + General text)
- Distributed training support (DDP)
- Mixed precision (fp16/bf16)
- Gradient accumulation
- Learning rate scheduling
- Automatic checkpointing
- TensorBoard logging

## 1. Setup and Installation

In [1]:
# Check if running on Colab
try:
    import google.colab
    IN_COLAB = True
    print("‚úì Running on Google Colab")
except ImportError:
    IN_COLAB = False
    print("‚úì Running locally")

# Clone repository if on Colab
if IN_COLAB:
    print("\nCloning repository...")
    !git clone https://github.com/Alpyaman/AI-Mathematical-Olympiad.git
    %cd AI-Mathematical-Olympiad
    print("‚úì Repository cloned")

‚úì Running on Google Colab

Cloning repository...
Cloning into 'AI-Mathematical-Olympiad'...
remote: Enumerating objects: 155, done.[K
remote: Counting objects: 100% (155/155), done.[K
remote: Compressing objects: 100% (124/124), done.[K
remote: Total 155 (delta 43), reused 140 (delta 29), pack-reused 0 (from 0)[K
Receiving objects: 100% (155/155), 11.19 MiB | 17.43 MiB/s, done.
Resolving deltas: 100% (43/43), done.
/content/AI-Mathematical-Olympiad
‚úì Repository cloned


In [2]:
# Install dependencies
print("Installing dependencies...")
!pip install -q torch numpy tqdm

# Optional: Install TensorBoard for logging
!pip install -q tensorboard

print("‚úì Dependencies installed")

Installing dependencies...
‚úì Dependencies installed


In [3]:
# Import required libraries
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
import json
from pathlib import Path

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: NVIDIA L4
CUDA memory: 23.80 GB


## 2. Import Phase 2.1 Components

In [6]:
# Import model and tokenizer
from src import (
    get_small_config,
    get_base_config,
    MathTransformerDecoder,
    MathTokenizer,
)

# Import training infrastructure
from src.training.pretrainer import PreTrainer, PreTrainingConfig

# Import data utilities
from src.data.pretraining_dataset import (
    create_sample_pretraining_data,
    prepare_pretraining_data,
    PreTrainingDataCollator,
)

print("‚úì All Phase 2.1 components imported successfully")

‚úì All Phase 2.1 components imported successfully


## 3. Configure Training

We'll automatically adjust based on available hardware.

In [7]:
# Detect hardware and configure accordingly
device = "cuda" if torch.cuda.is_available() else "cpu"
USE_GPU = device == "cuda"

if USE_GPU:
    # GPU Configuration - Faster training
    print("üöÄ GPU Training Configuration")
    MODEL_SIZE = "small"  # Can use "base" for larger GPUs
    BATCH_SIZE = 4
    GRAD_ACCUM_STEPS = 8
    MAX_STEPS = 500  # Increase to 10000+ for real training
    MIXED_PRECISION = "bf16" if torch.cuda.is_bf16_supported() else "fp16"
    NUM_WORKERS = 2
else:
    # CPU Configuration - Slower but still works
    print("üíª CPU Training Configuration (Demo Mode)")
    MODEL_SIZE = "small"
    BATCH_SIZE = 1
    GRAD_ACCUM_STEPS = 2
    MAX_STEPS = 20  # Very short for CPU demo
    MIXED_PRECISION = "fp32"
    NUM_WORKERS = 0

print(f"\nTraining Configuration:")
print(f"  Device: {device}")
print(f"  Model size: {MODEL_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Gradient accumulation: {GRAD_ACCUM_STEPS}")
print(f"  Effective batch size: {BATCH_SIZE * GRAD_ACCUM_STEPS}")
print(f"  Max steps: {MAX_STEPS}")
print(f"  Mixed precision: {MIXED_PRECISION}")

üöÄ GPU Training Configuration

Training Configuration:
  Device: cuda
  Model size: small
  Batch size: 4
  Gradient accumulation: 8
  Effective batch size: 32
  Max steps: 500
  Mixed precision: bf16


## 4. Prepare Pre-Training Data

We'll create sample mathematical and general text data for demonstration.

In [8]:
# [REPLACEMENT FOR STEP 4]
import os
import json
from datasets import load_dataset
from tqdm import tqdm
from pathlib import Path

# 1. Define the correct path
data_dir = "./data/pretraining"
print(f"üöÄ Preparing REAL data in: {data_dir}")

# 2. Setup Directories
math_dir = Path(data_dir) / "arxiv" # Keep folder name 'arxiv' for compatibility
general_dir = Path(data_dir) / "general"
os.makedirs(math_dir, exist_ok=True)
os.makedirs(general_dir, exist_ok=True)

# 3. Download MATH Data (Source: OpenWebMath)
# This is a high-quality dataset of math webpages
print("üìö Downloading Math data (OpenWebMath)...")
try:
    ds_math = load_dataset("open-web-math/open-web-math", split="train", streaming=True)

    with open(math_dir / "math_subset.jsonl", "w", encoding="utf-8") as f:
        count = 0
        for row in tqdm(ds_math, desc="Saving Math", total=10000):
            text = row.get('text', '')
            if len(text) > 500:
                json.dump({"text": text}, f)
                f.write("\n")
                count += 1
            if count >= 10000:
                break
    print("   ‚úÖ Math data downloaded successfully.")
except Exception as e:
    print(f"   ‚ö†Ô∏è Math download failed: {e}")

# 4. Download General Data (C4)
print("üåç Downloading General (C4) data...")
try:
    ds_general = load_dataset("allenai/c4", "en", split="train", streaming=True)

    with open(general_dir / "c4_subset.jsonl", "w", encoding="utf-8") as f:
        count = 0
        for row in tqdm(ds_general, desc="Saving General", total=10000):
            if len(row['text']) > 500:
                json.dump({"text": row['text']}, f)
                f.write("\n")
                count += 1
            if count >= 10000:
                break
    print("   ‚úÖ General data downloaded successfully.")
except Exception as e:
    print(f"   ‚ö†Ô∏è General failed: {e}")

print(f"\n‚úÖ Data setup complete. Path: {data_dir}")

üöÄ Preparing REAL data in: ./data/pretraining
üìö Downloading Math data (OpenWebMath)...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/114 [00:00<?, ?it/s]

Saving Math: 10807it [00:07, 1460.28it/s]


   ‚úÖ Math data downloaded successfully.
üåç Downloading General (C4) data...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Saving General: 12975it [00:07, 1632.44it/s]

   ‚úÖ General data downloaded successfully.

‚úÖ Data setup complete. Path: ./data/pretraining





In [9]:
# Preview the data
print("Sample ArXiv text:")
print("=" * 70)
with open(f"{data_dir}/arxiv/math_subset.jsonl", 'r') as f:
    sample = json.loads(f.readline())
    print(sample['text'])

print("\n" + "=" * 70)
print("Sample General text:")
print("=" * 70)
with open(f"{data_dir}/general/c4_subset.jsonl", 'r') as f:
    sample = json.loads(f.readline())
    print(sample['text'])

Sample ArXiv text:
Bayes and his¬†Theorem

My earlier post on Bayesian probability seems to have generated quite a lot of readers, so this lunchtime I thought I‚Äôd add a little bit of background. The previous discussion started from the result

$P(B|AC) = K^{-1}P(B|C)P(A|BC) = K^{-1} P(AB|C)$

where

$K=P(A|C).$

Although this is called Bayes‚Äô theorem, the general form of it as stated here was actually first written down, not by Bayes but by Laplace. What Bayes‚Äô did was derive the special case of this formula for ‚Äúinverting‚Äù the binomial distribution. This distribution gives the probability of x successes in n independent ‚Äútrials‚Äù each having the same probability of success, p; each ‚Äútrial‚Äù has only two possible outcomes (‚Äúsuccess‚Äù or ‚Äúfailure‚Äù). Trials like this are usually called Bernoulli trials, after Daniel Bernoulli. If we ask the question ‚Äúwhat is the probability of exactly x successes from the possible n?‚Äù, the answer is given by the binomial distri

## 5. Initialize Tokenizer

Our enhanced mathematical tokenizer with 200+ symbols.

In [10]:
# Initialize tokenizer
print("Initializing mathematical tokenizer...")
tokenizer = MathTokenizer()

print(f"‚úì Tokenizer initialized")
print(f"  Vocabulary size: {len(tokenizer):,}")
print(f"  Special tokens: {tokenizer.SPECIAL_TOKENS}")
print(f"  Mathematical symbols: 200+")

# Test tokenization
test_text = "Let f: ‚Ñù ‚Üí ‚Ñù be continuous. Then ‚à´‚ÇÄ¬π f(x)dx exists."
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded['input_ids'])

print(f"\nTokenization test:")
print(f"  Original: {test_text}")
print(f"  Decoded:  {decoded}")
print(f"  Tokens: {len(encoded['input_ids'])}")

Initializing mathematical tokenizer...
‚úì Tokenizer initialized
  Vocabulary size: 542
  Special tokens: {'pad': '<pad>', 'eos': '<eos>', 'bos': '<bos>', 'unk': '<unk>', 'sep': '<sep>', 'math_start': '<math>', 'math_end': '</math>', 'equation_start': '<eq>', 'equation_end': '</eq>', 'proof_start': '<proof>', 'proof_end': '</proof>', 'solution_start': '<solution>', 'solution_end': '</solution>', 'step': '<step>', 'step_end': '</step>', 'answer_start': '<answer>', 'answer_end': '</answer>'}
  Mathematical symbols: 200+

Tokenization test:
  Original: Let f: ‚Ñù ‚Üí ‚Ñù be continuous. Then ‚à´‚ÇÄ¬π f(x)dx exists.
  Decoded:  Let f: ‚Ñù ‚Üí ‚Ñù be continuous. Then ‚à´‚ÇÄ¬π f(x)dx exists.
  Tokens: 53


In [27]:
# --- DIAGNOSTIC CELL ---
print("üîç INSPECTING A BATCH...")
batch = next(iter(train_loader))
input_ids = batch['input_ids'][0]
labels = batch['labels'][0]

print(f"Batch Shape: {batch['input_ids'].shape}")
print("-" * 50)
print("First 20 Input Tokens:", input_ids[:20].tolist())
print("First 20 Labels:      ", labels[:20].tolist())
print("-" * 50)
print("Last 20 Input Tokens: ", input_ids[-20:].tolist())
print("Last 20 Labels:       ", labels[-20:].tolist())

# Check for the Padding Trap
pad_id = tokenizer.pad_token_id
pad_count = (input_ids == pad_id).sum().item()
print("-" * 50)
print(f"Padding Count: {pad_count} / {len(input_ids)} tokens")

if (labels[-1] != -100) and (input_ids[-1] == pad_id):
    print("üö® CRITICAL ISSUE: The model is training on PADDING tokens!")
    print("   The 'labels' for padding positions should be -100, but they are", labels[-1].item())
else:
    print("‚úÖ Masking looks correct (Labels are -100 for padding).")

üîç INSPECTING A BATCH...
Batch Shape: torch.Size([4, 512])
--------------------------------------------------
First 20 Input Tokens: [2, 496, 509, 513, 514, 509, 508, 521, 246, 497, 496, 513, 247, 521, 3, 521, 495, 480, 494, 476]
First 20 Labels:       [2, 496, 509, 513, 514, 509, 508, 521, 246, 497, 496, 513, 247, 521, 3, 521, 495, 480, 494, 476]
--------------------------------------------------
Last 20 Input Tokens:  [488, 486, 473, 469, 488, 521, 477, 488, 523, 521, 469, 482, 472, 521, 488, 483, 521, 469, 482, 1]
Last 20 Labels:        [488, 486, 473, 469, 488, 521, 477, 488, 523, 521, 469, 482, 472, 521, 488, 483, 521, 469, 482, 1]
--------------------------------------------------
Padding Count: 0 / 512 tokens
‚úÖ Masking looks correct (Labels are -100 for padding).


## 6. Prepare Streaming Dataset

Create a mixed-domain dataset that samples 30% from ArXiv and 70% from general text.

In [11]:
# Prepare streaming dataset
print("Preparing mixed-domain streaming dataset...")
print("-" * 70)

train_dataset = prepare_pretraining_data(
    data_dir=data_dir,
    sources=["arxiv", "general"],
    tokenizer=tokenizer,
    max_seq_length=512,  # Shorter for demo
    mix_weights=[0.3, 0.7],  # 30% math, 70% general
)

print("‚úì Streaming dataset created")
print(f"  Data sources: ArXiv (30%), General (70%)")
print(f"  Streaming mode: Yes (memory efficient)")
print(f"  Max sequence length: 512 tokens")

Preparing mixed-domain streaming dataset...
----------------------------------------------------------------------
‚úì Streaming dataset created
  Data sources: ArXiv (30%), General (70%)
  Streaming mode: Yes (memory efficient)
  Max sequence length: 512 tokens


In [14]:
# Create data loader
collator = PreTrainingDataCollator(pad_token_id=tokenizer.pad_token_id)
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    collate_fn=collator,
)

print("‚úì Data loader created")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Workers: {0}")

# Test loading a batch
sample_batch = next(iter(train_loader))
print(f"\nSample batch:")
print(f"  Input shape: {sample_batch['input_ids'].shape}")
print(f"  Attention mask shape: {sample_batch['attention_mask'].shape}")
print(f"  Labels shape: {sample_batch['labels'].shape}")

‚úì Data loader created
  Batch size: 4
  Workers: 0

Sample batch:
  Input shape: torch.Size([4, 512])
  Attention mask shape: torch.Size([4, 512])
  Labels shape: torch.Size([4, 512])


## 7. Initialize Model

Create the decoder-only transformer from Phase 1.1.

In [15]:
# Get model configuration
if MODEL_SIZE == "small":
    config = get_small_config()
    # Further reduce for demo if on CPU
    if not USE_GPU:
        config.hidden_size = 256
        config.num_hidden_layers = 4
        config.num_attention_heads = 4
        config.num_key_value_heads = 4
        config.intermediate_size = 1024
elif MODEL_SIZE == "base":
    config = get_base_config()

# Update vocab size to match tokenizer
config.vocab_size = len(tokenizer)
config.max_position_embeddings = 512

# Initialize model
print(f"Initializing {MODEL_SIZE} model...")
model = MathTransformerDecoder(config)

# Count parameters
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úì Model initialized")
print(f"  Architecture: Decoder-only (Llama-style)")
print(f"  Hidden size: {config.hidden_size}")
print(f"  Layers: {config.num_hidden_layers}")
print(f"  Attention heads: {config.num_attention_heads}")
print(f"  Parameters: {num_params:,} ({num_trainable:,} trainable)")
print(f"  Positional encoding: RoPE (dynamic scaling)")
print(f"  Activation: SwiGLU")

# Show model size in MB
param_size_mb = num_params * 4 / (1024 ** 2)  # 4 bytes per float32 param
print(f"  Model size: {param_size_mb:.2f} MB")

Initializing small model...

‚úì Model initialized
  Architecture: Decoder-only (Llama-style)
  Hidden size: 512
  Layers: 8
  Attention heads: 8
  Parameters: 34,118,144 (34,118,144 trainable)
  Positional encoding: RoPE (dynamic scaling)
  Activation: SwiGLU
  Model size: 130.15 MB


In [68]:
# [REPLACEMENT FOR PRETRAINER CLASS - CORRECTED]
from typing import Dict
import torch.nn as nn
from torch.amp import autocast

def fixed_train_step(self, batch: Dict[str, torch.Tensor]) -> float:
    """
    Perform a single training step with CORRECT label shifting and Type handling.
    """
    # Move batch to device
    batch = {k: v.to(self.device) for k, v in batch.items()}

    # Forward pass with mixed precision
    with autocast("cuda", dtype=self.dtype, enabled=self.use_amp):
        outputs = self.model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
        )

        # --- FIX 1: Extract Logits correctly ---
        if isinstance(outputs, dict):
            logits = outputs['logits']
        else:
            logits = outputs
        # ---------------------------------------

        # --- FIX 2: Shift logits and labels ---
        # We drop the last prediction (nothing to compare to)
        # and drop the first label (nothing predicts it)
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = batch["labels"][..., 1:].contiguous()
        # ------------------------------------

        # Compute loss on shifted tensors
        loss = nn.functional.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            ignore_index=-100,
        )

        # Scale loss for gradient accumulation
        loss = loss / self.config.gradient_accumulation_steps

    # Backward pass
    if self.scaler is not None:
        self.scaler.scale(loss).backward()
    else:
        loss.backward()

    return loss.item() * self.config.gradient_accumulation_steps

# Apply the fix to the existing class
PreTrainer.train_step = fixed_train_step
print("‚úÖ Patched PreTrainer with Logit Extraction + Shifting Logic")

‚úÖ Patched PreTrainer with Logit Extraction + Shifting Logic


## 8. Configure Pre-Training

In [69]:
# [REPLACEMENT FOR STEP 8]
training_config = PreTrainingConfig(
    model_config_name=MODEL_SIZE,
    vocab_size=config.vocab_size,
    max_seq_length=512,
    data_dir=data_dir,

    # --- FIX 1: Increase Steps ---
    # We want ~1000 effective updates.
    # 1000 updates * 8 accum steps = 8000 micro steps.
    max_steps=8000,

    micro_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS, # 8

    warmup_steps=500,
    learning_rate=3e-4,

    mixed_precision=MIXED_PRECISION,
    gradient_checkpointing=USE_GPU,

    checkpoint_dir="./checkpoints/pretraining_notebook",
    save_interval=2000,

    # --- FIX 2: Align Logging ---
    # To ensure we see logs, set log_interval relative to accum steps
    # But since the internal logic is tricky, we rely on the final result
    # or set it to 1 to force it to try often.
    log_interval=1,

    use_wandb=False,
    use_tensorboard=False,
    num_workers=0,
    seed=42,
)

print("‚úì Real Training Configuration Applied")
print(f"  Target: {training_config.max_steps} micro-steps")
print(f"  Estimated Duration: ~10-15 minutes")

‚úì Real Training Configuration Applied
  Target: 8000 micro-steps
  Estimated Duration: ~10-15 minutes


In [70]:
# [REPLACEMENT FOR DATA COLLATOR]
import torch
from torch.nn.utils.rnn import pad_sequence

class RobustDataCollator:
    """
    A robust collator that:
    1. Pads sequences to the longest length in the batch.
    2. Ensures padding is masked in the labels (-100).
    """
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        # Ensure we have a valid pad ID, default to 0 if None
        self.pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

    def __call__(self, features):
        # 1. Extract inputs (they might vary in length)
        input_ids_list = [f['input_ids'] for f in features]
        attention_mask_list = [f['attention_mask'] for f in features]

        # 2. Dynamically pad the batch (makes them all same length)
        # batch_first=True results in (batch_size, seq_len)
        input_ids = pad_sequence(input_ids_list, batch_first=True, padding_value=self.pad_id)
        attention_mask = pad_sequence(attention_mask_list, batch_first=True, padding_value=0)

        # 3. Create labels
        labels = input_ids.clone()

        # 4. Mask padding tokens so loss is NOT calculated on them
        # Set label to -100 wherever the input is a pad token
        labels[input_ids == self.pad_id] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

# Re-initialize with the SMART collator
print("üîÑ Rebuilding Data Loader with Dynamic Padding...")
fix_collator = RobustDataCollator(tokenizer)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=0,
    collate_fn=fix_collator,
)
print("‚úÖ Loader ready.")

üîÑ Rebuilding Data Loader with Dynamic Padding...
‚úÖ Loader ready.


## 9. Initialize Pre-Trainer

In [71]:
# Initialize pre-trainer
print("Initializing Pre-Trainer...")
print("=" * 70)

trainer = PreTrainer(
    model=model,
    config=training_config,
    train_dataloader=train_loader,
    val_dataloader=None,
)

print("\n‚úì Pre-Trainer ready!")

Initializing Pre-Trainer...

PRE-TRAINING CONFIGURATION
Model: small
Max steps: 8,000
Batch size: 4 (per device)
Gradient accumulation: 8 steps
Effective batch size: 32
Learning rate: 0.0003
Mixed precision: bf16
Gradient checkpointing: True
World size: 1
Device: cuda:0


‚úì Pre-Trainer ready!


## 10. Run Pre-Training

This is where the magic happens! üéØ

In [72]:
# Start training
print("\n" + "=" * 70)
print("STARTING BASE PRE-TRAINING")
print("=" * 70)

if not USE_GPU:
    print("‚ö†Ô∏è  Running on CPU - this will be slow!")
    print("    For faster training, use Google Colab with GPU")
    print()

trainer.train()

print("\n" + "=" * 70)
print("‚úÖ PRE-TRAINING COMPLETE!")
print("=" * 70)


STARTING BASE PRE-TRAINING

Starting pre-training from step 0...
Step 7/8000 | Loss: 17.2963 | LR: 6.00e-07 | Tokens: 2,048 | Time: 1.0s
Step 15/8000 | Loss: 17.3815 | LR: 1.20e-06 | Tokens: 4,096 | Time: 1.4s
Step 23/8000 | Loss: 17.2314 | LR: 1.80e-06 | Tokens: 6,144 | Time: 1.8s
Step 31/8000 | Loss: 17.0230 | LR: 2.40e-06 | Tokens: 8,192 | Time: 2.3s
Step 39/8000 | Loss: 17.1445 | LR: 3.00e-06 | Tokens: 10,240 | Time: 2.7s
Step 47/8000 | Loss: 17.1857 | LR: 3.60e-06 | Tokens: 12,288 | Time: 3.1s
Step 55/8000 | Loss: 16.9725 | LR: 4.20e-06 | Tokens: 14,336 | Time: 3.6s
Step 63/8000 | Loss: 16.9856 | LR: 4.80e-06 | Tokens: 16,384 | Time: 4.0s
Step 71/8000 | Loss: 16.7718 | LR: 5.40e-06 | Tokens: 18,432 | Time: 4.4s
Step 79/8000 | Loss: 16.5245 | LR: 6.00e-06 | Tokens: 20,480 | Time: 4.8s
Step 87/8000 | Loss: 16.3853 | LR: 6.60e-06 | Tokens: 22,528 | Time: 5.2s
Step 95/8000 | Loss: 15.8645 | LR: 7.20e-06 | Tokens: 24,576 | Time: 5.7s
Step 103/8000 | Loss: 15.8225 | LR: 7.80e-06 | Toke

## 11. Test the Trained Model

Let's generate some text to see what the model learned!

In [73]:
# [REPLACEMENT FOR STEP 11 FUNCTION]
@torch.no_grad()
def generate_text(model, tokenizer, prompt, max_length=50, temperature=0.8, device="cuda"):
    """Generate text continuation from a prompt."""
    model.eval()

    # Encode prompt
    encoded = tokenizer.encode(prompt)
    input_ids = torch.tensor([encoded['input_ids']], dtype=torch.long).to(device)

    # Generate
    for _ in range(max_length):
        # Forward pass
        outputs = model(input_ids)

        # --- FIX: Extract logits from dictionary ---
        if isinstance(outputs, dict):
            logits = outputs['logits']
        else:
            logits = outputs # Fallback if it's already a tensor

        # Get next token logits
        next_token_logits = logits[0, -1, :] / temperature
        # -------------------------------------------

        # Sample next token
        probs = torch.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # Append to sequence
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)

        # Stop if EOS token
        if next_token.item() == tokenizer.eos_token_id:
            break

    # Decode
    generated_ids = input_ids[0].tolist()
    return tokenizer.decode(generated_ids)

print("Testing text generation...")
print("=" * 70)

Testing text generation...


In [77]:
# Test 1: Mathematical prompt
math_prompt = "Let f: ‚Ñù ‚Üí ‚Ñù be a continuous function. Then"
generated = generate_text(
    model=trainer.raw_model,
    tokenizer=tokenizer,
    prompt=math_prompt,
    max_length=50,
    device=device
)

print("Mathematical generation:")
print(f"Prompt: {math_prompt}")
print(f"Generated: {generated}")
print()

Mathematical generation:
Prompt: Let f: ‚Ñù ‚Üí ‚Ñù be a continuous function. Then
Generated: Let f: ‚Ñù ‚Üí ‚Ñù be a continuous function. Thenrything the reach that are started with the attacc



In [78]:
# Test 2: Theorem prompt
theorem_prompt = "Theorem: For any prime number p, we have"
generated = generate_text(
    model=trainer.raw_model,
    tokenizer=tokenizer,
    prompt=theorem_prompt,
    max_length=50,
    device=device
)

print("Theorem generation:")
print(f"Prompt: {theorem_prompt}")
print(f"Generated: {generated}")
print()

Theorem generation:
Prompt: Theorem: For any prime number p, we have
Generated: Theorem: For any prime number p, we have the hanted to mean the standard team of the card



In [79]:
# Test 3: General text prompt
general_prompt = "The history of mathematics began in"
generated = generate_text(
    model=trainer.raw_model,
    tokenizer=tokenizer,
    prompt=general_prompt,
    max_length=50,
    device=device
)

print("General text generation:")
print(f"Prompt: {general_prompt}")
print(f"Generated: {generated}")

General text generation:
Prompt: The history of mathematics began in
Generated: The history of mathematics began inmal started as one to use the location group, part


## 12. Save and Load Checkpoints

In [80]:
# Save final checkpoint
checkpoint_path = Path(training_config.checkpoint_dir) / "final_notebook.pt"
trainer.save_checkpoint("final_notebook.pt")

print(f"‚úì Checkpoint saved to: {checkpoint_path}")
print(f"\nCheckpoint contains:")
print(f"  - Model weights")
print(f"  - Optimizer state")
print(f"  - Training step: {trainer.global_step}")
print(f"  - Tokens seen: {trainer.tokens_seen:,}")

Checkpoint saved: checkpoints/pretraining_notebook/final_notebook.pt
‚úì Checkpoint saved to: checkpoints/pretraining_notebook/final_notebook.pt

Checkpoint contains:
  - Model weights
  - Optimizer state
  - Training step: 3574
  - Tokens seen: 913,408


In [81]:
# Example: Load checkpoint
if checkpoint_path.exists():
    print("Loading checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=device)

    print(f"\n‚úì Checkpoint loaded")
    print(f"  Training step: {checkpoint['global_step']}")
    print(f"  Tokens seen: {checkpoint['tokens_seen']:,}")
    print(f"  Config: {checkpoint['config']['model_config_name']}")

Loading checkpoint...

‚úì Checkpoint loaded
  Training step: 3574
  Tokens seen: 913,408
  Config: small


## 13. Summary and Next Steps

In [82]:
print("=" * 70)
print("PHASE 2.1 DEMONSTRATION COMPLETE!")
print("=" * 70)
print("\n‚úì Successfully demonstrated:")
print("  1. Streaming dataset for large-scale corpora")
print("  2. Mixed-domain data sampling (ArXiv + General)")
print("  3. Decoder-only transformer architecture")
print("  4. Pre-training with causal language modeling")
print("  5. Mixed precision training" if USE_GPU else "  5. CPU training (demo mode)")
print("  6. Automatic checkpointing")
print("  7. Text generation from trained model")
print()
print("Training Statistics:")
print(f"  Steps completed: {trainer.global_step}")
print(f"  Tokens processed: {trainer.tokens_seen:,}")
print(f"  Model parameters: {num_params:,}")
print()
print("Next steps for FULL pre-training:")
print("  1. Prepare large-scale datasets:")
print("     - ArXiv papers (LaTeX extraction): ~2M papers")
print("     - C4 corpus: 750GB of web text")
print("     - Wikipedia: ~6M articles")
print("     - Books corpus")
print()
print("  2. Scale up training:")
print("     - Use base or large model")
print("     - Train for 100K-1M steps")
print("     - Use multiple GPUs with DDP:")
print("       torchrun --nproc_per_node=4 pretrain.py")
print()
print("  3. Monitor with wandb:")
print("     python pretrain.py --use-wandb --wandb-project my-project")
print()
print("  4. Proceed to Phase 2.2: Mathematical Fine-tuning")
print("     - Fine-tune on MATH dataset")
print("     - Add reinforcement learning")
print("     - Outcome supervision")
print()
print("Checkpoints saved to:", training_config.checkpoint_dir)
print("=" * 70)

PHASE 2.1 DEMONSTRATION COMPLETE!

‚úì Successfully demonstrated:
  1. Streaming dataset for large-scale corpora
  2. Mixed-domain data sampling (ArXiv + General)
  3. Decoder-only transformer architecture
  4. Pre-training with causal language modeling
  5. Mixed precision training
  6. Automatic checkpointing
  7. Text generation from trained model

Training Statistics:
  Steps completed: 3574
  Tokens processed: 913,408
  Model parameters: 34,118,144

Next steps for FULL pre-training:
  1. Prepare large-scale datasets:
     - ArXiv papers (LaTeX extraction): ~2M papers
     - C4 corpus: 750GB of web text
     - Wikipedia: ~6M articles
     - Books corpus

  2. Scale up training:
     - Use base or large model
     - Train for 100K-1M steps
     - Use multiple GPUs with DDP:
       torchrun --nproc_per_node=4 pretrain.py

  3. Monitor with wandb:
     python pretrain.py --use-wandb --wandb-project my-project

  4. Proceed to Phase 2.2: Mathematical Fine-tuning
     - Fine-tune on MAT

## 14. (Optional) Download Checkpoint

If running on Colab, you can download the checkpoint to your local machine.

In [83]:
if IN_COLAB:
    from google.colab import files

    # Zip checkpoints
    !zip -r checkpoints.zip checkpoints/

    print("Downloading checkpoint...")
    files.download('checkpoints.zip')
    print("‚úì Download started")
else:
    print("Checkpoints are saved locally at:", training_config.checkpoint_dir)

  adding: checkpoints/ (stored 0%)
  adding: checkpoints/pretraining_notebook/ (stored 0%)
  adding: checkpoints/pretraining_notebook/final.pt (deflated 9%)
  adding: checkpoints/pretraining_notebook/final_notebook.pt (deflated 9%)
Downloading checkpoint...


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

‚úì Download started


---

## üìö Additional Resources

**Documentation:**
- See `PHASE_2_1_README.md` for comprehensive documentation
- Run `python pretrain.py --help` for CLI options

**Scaling Up:**
```bash
# Multi-GPU training (4 GPUs)
torchrun --nproc_per_node=4 pretrain.py \
    --model-size base \
    --batch-size 4 \
    --gradient-accumulation-steps 8 \
    --max-steps 500000 \
    --mixed-precision bf16 \
    --use-wandb
```

**Key Papers:**
- [Chinchilla: Training Compute-Optimal LLMs](https://arxiv.org/abs/2203.15556)
- [LLaMA: Open Foundation LLMs](https://arxiv.org/abs/2302.13971)
- [Minerva: Mathematical Reasoning](https://arxiv.org/abs/2206.14858)

---

**Happy Pre-Training! üöÄ**