# GSM8K Math Problem Training

This notebook implements training on the GSM8K dataset using TPU acceleration and optimized configurations for mathematical reasoning.

## Setup and Imports

In [None]:
import os
import jax
from jax.experimental import mesh_utils
from jax.experimental.maps import Mesh
import numpy as np
from datasets import load_dataset
from vishwamai.training import train, create_train_state
from vishwamai.model import VishwamAIModel, ModelConfig
from vishwamai.tokenizer import VishwamAITokenizer
from omegaconf import OmegaConf
import logging
from safetensors.flax import save_file
import random
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict

# Configure logging and plotting
logging.basicConfig(level=logging.INFO)
plt.style.use('seaborn')

## TPU Setup

Configure TPU environment and create device mesh for training.

In [None]:
# TPU environment setup
os.environ['TCMALLOC_LARGE_ALLOC_REPORT_THRESHOLD'] = '10000000000'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['JAX_PLATFORMS'] = 'tpu'
os.environ['JAX_ENABLE_X64'] = 'False'

def setup_tpu_cluster():
    """Set up JAX TPU cluster configuration."""
    devices = jax.devices()
    print(f"Available devices: {devices}")
    
    # Create mesh for data parallelism
    mesh_shape = (8,)  # 8-core TPU
    device_mesh = mesh_utils.create_device_mesh(mesh_shape)
    mesh = Mesh(device_mesh, ('dp',))
    
    return mesh

mesh = setup_tpu_cluster()

## Load Configurations

Load model and training configurations optimized for GSM8K.

In [None]:
# Load configurations
model_config = OmegaConf.load('../vishwamai/configs/model/10B.yaml')
training_config = OmegaConf.load('../vishwamai/configs/training/gsm8k.yaml')

print("Model config:", model_config)
print("\nTraining config:", training_config)

## Data Processing

Implement GSM8K dataset processing with step-by-step solution formatting.

In [None]:
class GSM8KProcessor:
    """Processor for GSM8K dataset."""
    
    def __init__(self, tokenizer, config):
        self.tokenizer = tokenizer
        self.config = config
        self.max_length = config.dataset.max_length
    
    def format_example(self, example):
        """Format a GSM8K example for training."""
        question = example['question']
        answer = example['answer']
        # Extract final answer
        final_answer = answer.split('####')[-1].strip()
        # Format as instruction and response
        formatted_text = f"Question: {question}\nLet's solve this step by step:\n{answer}\nFinal Answer: {final_answer}"
        return formatted_text
    
    def tokenize_function(self, examples):
        """Tokenize a batch of formatted examples."""
        formatted_texts = [self.format_example(ex) for ex in examples]
        
        tokenized = self.tokenizer(
            formatted_texts,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_attention_mask=True,
        )
        
        tokenized["labels"] = tokenized["input_ids"].copy()
        return tokenized
    
    def prepare_dataset(self, dataset):
        """Prepare GSM8K dataset."""
        tokenized_dataset = dataset.map(
            self.tokenize_function,
            batched=True,
            num_proc=self.config.dataset.num_workers,
            remove_columns=dataset.column_names,
        )
        return tokenized_dataset

def create_gsm8k_dataloader(config, split="train"):
    """Create data loader for GSM8K dataset."""
    dataset = load_dataset("openai/gsm8k", "main", split=split)
    
    tokenizer = VishwamAITokenizer(
        vocab_size=config.model.vocab_size,
        model_prefix=config.model.name
    )
    
    data_processor = GSM8KProcessor(tokenizer, config)
    processed_dataset = data_processor.prepare_dataset(dataset)
    
    print(f"Processed {len(processed_dataset)} examples for {split} split")
    return processed_dataset

## Model Initialization

In [None]:
# Initialize model
model = VishwamAIModel(ModelConfig(**model_config))
print("Model initialized with config:", model_config)

## Training Setup

In [None]:
# Create dataloaders
train_dataset = create_gsm8k_dataloader(training_config, split="train")
val_dataset = create_gsm8k_dataloader(training_config, split="validation")

# Create checkpoint directory
checkpoint_dir = os.path.join(os.getcwd(), 'checkpoints', 'gsm8k')
os.makedirs(checkpoint_dir, exist_ok=True)

def save_checkpoint_hook(state, path):
    """Save checkpoint in safetensors format."""
    numpy_params = jax.tree_map(lambda x: np.array(x), state.params)
    save_file(numpy_params, f"{path}.safetensors")
    print(f"Saved checkpoint to {path}.safetensors")

## Start Training

In [None]:
# Run training with TPU mesh
with mesh:
    final_state = train(
        model,
        training_config,
        train_dataset,
        val_dataset=val_dataset,
        num_steps=training_config.max_steps,
        log_every=training_config.logging_steps,
        eval_every=training_config.eval_steps,
        checkpoint_dir=checkpoint_dir,
        save_checkpoint_fn=save_checkpoint_hook
    )

# Save final model
final_path = os.path.join(checkpoint_dir, "gsm8k_final.safetensors")
numpy_params = jax.tree_map(lambda x: np.array(x), final_state.params)
save_file(numpy_params, final_path)
print(f"\nTraining completed! Final model saved to {final_path}")
print(f"Best metrics: {final_state.best_metrics}")

## Model Testing and Evaluation

In [None]:
from safetensors.flax import load_file
import jax.numpy as jnp

def load_model(model_path):
    """Load model from safetensors checkpoint."""
    params = load_file(model_path)
    # Convert numpy arrays to jax arrays
    params = jax.tree_map(lambda x: jnp.array(x), params)
    return params

def evaluate_gsm8k_sample(model, tokenizer, question):
    """Test model on a single GSM8K question."""
    # Format input
    input_text = f"Question: {question}\nLet's solve this step by step:"
    
    # Tokenize
    inputs = tokenizer(
        input_text,
        return_tensors="jax",
        padding=True,
        truncation=True,
        max_length=512
    )
    
    # Generate response
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_length=1024,
        temperature=0.7,
        num_beams=4,
        early_stopping=True
    )
    
    # Decode output
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response

### Load and Test Trained Model

In [None]:
# Load the trained model
model_path = os.path.join(checkpoint_dir, "gsm8k_final.safetensors")
params = load_model(model_path)
model.params = params

# Initialize tokenizer
tokenizer = VishwamAITokenizer(
    vocab_size=model_config.vocab_size,
    model_prefix=model_config.name
)

# Test sample questions
test_questions = [
    "Janet's ducks lay 16 eggs per day. She eats 3 for breakfast every morning and sells the rest to her neighbors for $2 per egg. How much money does she make per week?",
    "A shop sells each ice cream cone for $2.50. On Monday they sold 45 cones, on Tuesday 52 cones, and on Wednesday 38 cones. What was their total revenue for these three days?",
    "John has 5 times as many marbles as Peter. If Peter has 8 marbles, how many do they have together?"
]

print("Testing model on sample questions:\n")
for i, question in enumerate(test_questions, 1):
    print(f"Question {i}:\n{question}")
    response = evaluate_gsm8k_sample(model, tokenizer, question)
    print(f"\nModel Response:\n{response}\n\n")

### Batch Evaluation on Test Set

In [None]:
def evaluate_test_set(model, tokenizer, num_samples=100):
    """Evaluate model on a subset of GSM8K test set."""
    # Load test dataset
    test_dataset = load_dataset("openai/gsm8k", "main", split="test")
    if num_samples:
        test_dataset = test_dataset.select(range(num_samples))
    
    correct = 0
    total = len(test_dataset)
    
    print(f"Evaluating on {total} test samples...\n")
    
    for i, example in enumerate(test_dataset):
        # Get model's response
        response = evaluate_gsm8k_sample(model, tokenizer, example['question'])
        
        # Extract predicted answer
        try:
            pred_answer = float(response.split('Final Answer:')[-1].strip())
            true_answer = float(example['answer'].split('####')[-1].strip())
            
            # Check if correct (allowing for small floating point differences)
            is_correct = abs(pred_answer - true_answer) < 1e-6
            correct += int(is_correct)
            
            if (i + 1) % 10 == 0:
                print(f"Progress: {i+1}/{total} samples evaluated")
                print(f"Current accuracy: {(correct/(i+1))*100:.2f}%\n")
                
        except ValueError:
            print(f"Warning: Could not extract numeric answer for sample {i}")
    
    final_accuracy = (correct/total) * 100
    print(f"\nEvaluation complete!")
    print(f"Final accuracy: {final_accuracy:.2f}%")
    print(f"Correct: {correct}/{total}")
    
    return {
        'accuracy': final_accuracy,
        'correct': correct,
        'total': total
    }

In [None]:
# Run batch evaluation
print("Starting batch evaluation on test set...")
eval_results = evaluate_test_set(model, tokenizer, num_samples=100)
print("\nTest set evaluation results:")
print(f"Accuracy: {eval_results['accuracy']:.2f}%")
print(f"Correct: {eval_results['correct']}/{eval_results['total']}")

## Performance Visualization

In [None]:
def visualize_results(eval_results, final_state):
    """Create visualizations of model performance."""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot 1: Training Loss Curve
    if hasattr(final_state, 'metrics_history'):
        steps = range(len(final_state.metrics_history['loss']))
        ax1.plot(steps, final_state.metrics_history['loss'], label='Training Loss')
        ax1.set_title('Training Loss Over Time')
        ax1.set_xlabel('Steps')
        ax1.set_ylabel('Loss')
        ax1.legend()
    
    # Plot 2: Test Set Performance
    ax2.bar(['Correct', 'Incorrect'], 
            [eval_results['correct'], eval_results['total'] - eval_results['correct']],
            color=['green', 'red'])
    ax2.set_title('Test Set Performance')
    ax2.text(0, eval_results['correct'], f"{eval_results['accuracy']:.1f}%",
             ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()

# Create performance visualizations
visualize_results(eval_results, final_state)

# Save visualization
plt.savefig(os.path.join(checkpoint_dir, 'performance_visualization.png'))
print(f"Performance visualization saved to {os.path.join(checkpoint_dir, 'performance_visualization.png')}")