# GSM8K Math Problem Training

This notebook implements training on the GSM8K dataset using TPU acceleration and optimized configurations for mathematical reasoning.

## Setup and Imports

In [None]:
import os
import jax
from jax.experimental import mesh_utils
from jax.experimental.maps import Mesh
import numpy as np
from datasets import load_dataset
from vishwamai.training import train, create_train_state
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from omegaconf import OmegaConf
import logging
from safetensors.flax import save_file
import random

# Configure logging
logging.basicConfig(level=logging.INFO)

## TPU Setup

Configure TPU environment and create device mesh for training.

In [None]:
# TPU environment setup
os.environ['TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD'] = '10000000000'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['JAX_PLATFORMS'] = 'tpu'
os.environ['JAX_ENABLE_X64'] = 'False'

def setup_tpu_cluster():
    """Set up JAX TPU cluster configuration."""
    devices = jax.devices()
    print(f"Available devices: {devices}")
    
    # Create mesh for data parallelism
    mesh_shape = (8,)  # 8-core TPU
    device_mesh = mesh_utils.create_device_mesh(mesh_shape)
    mesh = Mesh(device_mesh, ('dp',))
    
    return mesh

mesh = setup_tpu_cluster()

## Load Configurations

Load model and training configurations optimized for GSM8K.

In [None]:
# Load configurations
model_config = OmegaConf.load('../vishwamai/configs/model/10B.yaml')
training_config = OmegaConf.load('../vishwamai/configs/training/gsm8k.yaml')

print("Model config:", model_config)
print("\nTraining config:", training_config)

## Data Processing

Implement GSM8K dataset processing with step-by-step solution formatting.

In [None]:
class GSM8KProcessor:
    """Processor for GSM8K dataset."""
    
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config
        self.max_length = config.dataset.max_length
    
    def format_example(self, example):
        """Format a GSM8K example for training."""
        question = example['question']
        answer = example['answer']
        # Extract final answer
        final_answer = answer.split('####')[-1].strip()
        # Format as instruction and response
        formatted_text = f"Question: {question}\nLet's solve this step by step:\n{answer}\nFinal Answer: {final_answer}"
        return formatted_text
    
    def tokenize_function(self, examples):
        """Tokenize a batch of formatted examples."""
        formatted_texts = [self.format_example(ex) for ex in examples]
        
        tokenized = self.tokenizer(
            formatted_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
        )
        
        tokenized["labels"] = tokenized["input_ids"].copy()
        return tokenized
    
    def prepare_dataset(self, dataset):
        """Prepare GSM8K dataset."""
        tokenized_dataset = dataset.map(
            self.tokenize_function,
            batched=True,
            num_proc=self.config.dataset.num_workers,
            remove_columns=dataset.column_names,
        )
        return tokenized_dataset

def create_gsm8k_dataloader(config, split="train"):
    """Create data loader for GSM8K dataset."""
    dataset = load_dataset("openai/gsm8k", "main", split=split)
    
    tokenizer = VishwamAITokenizer(
        vocab_size=config.model.vocab_size,
        model_prefix=config.model.name
    )
    
    data_processor = GSM8KProcessor(tokenizer, config)
    processed_dataset = data_processor.prepare_dataset(dataset)
    
    print(f"Processed {len(processed_dataset)} examples for {split} split")
    return processed_dataset

## Model Initialization

In [None]:
# Initialize model
model = VishwamAIModel(ModelConfig(**model_config))
print("Model initialized with config:", model_config)

## Training Setup

In [None]:
# Create dataloaders
train_dataset = create_gsm8k_dataloader(training_config, split="train")
val_dataset = create_gsm8k_dataloader(training_config, split="validation")

# Create checkpoint directory
checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints', 'gsm8k')
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint_hook(state, path):
    """Save checkpoint in safetensors format."""
    numpy_params = jax.tree_map(lambda x: np.array(x), state.params)
    save_file(numpy_params, f"{path}.safetensors")
    print(f"Saved checkpoint to {path}.safetensors")

## Start Training

In [None]:
# Run training with TPU mesh
with mesh:
    final_state = train(
        model,
        training_config,
        train_dataset,
        val_dataset=val_dataset,
        num_steps=training_config.max_steps,
        log_every=training_config.logging_steps,
        eval_every=training_config.eval_steps,
        checkpoint_dir=checkpoint_dir,
        save_checkpoint_fn=save_checkpoint_hook
    )

# Save final model
final_path = os.path.join(checkpoint_dir, "gsm8k_final.safetensors")
numpy_params = jax.tree_map(lambda x: np.array(x), final_state.params)
save_file(numpy_params, final_path)
print(f"\nTraining completed! Final model saved to {final_path}")
print(f"Best metrics: {final_state.best_metrics}")