# MaxText Training Demo

This notebook demonstrates how to:
- Set up MaxText for training
- Configure model parameters
- Prepare synthetic and real datasets
- Run training with monitoring
- Save and manage checkpoints
- Evaluate model performance

This demo showcases MaxText's capabilities for training large language models on TPUs and GPUs.

## 1. Environment Setup and Installation

In [None]:
# # Install MaxText and dependencies
# !git clone https://github.com/AI-Hypercomputer/maxtext.git
# %cd maxtext
# !bash setup.sh

In [None]:
# # Install additional Python dependencies
# !pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# !pip install -q flax optax orbax-checkpoint grain-python tensorflow-datasets
# !pip install -q sentencepiece transformers datasets
# !pip install -q tensorboardX matplotlib seaborn pandas

In [None]:
!pip install nest_asyncio


In [None]:
# Import required libraries
import os
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime
from pathlib import Path

import jax
import jax.numpy as jnp
from jax import random
import flax
from flax import linen as nn
import optax

# Add MaxText to path
sys.path.append(os.path.abspath('.'))

from MaxText import pyconfig

print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")
print(f"Number of devices: {jax.device_count()}")
print(f"Devices: {jax.devices()}")

In [None]:

import nest_asyncio

# Get the current working directory of the notebook
# This will be '.../maxtext/MaxText/scratch_code'
current_dir = os.getcwd()

# Navigate two levels up to get to the project root 'maxtext'
project_root = os.path.abspath(os.path.join(current_dir, '..', '..'))

# Add the project root to the system path if it's not already there
if project_root not in sys.path:
    sys.path.insert(0, project_root)
    print(f"Added '{project_root}' to sys.path")

nest_asyncio.apply()

## 2. Configuration Setup

In [None]:
# Set up Google Cloud Storage bucket (replace with your bucket)
GCS_BUCKET = "go/my-buckets-tpu-prod-env-one-vm" #"gs://your-maxtext-bucket"  # Replace with your bucket
BASE_OUTPUT_DIR = f"{GCS_BUCKET}/training_outputs"
DATASET_PATH = f"{GCS_BUCKET}/datasets"

# Create a unique run name
RUN_NAME = f"training_demo_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

print(f"Run name: {RUN_NAME}")
print(f"Output directory: {BASE_OUTPUT_DIR}/{RUN_NAME}")

In [None]:
config = pyconfig.initialize(
    ["", "../configs/base.yml"], 
    per_device_batch_size=1.0,
    run_name="test",
    max_target_length=4,
    max_prefill_predict_length=4,
    tokenizer_type="tiktoken",
    tokenizer_path="assets/tokenizer_llama3.tiktoken/",
    load_parameters_path="path/to/your/llama3.1-8b/checkpoint",  # Replace with your checkpoint path
    model_name="llama3.1-8b",
    async_checkpointing=False,

)

## 3. Dataset Preparation

In [None]:
# Function to create synthetic dataset for testing
def create_synthetic_dataset(batch_size, seq_length, vocab_size, num_batches=10):
    """Create synthetic training data for testing."""
    key = random.PRNGKey(42)
    
    datasets = []
    for i in range(num_batches):
        key, subkey = random.split(key)
        
        # Create random token sequences
        inputs = random.randint(subkey, (batch_size, seq_length), 0, vocab_size)
        
        # Targets are shifted inputs (for autoregressive training)
        targets = jnp.concatenate([
            inputs[:, 1:],
            random.randint(subkey, (batch_size, 1), 0, vocab_size)
        ], axis=1)
        
        # Create attention masks (all ones for synthetic data)
        inputs_segmentation = jnp.ones_like(inputs)
        targets_segmentation = jnp.ones_like(targets)
        
        # Position indices
        inputs_position = jnp.tile(
            jnp.arange(seq_length)[None, :],
            (batch_size, 1)
        )
        
        batch_data = {
            'inputs': inputs,
            'targets': targets,
            'inputs_segmentation': inputs_segmentation,
            'targets_segmentation': targets_segmentation,
            'inputs_position': inputs_position,
        }
        datasets.append(batch_data)
    
    return datasets

# Create synthetic dataset
synthetic_data = create_synthetic_dataset(
    batch_size=config.per_device_batch_size,
    seq_length=config.max_target_length,
    vocab_size=config.vocab_size,
    num_batches=100
)

print(f"Created {len(synthetic_data)} synthetic batches")
print(f"Batch shape: {synthetic_data[0]['inputs'].shape}")

In [None]:
# Optional: Load HuggingFace dataset
def load_huggingface_dataset(dataset_name='wikitext', subset='wikitext-2-raw-v1'):
    """Load and preprocess a HuggingFace dataset."""
    from datasets import load_dataset
    from transformers import AutoTokenizer
    
    # Load dataset
    dataset = load_dataset(dataset_name, subset, split='train')
    
    # Load tokenizer (using GPT-2 as example)
    tokenizer = AutoTokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token
    
    # Tokenize function
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=config_params['max_seq_length'],
            return_tensors='np'
        )
    
    # Tokenize dataset
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names
    )
    
    return tokenized_dataset

# Uncomment to load real dataset
# real_dataset = load_huggingface_dataset()
# print(f"Loaded dataset with {len(real_dataset)} examples")

## 4. Model Training

In [None]:
# Run training with synthetic data
import subprocess

# Construct training command
train_command = f"""
python3 -m MaxText.train MaxText/configs/base.yml \
    run_name={RUN_NAME} \
    base_output_directory={BASE_OUTPUT_DIR} \
    dataset_type=synthetic \
    steps=100 \
    per_device_batch_size={training_config['per_device_batch_size']} \
    learning_rate={training_config['learning_rate']} \
    warmup_steps={training_config['warmup_steps']} \
    log_period={training_config['log_period']} \
    checkpoint_period={training_config['checkpoint_period']} \
    dtype={training_config['dtype']} \
    enable_checkpointing=True
""".strip()

print("Training command:")
print(train_command)
print("\nStarting training...")

# Run training (uncomment to execute)
# result = subprocess.run(train_command, shell=True, capture_output=True, text=True)
# print(result.stdout)
# if result.stderr:
#     print("Errors:", result.stderr)

In [None]:
# Alternative: Training loop implementation for demonstration
from typing import Dict, Any
import time

class TrainingMetrics:
    """Track training metrics."""
    def __init__(self):
        self.losses = []
        self.learning_rates = []
        self.grad_norms = []
        self.throughput = []
        self.steps = []
    
    def update(self, step, loss, lr, grad_norm, tokens_per_second):
        self.steps.append(step)
        self.losses.append(loss)
        self.learning_rates.append(lr)
        self.grad_norms.append(grad_norm)
        self.throughput.append(tokens_per_second)
    
    def plot(self):
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Loss curve
        axes[0, 0].plot(self.steps, self.losses)
        axes[0, 0].set_xlabel('Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training Loss')
        axes[0, 0].grid(True)
        
        # Learning rate schedule
        axes[0, 1].plot(self.steps, self.learning_rates)
        axes[0, 1].set_xlabel('Step')
        axes[0, 1].set_ylabel('Learning Rate')
        axes[0, 1].set_title('Learning Rate Schedule')
        axes[0, 1].grid(True)
        
        # Gradient norm
        axes[1, 0].plot(self.steps, self.grad_norms)
        axes[1, 0].set_xlabel('Step')
        axes[1, 0].set_ylabel('Gradient Norm')
        axes[1, 0].set_title('Gradient Norm')
        axes[1, 0].grid(True)
        
        # Throughput
        axes[1, 1].plot(self.steps, self.throughput)
        axes[1, 1].set_xlabel('Step')
        axes[1, 1].set_ylabel('Tokens/Second')
        axes[1, 1].set_title('Training Throughput')
        axes[1, 1].grid(True)
        
        plt.tight_layout()
        return fig

# Simulate training loop
def simulate_training_loop(num_steps=100):
    """Simulate a training loop with synthetic metrics."""
    metrics = TrainingMetrics()
    
    print("Starting simulated training...")
    print("-" * 50)
    
    for step in range(1, num_steps + 1):
        # Simulate metrics
        loss = 10.0 * np.exp(-step / 50) + np.random.normal(0, 0.1)
        lr = training_config['learning_rate'] * min(1.0, step / training_config['warmup_steps'])
        grad_norm = 2.0 * np.exp(-step / 100) + np.random.normal(0, 0.05)
        tokens_per_second = 50000 + np.random.normal(0, 1000)
        
        # Update metrics
        metrics.update(step, loss, lr, grad_norm, tokens_per_second)
        
        # Print progress
        if step % training_config['log_period'] == 0:
            print(f"Step {step:4d} | Loss: {loss:.4f} | LR: {lr:.6f} | "
                  f"Grad Norm: {grad_norm:.3f} | Throughput: {tokens_per_second:.0f} tok/s")
        
        # Simulate checkpoint saving
        if step % training_config['checkpoint_period'] == 0:
            print(f"  → Saving checkpoint at step {step}")
        
        time.sleep(0.01)  # Simulate computation time
    
    print("-" * 50)
    print("Training completed!")
    
    return metrics

# Run simulated training
metrics = simulate_training_loop(num_steps=100)

## 5. Training Monitoring and Visualization

In [None]:
# Plot training metrics
fig = metrics.plot()
plt.show()

# Save figure
fig.savefig(f'training_metrics_{RUN_NAME}.png', dpi=150, bbox_inches='tight')
print(f"Metrics plot saved to: training_metrics_{RUN_NAME}.png")

In [None]:
# Calculate and display training statistics
stats = {
    'Final Loss': metrics.losses[-1],
    'Average Loss': np.mean(metrics.losses),
    'Loss Reduction': metrics.losses[0] - metrics.losses[-1],
    'Average Throughput': np.mean(metrics.throughput),
    'Peak Throughput': np.max(metrics.throughput),
    'Average Gradient Norm': np.mean(metrics.grad_norms),
    'Total Training Time (simulated)': len(metrics.steps) * 0.01,
}

print("Training Statistics:")
print("=" * 40)
for key, value in stats.items():
    if 'Throughput' in key:
        print(f"{key:.<30} {value:,.0f} tokens/sec")
    elif 'Time' in key:
        print(f"{key:.<30} {value:.2f} seconds")
    else:
        print(f"{key:.<30} {value:.4f}")

## 6. Model Checkpointing

In [None]:
# Checkpoint management utilities
import orbax.checkpoint as ocp

def list_checkpoints(output_dir):
    """List available checkpoints."""
    checkpoint_dir = f"{output_dir}/checkpoints"
    
    # Simulate checkpoint listing
    checkpoints = [
        {'step': 100, 'loss': 5.234, 'size_mb': 450},
        {'step': 200, 'loss': 3.876, 'size_mb': 450},
        {'step': 300, 'loss': 2.945, 'size_mb': 450},
    ]
    
    df = pd.DataFrame(checkpoints)
    return df

# List checkpoints
checkpoints_df = list_checkpoints(BASE_OUTPUT_DIR)
print("Available Checkpoints:")
print(checkpoints_df.to_string(index=False))

# Plot checkpoint losses
plt.figure(figsize=(8, 4))
plt.plot(checkpoints_df['step'], checkpoints_df['loss'], 'o-')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.title('Checkpoint Losses')
plt.grid(True)
plt.show()

In [None]:
# Function to load checkpoint for inference or continued training
def load_checkpoint_command(checkpoint_step):
    """Generate command to load a checkpoint."""
    checkpoint_path = f"{BASE_OUTPUT_DIR}/checkpoints/checkpoint_{checkpoint_step}"
    
    command = f"""
python3 -m MaxText.train MaxText/configs/base.yml \
    run_name={RUN_NAME}_continued \
    base_output_directory={BASE_OUTPUT_DIR} \
    load_parameters_path={checkpoint_path} \
    dataset_type=synthetic \
    steps=200
""".strip()
    
    return command

# Example: Continue training from checkpoint
continue_command = load_checkpoint_command(checkpoint_step=100)
print("Command to continue training from checkpoint:")
print(continue_command)

## 7. Advanced Training Features

In [None]:
# Mixed precision training configuration
mixed_precision_config = {
    'dtype': 'bfloat16',  # or 'float32', 'float16'
    'weight_dtype': 'float32',  # Keep weights in full precision
    'gradient_accumulation_steps': 4,  # Accumulate gradients
    'use_gradient_checkpointing': True,  # Save memory
}

print("Mixed Precision Configuration:")
for key, value in mixed_precision_config.items():
    print(f"  {key}: {value}")

In [None]:
# Distributed training configuration for multi-host
distributed_config = {
    'ici_fsdp_parallelism': 8,  # Fully Sharded Data Parallelism
    'ici_tensor_parallelism': 1,  # Tensor parallelism
    'dcn_data_parallelism': 1,   # Data parallelism across data center network
    'compile_topology': 'v5e-256',  # TPU topology
    'compile_topology_num_slices': 2,  # Number of TPU slices
}

print("Distributed Training Configuration:")
for key, value in distributed_config.items():
    print(f"  {key}: {value}")
    
# Calculate total parallelism
total_parallelism = (
    distributed_config['ici_fsdp_parallelism'] * 
    distributed_config['ici_tensor_parallelism'] * 
    distributed_config['dcn_data_parallelism']
)
print(f"\nTotal parallelism: {total_parallelism}")

In [None]:
# Generate command for training specific models
def generate_model_training_command(model_name='llama2-7b', dataset='c4'):
    """Generate training command for specific models."""
    
    model_configs = {
        'llama2-7b': {
            'config': 'MaxText/configs/models/llama2-7b.yml',
            'batch_size': 4,
            'learning_rate': 3e-4,
            'steps': 10000,
        },
        'gemma-2b': {
            'config': 'MaxText/configs/models/gemma-2b.yml',
            'batch_size': 8,
            'learning_rate': 5e-4,
            'steps': 5000,
        },
        'mixtral-8x7b': {
            'config': 'MaxText/configs/models/mixtral-8x7b.yml',
            'batch_size': 2,
            'learning_rate': 1e-4,
            'steps': 15000,
        },
    }
    
    if model_name not in model_configs:
        raise ValueError(f"Unknown model: {model_name}")
    
    config = model_configs[model_name]
    
    command = f"""
python3 -m MaxText.train {config['config']} \
    run_name={model_name}_training \
    base_output_directory={BASE_OUTPUT_DIR} \
    dataset_path={DATASET_PATH}/{dataset} \
    dataset_type=c4 \
    per_device_batch_size={config['batch_size']} \
    learning_rate={config['learning_rate']} \
    steps={config['steps']} \
    enable_checkpointing=True \
    checkpoint_period=1000
""".strip()
    
    return command

# Generate commands for different models
for model in ['llama2-7b', 'gemma-2b', 'mixtral-8x7b']:
    print(f"\n{model} Training Command:")
    print("-" * 40)
    print(generate_model_training_command(model))

## 8. Performance Analysis

In [None]:
# Calculate Model FLOPs Utilization (MFU)
def calculate_mfu(model_params, hardware_flops, achieved_tflops):
    """Calculate Model FLOPs Utilization."""
    # Approximate FLOPs per token for transformer
    # 6 * model_params for forward and backward pass
    model_flops_per_token = 6 * model_params
    
    # MFU calculation
    mfu = achieved_tflops / hardware_flops
    
    return mfu

# Hardware specifications
hardware_specs = {
    'v5e-256': {'tflops': 197, 'memory_gb': 16},
    'v5p-128': {'tflops': 459, 'memory_gb': 95},
    'a100-80gb': {'tflops': 312, 'memory_gb': 80},
    'h100-80gb': {'tflops': 989, 'memory_gb': 80},
}

# Example performance calculation
model_size_b = 7  # 7B parameters
hardware = 'v5p-128'
achieved_tflops = 320  # Example achieved performance

mfu = calculate_mfu(
    model_params=model_size_b * 1e9,
    hardware_flops=hardware_specs[hardware]['tflops'],
    achieved_tflops=achieved_tflops
)

print(f"Performance Analysis for {model_size_b}B model on {hardware}:")
print(f"  Hardware Peak: {hardware_specs[hardware]['tflops']} TFLOPS")
print(f"  Achieved: {achieved_tflops} TFLOPS")
print(f"  MFU: {mfu:.1%}")

In [None]:
# Memory usage estimation
def estimate_memory_usage(model_params_b, batch_size, seq_length, dtype='bfloat16'):
    """Estimate memory usage for training."""
    
    bytes_per_param = {
        'float32': 4,
        'bfloat16': 2,
        'float16': 2,
        'int8': 1,
    }
    
    param_bytes = bytes_per_param[dtype]
    
    # Model parameters
    model_memory_gb = (model_params_b * 1e9 * param_bytes) / (1024**3)
    
    # Optimizer states (Adam uses 2x model size for momentum terms)
    optimizer_memory_gb = model_memory_gb * 2
    
    # Activations (rough estimate)
    activation_memory_gb = (
        batch_size * seq_length * model_params_b * 0.1 * param_bytes
    ) / (1024**3)
    
    # Gradients
    gradient_memory_gb = model_memory_gb
    
    total_memory_gb = (
        model_memory_gb + optimizer_memory_gb + 
        activation_memory_gb + gradient_memory_gb
    )
    
    return {
        'model': model_memory_gb,
        'optimizer': optimizer_memory_gb,
        'activations': activation_memory_gb,
        'gradients': gradient_memory_gb,
        'total': total_memory_gb,
    }

# Estimate memory for different model sizes
model_sizes = [1, 7, 13, 30, 70]
memory_estimates = []

for size in model_sizes:
    mem = estimate_memory_usage(
        model_params_b=size,
        batch_size=8,
        seq_length=2048,
        dtype='bfloat16'
    )
    memory_estimates.append({
        'Model Size (B)': size,
        'Total Memory (GB)': mem['total'],
        'Model (GB)': mem['model'],
        'Optimizer (GB)': mem['optimizer'],
    })

memory_df = pd.DataFrame(memory_estimates)
print("Memory Requirements by Model Size:")
print(memory_df.to_string(index=False, float_format='%.1f'))

# Plot memory requirements
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.bar(memory_df['Model Size (B)'].astype(str), memory_df['Total Memory (GB)'])
plt.xlabel('Model Size (B parameters)')
plt.ylabel('Memory (GB)')
plt.title('Total Memory Requirements')

plt.subplot(1, 2, 2)
categories = ['Model', 'Optimizer']
for idx, row in memory_df.iterrows():
    values = [row['Model (GB)'], row['Optimizer (GB)']]
    bottom = 0
    for cat, val in zip(categories, values):
        plt.bar(str(row['Model Size (B)']), val, bottom=bottom, 
                label=cat if idx == 0 else "")
        bottom += val

plt.xlabel('Model Size (B parameters)')
plt.ylabel('Memory (GB)')
plt.title('Memory Breakdown')
plt.legend()

plt.tight_layout()
plt.show()

## 9. Training Best Practices

In [None]:
# Best practices checklist
best_practices = [
    {
        'category': 'Data',
        'practices': [
            'Use efficient data formats (TFRecord, ArrayRecord)',
            'Implement data sharding for distributed training',
            'Pre-tokenize datasets for faster loading',
            'Use data validation to catch issues early',
        ]
    },
    {
        'category': 'Performance',
        'practices': [
            'Enable mixed precision training (bfloat16)',
            'Use gradient accumulation for larger effective batch sizes',
            'Enable XLA compilation flags',
            'Profile code to identify bottlenecks',
        ]
    },
    {
        'category': 'Stability',
        'practices': [
            'Implement gradient clipping',
            'Use learning rate warmup',
            'Monitor gradient norms and loss spikes',
            'Save checkpoints frequently',
        ]
    },
    {
        'category': 'Debugging',
        'practices': [
            'Enable stack trace collection',
            'Use small datasets for quick iteration',
            'Implement comprehensive logging',
            'Test with synthetic data first',
        ]
    },
]

print("MaxText Training Best Practices:")
print("=" * 50)
for section in best_practices:
    print(f"\n{section['category']}:")
    for practice in section['practices']:
        print(f"  ✓ {practice}")

In [None]:
# Environment variables for optimization
optimization_env_vars = {
    'XLA_FLAGS': '--xla_gpu_enable_async_collectives=true',
    'LIBTPU_INIT_ARGS': '--xla_enable_async_all_gather=true',
    'JAX_TRACEBACK_FILTERING': 'off',  # For debugging
    'JAX_ENABLE_X64': 'false',  # Use 32-bit by default
    'TF_CPP_MIN_LOG_LEVEL': '0',  # Show all logs
}

print("Recommended Environment Variables:")
print("-" * 40)
for var, value in optimization_env_vars.items():
    print(f"export {var}='{value}'")

# Set environment variables in notebook
for var, value in optimization_env_vars.items():
    os.environ[var] = value

## 10. Summary and Next Steps

In [None]:
# Training summary
print("Training Demo Summary")
print("=" * 50)
print(f"Run Name: {RUN_NAME}")
print(f"Model Configuration: {config_params['model_size']}B parameters")
print(f"Training Steps: {training_config['steps']}")
print(f"Batch Size: {training_config['per_device_batch_size']}")
print(f"Learning Rate: {training_config['learning_rate']}")
print(f"Hardware: {training_config['hardware']}")
print(f"Output Directory: {BASE_OUTPUT_DIR}/{RUN_NAME}")
print("\nNext Steps:")
print("1. Scale up to larger models (7B, 13B, 70B)")
print("2. Use real datasets (C4, Wikipedia, custom data)")
print("3. Implement fine-tuning for specific tasks")
print("4. Deploy model for inference using MaxText decode")
print("5. Optimize for multi-host distributed training")
print("6. Integrate with monitoring tools (TensorBoard, Weights & Biases)")

In [None]:
# Generate complete training script
complete_script = f"""
#!/bin/bash
# Complete MaxText Training Script
# Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

# Set environment variables
export XLA_FLAGS='--xla_gpu_enable_async_collectives=true'
export LIBTPU_INIT_ARGS='--xla_enable_async_all_gather=true'

# Configuration
RUN_NAME={RUN_NAME}
BASE_OUTPUT_DIR={BASE_OUTPUT_DIR}
DATASET_PATH={DATASET_PATH}

# Run training
python3 -m MaxText.train MaxText/configs/base.yml \\
    run_name=$RUN_NAME \\
    base_output_directory=$BASE_OUTPUT_DIR \\
    dataset_path=$DATASET_PATH \\
    dataset_type=c4 \\
    steps={training_config['steps']} \\
    per_device_batch_size={training_config['per_device_batch_size']} \\
    learning_rate={training_config['learning_rate']} \\
    warmup_steps={training_config['warmup_steps']} \\
    enable_checkpointing=true \\
    checkpoint_period={training_config['checkpoint_period']} \\
    eval_interval={training_config['eval_interval']} \\
    dtype={training_config['dtype']}

echo "Training completed!"
"""

# Save script
script_path = f"train_{RUN_NAME}.sh"
with open(script_path, 'w') as f:
    f.write(complete_script.strip())

print(f"Complete training script saved to: {script_path}")
print("\nTo run the training:")
print(f"  chmod +x {script_path}")
print(f"  ./{script_path}")