# GSM8K Math Problem Training

This notebook implements training on the GSM8K dataset using TPU acceleration and optimized configurations for mathematical reasoning.

## Repository Setup

First, let's clone the repository and prepare the environment.

In [None]:
# Helper function to determine environment
def is_colab():
    try:
        import google.colab
        return True
    except:
        return False

# Setup repository and paths
import os
from pathlib import Path

if is_colab():
    # Clone repository in Colab
    !git clone https://github.com/organization/vishwamai.git
    repo_path = Path('vishwamai')
else:
    # Local development setup
    notebook_dir = Path().absolute()
    if notebook_dir.name == 'notebooks':
        repo_path = notebook_dir.parent
    else:
        repo_path = notebook_dir

# Create necessary directories
os.makedirs(repo_path / 'checkpoints' / 'gsm8k', exist_ok=True)
os.makedirs(repo_path / 'logs', exist_ok=True)

print(f"Repository path: {repo_path}")
print("Directory structure:")
!ls -R {repo_path}

In [None]:
# Set up Python environment
import sys
import importlib

# Change to repository directory
os.chdir(repo_path)

# Install package in development mode
print("Installing VishwamAI package...")
!pip install --quiet --no-cache-dir -e .

# Add repository to Python path if needed
if str(repo_path) not in sys.path:
    sys.path.insert(0, str(repo_path))
    print(f"Added {repo_path} to Python path")

# Clear and reload imports
if 'vishwamai' in sys.modules:
    importlib.reload(sys.modules['vishwamai'])

# Verify installation
try:
    import vishwamai
    print(f"✓ Successfully imported vishwamai v{vishwamai.__version__}")
    print(f"✓ Package location: {vishwamai.__file__}")
except ImportError as e:
    print(f"✗ Failed to import vishwamai: {e}")
    print("Debug information:")
    print(f"Current working directory: {os.getcwd()}")
    print(f"Python path: {sys.path}")
    raise

## Environment Setup

Install required packages and set up TPU environment.

In [None]:
# Setup TPU runtime and install required packages
!pip install --upgrade pip
!pip install "jax[tpu]>=0.4.13" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
!pip install flax>=0.7.0 optax>=0.1.7 chex>=0.1.7 tensorboardX>=2.6.1
!pip install datasets omegaconf safetensors matplotlib seaborn

# Install VishwamAI package in development mode with extras
!pip install -e ".[dev,profiling]" --no-cache-dir

# Verify installations
import jax
print(f"JAX version: {jax.__version__}")
print(f"Available devices: {jax.devices()}")

import vishwamai
print(f"VishwamAI version: {vishwamai.__version__}")
print("✓ Successfully imported all required packages")

# Ensure TPU is available
assert len(jax.devices('tpu')) > 0, "No TPU devices found"

## Setup and Imports

In [None]:
import os
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
import numpy as np
from datasets import load_dataset
from vishwamai.training import train, create_train_state
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from omegaconf import OmegaConf
import logging
from safetensors.flax import save_file
import random
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Configure logging and plotting
logging.basicConfig(level=logging.INFO)
plt.style.use('seaborn')

## TPU Setup

Configure TPU environment and create device mesh for training using modern JAX sharding.

In [None]:
# TPU environment setup
os.environ['TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD'] = '10000000000'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['JAX_PLATFORMS'] = 'tpu'
os.environ['JAX_ENABLE_X64'] = 'False'

def setup_tpu_cluster():
    """Set up JAX TPU cluster configuration using modern sharding API."""
    # Get available devices
    devices = jax.devices()
    print(f"Available devices: {devices}")
    
    # Create device mesh for data parallel training
    device_count = len(devices)
    device_mesh = np.array(devices).reshape(device_count)
    
    # Create mesh with data parallel sharding
    mesh = Mesh(device_mesh, ('data',))
    
    # Create sharding rules
    data_sharding = NamedSharding(mesh, P('data'))
    
    return mesh, data_sharding

# Set up TPU mesh and sharding
mesh, sharding = setup_tpu_cluster()
print("\nMesh specification:", mesh)
print("Sharding specification:", sharding)

## Load Configurations

Load model and training configurations optimized for GSM8K.

In [None]:
# Load configurations
configs_dir = repo_path / 'vishwamai' / 'configs'
model_config = OmegaConf.load(configs_dir / 'model' / '10B.yaml')
training_config = OmegaConf.load(configs_dir / 'training' / 'gsm8k.yaml')

# Update paths in configs to use absolute paths
if 'output_dir' in training_config:
    training_config.output_dir = str(repo_path / 'checkpoints' / 'gsm8k')

print("Config directory:", configs_dir)
print("\nModel config:", model_config)
print("\nTraining config:", training_config)

## Data Processing

Implement GSM8K dataset processing with step-by-step solution formatting.

In [None]:
class GSM8KProcessor:
    """Processor for GSM8K dataset."""
    
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config
        self.max_length = config.dataset.max_length
    
    def format_example(self, example):
        """Format a GSM8K example for training."""
        question = example['question']
        answer = example['answer']
        # Extract final answer
        final_answer = answer.split('####')[-1].strip()
        # Format as instruction and response
        formatted_text = f"Question: {question}\nLet's solve this step by step:\n{answer}\nFinal Answer: {final_answer}"
        return formatted_text
    
    def tokenize_function(self, examples):
        """Tokenize a batch of formatted examples."""
        formatted_texts = [self.format_example(ex) for ex in examples]
        
        tokenized = self.tokenizer(
            formatted_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
        )
        
        tokenized["labels"] = tokenized["input_ids"].copy()
        return tokenized
    
    def prepare_dataset(self, dataset):
        """Prepare GSM8K dataset."""
        tokenized_dataset = dataset.map(
            self.tokenize_function,
            batched=True,
            num_proc=self.config.dataset.num_workers,
            remove_columns=dataset.column_names,
        )
        return tokenized_dataset

def create_gsm8k_dataloader(config, split="train"):
    """Create data loader for GSM8K dataset."""
    dataset = load_dataset("openai/gsm8k", "main", split=split)
    
    tokenizer = VishwamAITokenizer(
        vocab_size=config.model.vocab_size,
        model_prefix=config.model.name
    )
    
    data_processor = GSM8KProcessor(tokenizer, config)
    processed_dataset = data_processor.prepare_dataset(dataset)
    
    print(f"Processed {len(processed_dataset)} examples for {split} split")
    return processed_dataset

## Model Initialization

In [None]:
# Initialize model
model = VishwamAIModel(ModelConfig(**model_config))
print("Model initialized with config:", model_config)

## Training Setup

In [None]:
# Create dataloaders
train_dataset = create_gsm8k_dataloader(training_config, split="train")
val_dataset = create_gsm8k_dataloader(training_config, split="validation")

# Set up checkpoint directory using repository path
checkpoint_dir = repo_path / 'checkpoints' / 'gsm8k'
checkpoint_dir.mkdir(parents=True, exist_ok=True)

def save_checkpoint_hook(state, path):
    """Save checkpoint in safetensors format."""
    numpy_params = jax.tree_map(lambda x: np.array(x), state.params)
    save_file(numpy_params, f"{path}.safetensors")
    print(f"Saved checkpoint to {path}.safetensors")

## Start Training

In [None]:
# Run training with TPU mesh and sharding
with mesh:
    final_state = train(
        model,
        training_config,
        train_dataset,
        val_dataset=val_dataset,
        num_steps=training_config.max_steps,
        log_every=training_config.logging_steps,
        eval_every=training_config.eval_steps,
        checkpoint_dir=checkpoint_dir,
        save_checkpoint_fn=save_checkpoint_hook,
        sharding=sharding
    )

# Save final model
final_path = checkpoint_dir / "gsm8k_final.safetensors"
numpy_params = jax.tree_map(lambda x: np.array(x), final_state.params)
save_file(numpy_params, str(final_path))
print(f"\nTraining completed! Final model saved to {final_path}")
print(f"Best metrics: {final_state.best_metrics}")