# VishwamAI Training Performance Analysis

This notebook analyzes the training performance of VishwamAI model using the provided train/test parquet datasets.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any

from vishwamai.training import TPUTrainingConfig, create_train_state_tpu, create_train_step_tpu
from vishwamai.profiler import TPUProfiler
from vishwamai.transformer import create_vishwamai_transformer

In [None]:
# Load training and test data
train_data = pd.read_parquet('train-00000-of-00001.parquet')
test_data = pd.read_parquet('test-00000-of-00001.parquet')

print("Training data shape:", train_data.shape)
print("Test data shape:", test_data.shape)

In [None]:
def create_training_config():
    """Create TPU-optimized training configuration"""
    model_config = {
        'vocab_size': 32000,
        'num_layers': 12,
        'num_heads': 12,
        'head_dim': 64,
        'hidden_dim': 768,
        'mlp_dim': 3072,
        'max_seq_len': 2048,
        'dropout_rate': 0.1,
        'use_flash_attn': True,
        'use_rms_norm': False
    }
    
    return TPUTrainingConfig(
        model_config=model_config,
        batch_size=32,
        grad_accum_steps=4,
        learning_rate=1e-4,
        warmup_steps=2000,
        max_steps=100000,
        dtype='bfloat16',
        enable_pjit=True,
        block_size=128,
        use_flash_attn=True,
        mixed_precision=True
    )

In [None]:
# Initialize training components
config = create_training_config()
rng = jax.random.PRNGKey(42)

# Create training state and step function
state = create_train_state_tpu(config, rng)
train_step = create_train_step_tpu(config, state)

# Initialize profiler
profiler = TPUProfiler(config=config.model_config, log_dir='training_profiles')

In [None]:
def analyze_training_metrics(profiler: TPUProfiler, num_steps: int = 100):
    """Analyze training metrics over multiple steps"""
    metrics = {
        'step_time': [],
        'throughput': [],
        'memory_used': [],
        'tpu_utilization': []
    }
    
    for step in range(num_steps):
        profiler.start_step()
        
        # Simulate training step
        batch_size = config.batch_size * config.grad_accum_steps
        profiler.record_batch_time(batch_size, 0.1)  # Example duration
        profiler.measure_tpu_utilization()
        
        profiler.end_step()
        
        # Collect metrics
        summary = profiler.get_metrics_summary()
        metrics['step_time'].append(summary['step_time_mean'])
        metrics['throughput'].append(summary.get('steps_per_second', 0))
        metrics['memory_used'].append(summary.get('memory_accessed_mean', 0))
        metrics['tpu_utilization'].append(summary.get('tpu_utilization_mean', 0))
    
    return metrics

# Run analysis
training_metrics = analyze_training_metrics(profiler)

# Plot metrics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training Performance Metrics')

axes[0, 0].plot(training_metrics['step_time'])
axes[0, 0].set_title('Step Time')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Time (s)')

axes[0, 1].plot(training_metrics['throughput'])
axes[0, 1].set_title('Throughput')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Steps/second')

axes[1, 0].plot(training_metrics['memory_used'])
axes[1, 0].set_title('Memory Usage')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Bytes')

axes[1, 1].plot(training_metrics['tpu_utilization'])
axes[1, 1].set_title('TPU Utilization')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Utilization %')

plt.tight_layout()

In [None]:
# Get performance recommendations
recommendations = profiler.get_performance_recommendations()
print("\nPerformance Recommendations:")
for i, rec in enumerate(recommendations, 1):
    print(f"\n{i}. {rec}")

## Model Performance Analysis

Current metrics from previous analysis:
- Total Parameters: 109,529,088
- MoE Parameters: 301,991,936
- Memory Usage:
  - Activations: 96.00 MB
  - Attention: 3072.00 MB
  - KV Cache: 192.00 MB

In [None]:
def profile_model_performance(config: Dict[str, Any]):
    """Profile model inference performance"""
    model = create_vishwamai_transformer(config)
    batch_size = 1
    seq_length = config['max_seq_len']
    
    # Create dummy input
    dummy_input = jnp.ones((batch_size, seq_length), dtype=jnp.int32)
    
    # Profile memory
    memory_profile = profiler.profile_memory_usage(
        lambda x: model.apply({'params': state.params}, x),
        {'input': dummy_input.shape}
    )
    
    return memory_profile

# Run performance profiling
perf_metrics = profile_model_performance(config.model_config)
print("\nModel Performance Profile:")
for k, v in perf_metrics.items():
    print(f"{k}: {v/1e9:.2f} GB")