# VishwamAI TPU Development - Experiment 2

This notebook implements TPU-optimized transformer training with monitoring and analysis.

## Setup and Dependencies

In [None]:
import os
import jax
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn
import optax
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

from vishwamai import (
    create_vishwamai_transformer,
    create_train_state,
    EnhancedTransformerModel,
    VishwamAITrainer,
    DuckDBLogger,
    DEFAULT_CONFIG
)

# Check TPU configuration
print("JAX devices:", jax.devices())
print("Number of devices:", jax.device_count())

## Model Configuration

In [None]:
# Load and update configuration for TPU
config = DEFAULT_CONFIG.copy()

# Update model configuration
config['model_config'].update({
    'vocab_size': 32000,
    'num_layers': 12,
    'num_heads': 12,
    'head_dim': 64,
    'hidden_dim': 768,
    'mlp_dim': 3072,
    'max_seq_len': 2048,
    'use_flash_attn': True,
    'use_rotary': True,
    'use_rms_norm': True,
    'dtype': 'bfloat16',
    'compute_dtype': 'float32'
})

# Update training configuration
config['training'].update({
    'batch_size': 32 * jax.device_count(),  # Scale batch size by number of devices
    'learning_rate': 1e-4,
    'warmup_steps': 2000,
    'decay_steps': 50000,
    'weight_decay': 0.01,
    'gradient_checkpointing': True,
    'gradient_accumulation_steps': 4,
    'mixed_precision': True,
    'tpu_iterations_per_loop': 100
})

print("Model configuration:", config['model_config'])
print("\nTraining configuration:", config['training'])

## Data Generation

In [None]:
def create_dummy_data(num_samples, batch_size, seq_length, vocab_size):
    """Create dummy data optimized for TPU training"""
    # Ensure total samples is divisible by global batch size
    samples_per_device = num_samples // jax.device_count()
    total_samples = samples_per_device * jax.device_count()
    
    # Generate random data
    rng = np.random.default_rng(42)
    
    input_shape = (total_samples, seq_length)
    input_ids = rng.integers(0, vocab_size, size=input_shape)
    labels = rng.integers(0, vocab_size, size=input_shape)
    attention_mask = np.ones(input_shape)
    
    # Convert to device arrays
    return {
        'input_ids': jnp.array(input_ids),
        'labels': jnp.array(labels),
        'attention_mask': jnp.array(attention_mask)
    }

# Create training and validation datasets
train_data = create_dummy_data(
    num_samples=1000,
    batch_size=config['training']['batch_size'],
    seq_length=config['model_config']['max_seq_len'],
    vocab_size=config['model_config']['vocab_size']
)

val_data = create_dummy_data(
    num_samples=100,
    batch_size=config['training']['batch_size'],
    seq_length=config['model_config']['max_seq_len'],
    vocab_size=config['model_config']['vocab_size']
)

print("Training data shape:", train_data['input_ids'].shape)
print("Validation data shape:", val_data['input_ids'].shape)

## Model Setup

In [None]:
# Initialize model and training state
rng = jax.random.PRNGKey(42)
rng, init_rng = jax.random.split(rng)

# Create model
model = create_vishwamai_transformer(config['model_config'])

# Create trainer
trainer = VishwamAITrainer(
    config=config,
    model=model,
    experiment_name='experiment2_tpu',
    db_path='experiment2.db'
)

# Initialize training state
trainer.setup_training()

print(f"Model initialized with {sum(p.size for p in jax.tree_leaves(model.params)):,} parameters")

## Training Loop

In [None]:
def train_epoch(trainer, train_data, config):
    """Run one epoch of training"""
    batch_size = config['training']['batch_size']
    steps_per_epoch = len(train_data['input_ids']) // batch_size
    
    metrics_list = []
    
    for step in tqdm(range(steps_per_epoch)):
        # Get batch
        start_idx = step * batch_size
        end_idx = start_idx + batch_size
        
        batch = {
            'input_ids': train_data['input_ids'][start_idx:end_idx],
            'labels': train_data['labels'][start_idx:end_idx],
            'attention_mask': train_data['attention_mask'][start_idx:end_idx]
        }
        
        # Training step
        metrics = trainer.train_step(
            batch,
            dropout_rng=jax.random.PRNGKey(step)
        )
        metrics_list.append(metrics)
        
        # Log every N steps
        if step % 10 == 0:
            avg_loss = np.mean([m['loss'] for m in metrics_list[-10:]])
            print(f"\nStep {step}/{steps_per_epoch}")
            print(f"Loss: {avg_loss:.4f}")
            print(f"Learning rate: {metrics['learning_rate']:.6f}")
    
    return metrics_list

def evaluate(trainer, val_data, config):
    """Evaluate the model"""
    batch_size = config['training']['batch_size']
    steps_per_eval = len(val_data['input_ids']) // batch_size
    
    metrics_list = []
    
    for step in range(steps_per_eval):
        start_idx = step * batch_size
        end_idx = start_idx + batch_size
        
        batch = {
            'input_ids': val_data['input_ids'][start_idx:end_idx],
            'labels': val_data['labels'][start_idx:end_idx],
            'attention_mask': val_data['attention_mask'][start_idx:end_idx]
        }
        
        metrics = trainer.evaluate(batch)
        metrics_list.append(metrics)
    
    # Compute average metrics
    avg_metrics = {}
    for key in metrics_list[0].keys():
        avg_metrics[key] = np.mean([m[key] for m in metrics_list])
    
    return avg_metrics

# Training loop
num_epochs = 5
train_metrics = []
val_metrics = []

print("Starting training...")
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Training
    epoch_metrics = train_epoch(trainer, train_data, config)
    train_metrics.extend(epoch_metrics)
    
    # Evaluation
    eval_metrics = evaluate(trainer, val_data, config)
    val_metrics.append(eval_metrics)
    
    print(f"\nEpoch {epoch + 1} evaluation:")
    print(f"Validation loss: {eval_metrics['loss']:.4f}")
    
    # Save checkpoint
    trainer.save_checkpoint(f"experiment2_checkpoint_epoch_{epoch + 1}")

print("Training completed!")

## Analysis and Visualization

In [None]:
# Plot training metrics
plt.figure(figsize=(12, 4))

# Training loss
plt.subplot(1, 2, 1)
train_losses = [m['loss'] for m in train_metrics]
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('Loss')

# Validation loss
plt.subplot(1, 2, 2)
val_losses = [m['loss'] for m in val_metrics]
plt.plot(val_losses, 'r-')
plt.title('Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.tight_layout()
plt.show()

# Print final metrics
print("Final training loss:", train_losses[-1])
print("Final validation loss:", val_losses[-1])

## TPU Profiling and Monitoring

In [None]:
from jax.profiler import start_trace, stop_trace, trace
import psutil
import time

class TPUMonitor:
    """Monitor TPU performance and memory usage"""
    def __init__(self):
        self.metrics = []
        self.start_time = time.time()
    
    def capture_metrics(self):
        try:
            # Get host memory usage
            memory = psutil.Process().memory_info()
            
            # Get TPU metrics if available
            devices = jax.devices()
            device_memory = [device.memory_stats() for device in devices]
            
            metrics = {
                'timestamp': time.time() - self.start_time,
                'host_memory_mb': memory.rss / (1024 * 1024),
                'device_metrics': device_memory
            }
            
            self.metrics.append(metrics)
            return metrics
        except Exception as e:
            print(f"Error capturing metrics: {e}")
            return None
    
    def plot_metrics(self):
        timestamps = [m['timestamp'] for m in self.metrics]
        host_memory = [m['host_memory_mb'] for m in self.metrics]
        
        plt.figure(figsize=(12, 4))
        
        plt.subplot(1, 2, 1)
        plt.plot(timestamps, host_memory)
        plt.title('Host Memory Usage')
        plt.xlabel('Time (s)')
        plt.ylabel('Memory (MB)')
        
        if self.metrics[0]['device_metrics']:
            plt.subplot(1, 2, 2)
            for i, device in enumerate(jax.devices()):
                device_mem = [m['device_metrics'][i].get('peak_bytes_in_use', 0) / (1024**3)
                             for m in self.metrics]
                plt.plot(timestamps, device_mem, label=f'TPU {i}')
            
            plt.title('TPU Memory Usage')
            plt.xlabel('Time (s)')
            plt.ylabel('Memory (GB)')
            plt.legend()
        
        plt.tight_layout()
        plt.show()

# Create monitor
tpu_monitor = TPUMonitor()

In [None]:
# Update training loop with profiling
@trace("train_epoch_profile")
def train_epoch_with_profile(trainer, train_data, config):
    metrics = train_epoch(trainer, train_data, config)
    tpu_monitor.capture_metrics()
    return metrics

# Start profiling
start_trace('./tpu_profile')

# Run one epoch with profiling
print("Running profiled training epoch...")
profile_metrics = train_epoch_with_profile(trainer, train_data, config)

# Stop profiling
stop_trace()

# Plot monitoring results
tpu_monitor.plot_metrics()

print("\nProfile data saved to ./tpu_profile")

## TPU Optimization Analysis

In [None]:
def analyze_tpu_performance(monitor, profile_metrics):
    """Analyze TPU performance metrics"""
    print("TPU Performance Analysis")
    print("-" * 50)
    
    # Memory utilization
    if monitor.metrics:
        latest = monitor.metrics[-1]
        print("\nMemory Utilization:")
        print(f"Host Memory: {latest['host_memory_mb']:.2f} MB")
        
        for i, device_metrics in enumerate(latest['device_metrics']):
            mem_gb = device_metrics.get('peak_bytes_in_use', 0) / (1024**3)
            print(f"TPU {i} Peak Memory: {mem_gb:.2f} GB")
    
    # Training metrics
    if profile_metrics:
        print("\nTraining Performance:")
        losses = [m['loss'] for m in profile_metrics]
        print(f"Average Loss: {np.mean(losses):.4f}")
        print(f"Loss Std Dev: {np.std(losses):.4f}")
        
        # Compute throughput
        batch_size = config['training']['batch_size']
        total_samples = len(profile_metrics) * batch_size
        duration = monitor.metrics[-1]['timestamp'] - monitor.metrics[0]['timestamp']
        throughput = total_samples / duration
        
        print(f"\nThroughput: {throughput:.2f} samples/second")
        print(f"Batch Processing Time: {duration/len(profile_metrics)*1000:.2f} ms/batch")

# Run analysis
analyze_tpu_performance(tpu_monitor, profile_metrics)

## Model Export

In [None]:
# Save final model and configuration
import json

save_dir = 'experiment2_final'
os.makedirs(save_dir, exist_ok=True)

# Save model
trainer.save_checkpoint(f"{save_dir}/model")

# Save configuration
with open(f"{save_dir}/config.json", 'w') as f:
    json.dump(config, f, indent=2)

print(f"Model and configuration saved to {save_dir}")