# Efficient Pre-training on Limited TPU Resources

This notebook demonstrates how to use VishwamAI's efficient pre-training capabilities with curriculum learning and TPU optimizations.

In [None]:
import os
import jax
import numpy as np
from omegaconf import OmegaConf
from vishwamai.model import VishwamAIModel
from vishwamai.tokenizer import VishwamAITokenizer
from vishwamai.pretrain_efficient import setup_tpu_devices, create_model_config

## 1. TPU Setup and Configuration

First, let's set up our TPU environment and load our optimized configuration.

In [None]:
# Set up TPU devices with optimal settings
devices = setup_tpu_devices()
print(f"Available devices: {jax.device_count()}")

# Load our efficient pre-training config
config = OmegaConf.load("../vishwamai/configs/training/efficient_pretrain.yaml")
print("\nTraining Configuration:")
print(OmegaConf.to_yaml(config.training))

## 2. Model and Tokenizer Initialization

Now we'll create our model and tokenizer with the optimized settings.

In [None]:
# Create model with TPU optimizations
model_config = create_model_config(config)
model = VishwamAIModel(model_config)

# Initialize tokenizer
tokenizer = VishwamAITokenizer(vocab_size=config.model.vocab_size)

print(f"Model parameters: {sum(p.size for p in jax.tree_leaves(model.params)):,}")

## 3. Curriculum Learning Demo

Let's examine how curriculum learning progresses during training.

In [None]:
from vishwamai.training import DataProcessor

# Create data processor with curriculum learning
data_processor = DataProcessor(tokenizer, config)

# Demonstrate curriculum progression
print("Curriculum Learning Progress:")
print(f"Initial sequence length: {data_processor.curriculum_scheduler['current_max_length']}")

# Simulate curriculum updates
for step in range(3):
    data_processor.update_curriculum()
    if step > 0 and step % (config.training.curriculum.update_every - 1) == 0:
        print(f"Step {step+1}: sequence length = {data_processor.curriculum_scheduler['current_max_length']}")

## 4. Mixed Precision Training

Demonstrate the memory benefits of mixed precision training.

In [None]:
# Create sample batch
batch_size = config.training.batch_size
seq_length = 64  # Start with curriculum's initial length

# Compare memory usage
sample_fp32 = np.random.randn(batch_size, seq_length, config.model.hidden_size).astype(np.float32)
sample_bf16 = sample_fp32.astype(np.float16)

print("Memory Usage Comparison:")
print(f"FP32: {sample_fp32.nbytes / 1024 / 1024:.2f} MB")
print(f"BF16: {sample_bf16.nbytes / 1024 / 1024:.2f} MB")
print(f"Memory Savings: {(1 - sample_bf16.nbytes/sample_fp32.nbytes) * 100:.1f}%")

## 5. Training Example

Run a small training example to demonstrate all optimizations working together.

In [None]:
from vishwamai.training import train, create_train_dataloader, create_val_dataloader

# Modify config for quick demo
demo_config = OmegaConf.create(config)
demo_config.training.max_steps = 100
demo_config.training.log_every_n_steps = 10
demo_config.monitoring.save_every_n_steps = 50

# Create data loaders
train_loader = create_train_dataloader(demo_config, tokenizer)
val_loader = create_val_dataloader(demo_config, tokenizer)

# Run training demo
checkpoint_dir = "demo_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

final_state = train(
    model=model,
    config=demo_config,
    tokenizer=tokenizer,
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    num_steps=demo_config.training.max_steps,
    log_every=demo_config.training.log_every_n_steps,
    eval_every=demo_config.monitoring.save_every_n_steps,
    checkpoint_dir=checkpoint_dir,
    accum_steps=demo_config.training.gradient_accumulation_steps,
    mesh=devices
)

print("\nTraining Demo Complete!")
print(f"Final Metrics: {final_state.best_metrics}")

## 6. Memory and Performance Analysis

Examine training efficiency metrics.

In [None]:
def format_metrics(metrics):
    return {
        'loss': f"{metrics['loss']:.4f}",
        'accuracy': f"{metrics['accuracy']:.4f}",
        'learning_rate': f"{metrics.get('learning_rate', 0.0):.6f}"
    }

print("Training Efficiency Summary:")
print(f"Steps Completed: {final_state.step}")
print(f"Best Metrics: {format_metrics(final_state.best_metrics)}")

if hasattr(final_state, 'tot_state') and final_state.tot_state['enabled']:
    print(f"\nTree of Thoughts Performance:")
    print(f"Thoughts per batch: {final_state.tot_state['thoughts_per_batch']}")
    print(f"Best thought score: {final_state.tot_state['best_thought_score']:.4f}")