# MaxText + Tunix SFT Integration Demo

This notebook demonstrates how to run Supervised Fine-Tuning (SFT) using MaxText models with the Tunix trainer. The workflow follows the steps implemented in `MaxText/sft/sft_trainer.py`.

## Overview

The integration consists of several key components:
1. **MaxText Model Loading**: Using `mt.from_pretrained()` to load pre-trained models
2. **Tunix Adapter**: `TunixMaxTextLlama` wrapper that bridges MaxText models with Tunix trainer
3. **Tunix Trainer**: Handles the training loop, optimization, and data management
4. **Data Hooks**: Custom hooks for MaxText-specific data processing
5. **Training Hooks**: Custom hooks for MaxText-specific training logic

## Prerequisites

Make sure you have the following packages installed:
- `MaxText`
- `tunix`
- `jax`
- `flax`
- `orbax-checkpoint`

## Setup

First, let's set up the environment and imports.

In [None]:
# Install required packages if not already installed
!pip install tunix orbax-checkpoint

# Set environment variables for TPU/GPU compatibility
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
if "xla_tpu_spmd_rng_bit_generator_unsafe" not in os.environ.get("LIBTPU_INIT_ARGS", ""):
    os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true"

In [None]:
# Import required libraries
import jax
import jax.numpy as jnp
from flax import nnx
from functools import partial
import numpy as np

# MaxText imports
import MaxText as mt
from MaxText import pyconfig
from MaxText import max_utils
from MaxText import maxtext_utils
from MaxText import optimizers
from MaxText.integration.tunix.tunix_adaptor import TunixMaxTextLlama
from MaxText.sft import hooks

# Tunix imports
from tunix.sft import peft_trainer, profiler

# Checkpointing
from orbax import checkpoint as ocp

# Set JAX configuration
jax.config.update("jax_default_prng_impl", "unsafe_rbg")

## Configuration Setup

Let's create a configuration object similar to what's used in the SFT trainer. We'll use a simplified configuration for demonstration purposes.

In [None]:
# Create a simple configuration class for demonstration
class SimpleConfig:
    def __init__(self):
        # Model configuration
        self.model_name = "llama2-7b"
        self.load_parameters_path = None  # Set to checkpoint path if loading from checkpoint
        
        # Training configuration
        self.steps = 100
        self.eval_interval = 10
        self.eval_steps = 5
        self.gradient_accumulation_steps = 1
        
        # Checkpointing
        self.checkpoint_period = 50
        self.checkpoint_dir = "./checkpoints"
        self.async_checkpointing = True
        
        # Profiling and logging
        self.profiler = False
        self.tensorboard_dir = "./tensorboard"
        self.skip_first_n_steps_for_profiler = 0
        self.profiler_steps = 10
        
        # Model-specific
        self.logical_axis_rules = ()
        
        # Data configuration
        self.dataset_type = "hf"
        self.hf_path = "HuggingFaceH4/ultrachat_200k"
        self.train_split = "train_sft"
        self.hf_eval_split = "test_sft"
        self.train_data_columns = ["messages"]
        self.eval_data_columns = ["messages"]
        
        # Training parameters
        self.learning_rate = 2e-5
        self.weight_dtype = "bfloat16"
        self.per_device_batch_size = 1
        self.max_target_length = 1024
        
        # SFT specific
        self.use_sft = True
        self.sft_train_on_completion_only = True
        self.packing = True

# Create config instance
config = SimpleConfig()

print("Configuration created:")
print(f"Model: {config.model_name}")
print(f"Steps: {config.steps}")
print(f"Learning rate: {config.learning_rate}")
print(f"Dataset: {config.hf_path}")

## Tunix Configuration Setup

This function creates the Tunix training configuration from MaxText config, following the pattern in `sft_trainer.py`.

In [None]:
def get_tunix_config(mt_config):
    """Create Tunix training configuration from MaxText config."""
    # Checkpointing configurations
    checkpointing_options = ocp.CheckpointManagerOptions(
        save_interval_steps=mt_config.checkpoint_period,
        enable_async_checkpointing=mt_config.async_checkpointing,
    )

    # Metrics configurations
    metrics_logging_options = peft_trainer.metrics_logger.MetricsLoggerOptions(
        log_dir=mt_config.tensorboard_dir
    )

    # Profiler configurations
    profiler_options = None
    if mt_config.profiler:
        profiler_options = profiler.ProfilerOptions(
            log_dir=mt_config.tensorboard_dir,
            skip_first_n_steps=mt_config.skip_first_n_steps_for_profiler,
            profiler_steps=mt_config.profiler_steps,
        )

    return peft_trainer.TrainingConfig(
        eval_every_n_steps=mt_config.eval_interval,
        max_steps=mt_config.steps,
        gradient_accumulation_steps=mt_config.gradient_accumulation_steps,
        checkpoint_root_directory=mt_config.checkpoint_dir,
        checkpointing_options=checkpointing_options,
        metrics_logging_options=metrics_logging_options,
        profiler_options=profiler_options,
    )

# Create Tunix config
tunix_config = get_tunix_config(config)
print("Tunix configuration created:")
print(f"Max steps: {tunix_config.max_steps}")
print(f"Eval every: {tunix_config.eval_every_n_steps} steps")
print(f"Checkpoint dir: {tunix_config.checkpoint_root_directory}")

## MaxText Model Loading with Tunix Integration

This function demonstrates how to load a MaxText model and wrap it with the Tunix adapter, following the pattern in `sft_trainer.py`.

In [None]:
def get_maxtext_model(config, default_loss_function=True):
    """Load MaxText model and wrap with Tunix adapter."""
    
    def create_model():
        # Load the model using MaxText's from_pretrained
        # Note: In a real scenario, you would need proper config and devices
        # For demo purposes, we'll create a placeholder
        print("Creating MaxText model...")
        
        # This is a simplified version - in practice you'd use:
        # model = mt.from_pretrained(config, rngs=nnx.Rngs(params=0, dropout=1))
        
        # For demo, we'll create a mock model structure
        class MockMaxTextModel:
            def __init__(self):
                self.mesh = None
                self.enable_dropout = True
            
            def __call__(self, decoder_input_tokens, decoder_positions, decoder_segment_ids=None):
                # Mock forward pass
                batch_size, seq_len = decoder_input_tokens.shape
                vocab_size = 32000  # Mock vocab size
                return jnp.zeros((batch_size, seq_len, vocab_size))
        
        return MockMaxTextModel()
    
    # Create the model
    model = create_model()
    
    # Create a mock mesh for demonstration
    # In practice, this would come from the actual model
    mesh = None
    
    # Load checkpoint if specified
    if config.load_parameters_path:
        print(f"Loading checkpoint from {config.load_parameters_path}")
        # In practice, you would use:
        # checkpoint = mt.checkpointing.load_params_from_path(...)
        # if checkpoint:
        #     nnx.update(model, checkpoint)
    
    # Wrap with Tunix adapter
    if default_loss_function:
        print("Using Tunix default loss function")
        tunix_model = TunixMaxTextLlama(
            base_model=model,
            use_attention_mask=False,  # trust Tunix loss masking
        )
    else:
        print("Using MaxText loss function")
        tunix_model = model
    
    return tunix_model, mesh

# Load the model
print("Loading MaxText model with Tunix integration...")
model, mesh = get_maxtext_model(config, default_loss_function=True)
print(f"Model loaded: {type(model)}")
print(f"Mesh: {mesh}")

## Loss Function Configuration

Configure whether to use Tunix's default loss function or MaxText's custom loss function.

In [None]:
def use_tunix_default_loss_function(trainer):
    """Configure trainer to use Tunix default loss function."""
    def gen_model_input_fn(x):
        return {
            "input_tokens": x["inputs"],
            "positions": x["inputs_position"],
            "input_mask": x["inputs_segmentation"],
            "attention_mask": x["inputs_segmentation"],
        }

    trainer = trainer.with_gen_model_input_fn(gen_model_input_fn)
    return trainer

def use_maxtext_loss_function(trainer, mt_config):
    """Configure trainer to use MaxText custom loss function."""
    def loss_fn(model, inputs, inputs_position, inputs_segmentation,
                targets, targets_position, targets_segmentation):
        # In practice, you would import and use MaxText's loss function:
        # from MaxText.train import loss_fn
        # data = {
        #     "inputs": inputs,
        #     "inputs_position": inputs_position,
        #     "inputs_segmentation": inputs_segmentation,
        #     "targets": targets,
        #     "targets_position": targets_position,
        #     "targets_segmentation": targets_segmentation,
        # }
        # return loss_fn(model, mt_config, data, dropout_rng=None, params=None, is_train=True)
        
        # For demo, return a mock loss
        return jnp.mean((inputs - targets) ** 2), {}

    trainer = trainer.with_loss_fn(loss_fn, has_aux=True)
    return trainer

print("Loss function configuration functions defined")

## Training Setup

Now let's set up the training components including the optimizer, learning rate schedule, and hooks.

In [None]:
def setup_training_components(config, mesh):
    """Set up training components: optimizer, learning rate schedule, and hooks."""
    
    # Create learning rate schedule
    # In practice, you would use:
    # learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
    # For demo, create a simple constant schedule
    def simple_lr_schedule(step):
        return config.learning_rate
    
    learning_rate_schedule = simple_lr_schedule
    
    # Create optimizer
    # In practice, you would use:
    # optimizer = optimizers.get_optimizer(config, learning_rate_schedule)
    # For demo, create a simple SGD optimizer
    from optax import sgd
    optimizer = sgd(learning_rate=config.learning_rate)
    
    # Create training hooks
    # In practice, you would use:
    # training_hooks = hooks.SFTTrainingHooks(config, mesh, learning_rate_schedule, goodput_recorder)
    # For demo, create mock hooks
    class MockTrainingHooks:
        def __init__(self, config, mesh, lr_schedule):
            self.config = config
            self.mesh = mesh
            self.lr_schedule = lr_schedule
        
        def on_step_begin(self, step, state):
            print(f"Training step {step} beginning")
        
        def on_step_end(self, step, state, metrics):
            print(f"Training step {step} completed with loss: {metrics.get('loss', 'N/A')}")
    
    training_hooks = MockTrainingHooks(config, mesh, learning_rate_schedule)
    
    # Create data hooks
    # In practice, you would use:
    # data_hooks = hooks.SFTDataHooks(config, mesh, goodput_recorder)
    # For demo, create mock hooks
    class MockDataHooks:
        def __init__(self, config, mesh):
            self.config = config
            self.mesh = mesh
        
        def get_train_data_iterator(self):
            # Mock data iterator
            def mock_iterator():
                for i in range(100):
                    yield {
                        "inputs": jnp.ones((1, 128), dtype=jnp.int32),
                        "inputs_position": jnp.arange(128).reshape(1, -1),
                        "inputs_segmentation": jnp.ones((1, 128), dtype=jnp.int32),
                        "targets": jnp.ones((1, 128), dtype=jnp.int32),
                        "targets_position": jnp.arange(128).reshape(1, -1),
                        "targets_segmentation": jnp.ones((1, 128), dtype=jnp.int32),
                    }
            return mock_iterator()
        
        def get_eval_data_iterator(self):
            # Mock eval data iterator
            def mock_eval_iterator():
                for i in range(10):
                    yield {
                        "inputs": jnp.ones((1, 128), dtype=jnp.int32),
                        "inputs_position": jnp.arange(128).reshape(1, -1),
                        "inputs_segmentation": jnp.ones((1, 128), dtype=jnp.int32),
                        "targets": jnp.ones((1, 128), dtype=jnp.int32),
                        "targets_position": jnp.arange(128).reshape(1, -1),
                        "targets_segmentation": jnp.ones((1, 128), dtype=jnp.int32),
                    }
            return mock_eval_iterator()
    
    data_hooks = MockDataHooks(config, mesh)
    
    return optimizer, learning_rate_schedule, training_hooks, data_hooks

# Set up training components
print("Setting up training components...")
optimizer, lr_schedule, training_hooks, data_hooks = setup_training_components(config, mesh)
print("Training components set up successfully")

## Trainer Configuration

Now let's configure the Tunix trainer with all the components we've set up.

In [None]:
def configure_trainer(model, optimizer, tunix_config, training_hooks, data_hooks, config):
    """Configure the Tunix trainer with all components."""
    
    # Create the base trainer
    trainer = peft_trainer.PeftTrainer(model, optimizer, tunix_config)
    
    # Add training hooks
    trainer = trainer.with_training_hooks(training_hooks)
    
    # Add data hooks
    trainer = trainer.with_data_hooks(data_hooks)
    
    # Configure loss function
    if hasattr(config, 'use_tunix_loss') and config.use_tunix_loss:
        print("Using Tunix default loss function")
        trainer = use_tunix_default_loss_function(trainer)
    else:
        print("Using MaxText loss function")
        trainer = use_maxtext_loss_function(trainer, config)
    
    return trainer

# Configure the trainer
print("Configuring Tunix trainer...")
trainer = configure_trainer(model, optimizer, tunix_config, training_hooks, data_hooks, config)
print(f"Trainer configured: {type(trainer)}")

## Training Execution

Now let's execute the training loop. In a real scenario, this would run the actual training.

In [None]:
def execute_training(trainer, data_hooks, mesh, config):
    """Execute the training loop."""
    
    print("Starting training execution...")
    print(f"Training for {config.steps} steps")
    print(f"Evaluation every {config.eval_interval} steps")
    
    # In practice, you would run:
    # with mesh:
    #     trainer.train(data_hooks.get_train_data_iterator(), data_hooks.get_eval_data_iterator())
    
    # For demo purposes, we'll simulate the training process
    print("\nSimulating training process...")
    
    for step in range(min(10, config.steps)):  # Show first 10 steps for demo
        # Simulate training step
        print(f"Step {step + 1}: Training...")
        
        # Simulate evaluation
        if (step + 1) % config.eval_interval == 0:
            print(f"  Step {step + 1}: Evaluating...")
        
        # Simulate checkpointing
        if (step + 1) % config.checkpoint_period == 0:
            print(f"  Step {step + 1}: Saving checkpoint...")
    
    print("\nTraining simulation completed!")
    print("\nIn a real scenario, this would:")
    print("1. Load actual training data from HuggingFace")
    print("2. Run actual training steps with gradient updates")
    print("3. Perform real evaluation on validation data")
    print("4. Save actual model checkpoints")
    print("5. Log metrics to TensorBoard")

# Execute training (simulation)
execute_training(trainer, data_hooks, mesh, config)

## Complete Training Function

Here's the complete training function that puts everything together, following the pattern in `sft_trainer.py`.

In [None]:
def train(mt_config, default_loss_function=True):
    """Complete training function following sft_trainer.py pattern."""
    
    print("=== Starting MaxText + Tunix SFT Training ===")
    
    # Step 1: Get Tunix configuration
    print("\n1. Setting up Tunix configuration...")
    tunix_config = get_tunix_config(mt_config)
    
    # Step 2: Load MaxText model with Tunix integration
    print("\n2. Loading MaxText model...")
    model, mesh = get_maxtext_model(mt_config, default_loss_function)
    
    # Step 3: Set up training components
    print("\n3. Setting up training components...")
    optimizer, learning_rate_schedule, training_hooks, data_hooks = setup_training_components(mt_config, mesh)
    
    # Step 4: Configure trainer
    print("\n4. Configuring Tunix trainer...")
    trainer = configure_trainer(model, optimizer, tunix_config, training_hooks, data_hooks, mt_config)
    
    # Step 5: Execute training
    print("\n5. Executing training...")
    execute_training(trainer, data_hooks, mesh, mt_config)
    
    print("\n=== Training Setup Complete ===")
    print("\nTo run actual training, you would need:")
    print("1. Real model checkpoint or pretrained weights")
    print("2. Proper TPU/GPU configuration")
    print("3. Real training data")
    print("4. Proper environment setup")
    
    return trainer, model, mesh

# Run the complete training setup
print("Running complete training setup...")
trainer, model, mesh = train(config, default_loss_function=True)

## Usage Examples

Here are some examples of how to use this integration in practice.

In [None]:
# Example 1: Training with Tunix default loss
print("=== Example 1: Tunix Default Loss ===")
config.use_tunix_loss = True
trainer1, _, _ = train(config, default_loss_function=True)

# Example 2: Training with MaxText custom loss
print("\n=== Example 2: MaxText Custom Loss ===")
config.use_tunix_loss = False
trainer2, _, _ = train(config, default_loss_function=False)

# Example 3: Different model configurations
print("\n=== Example 3: Different Model Configs ===")
configs = [
    {"model_name": "llama2-7b", "learning_rate": 1e-5},
    {"model_name": "llama2-13b", "learning_rate": 5e-6},
    {"model_name": "gemma-7b", "learning_rate": 2e-5}
]

for cfg in configs:
    print(f"\nConfig: {cfg}")
    config.model_name = cfg["model_name"]
    config.learning_rate = cfg["learning_rate"]
    # In practice, you would run actual training here
    print(f"  Would train {cfg['model_name']} with lr={cfg['learning_rate']}")

## Integration with Real MaxText

To integrate with real MaxText models, you would need to:

1. **Install MaxText properly**: `pip install -e .`
2. **Set up proper TPU/GPU configuration**
3. **Use real model checkpoints**
4. **Configure proper data pipelines**

## Key Benefits of This Integration

1. **Leverages Tunix's optimized training infrastructure**
2. **Maintains MaxText's model architecture and capabilities**
3. **Provides flexible loss function options**
4. **Supports both training and evaluation workflows**
5. **Integrates with MaxText's checkpointing and logging systems**

## Next Steps

1. **Set up proper TPU/GPU environment**
2. **Download real model checkpoints**
3. **Configure real training data**
4. **Run actual training experiments**
5. **Monitor training progress and metrics**

This notebook provides the foundation for integrating MaxText with Tunix for SFT training. The actual training execution would require proper hardware setup and real data.