# VishwamAI Training Performance Analysis

This notebook analyzes the training performance with TPU optimizations and fixed FlashAttention implementation.

In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, Any

from vishwamai.training import create_train_state_tpu, create_train_step_tpu
from vishwamai.profiler import TPUProfiler
from vishwamai.transformer import create_vishwamai_transformer

2025-03-19 17:43:46.445221: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742386426.519071   16025 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742386426.539317   16025 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [2]:
# Load training and test data
train_data = pd.read_parquet('/home/kasinadhsarma/VishwamAI/train-00000-of-00001.parquet')
test_data = pd.read_parquet('/home/kasinadhsarma/VishwamAI/test-00000-of-00001.parquet')

print("Training data shape:", train_data.shape)
print("Test data shape:", test_data.shape)

Training data shape: (7473, 2)
Test data shape: (1319, 2)


In [3]:
def create_training_config():
    """Create TPU-optimized training configuration"""
    model_config = {
        'vocab_size': 32000,
        'num_layers': 12,
        'num_heads': 12,
        'head_dim': 64,
        'hidden_dim': 768,
        'mlp_dim': 3072,
        'max_seq_len': 2048,
        'dropout_rate': 0.1,
        'use_flash_attn': True,
        'use_rotary': True,
        'use_rms_norm': False
    }
    
    return {
        'model_config': model_config,
        'batch_size': 32,
        'grad_accum_steps': 4,
        'learning_rate': 1e-4,
        'warmup_steps': 2000,
        'max_steps': 100000,
        'dtype': jnp.bfloat16,
        'enable_pjit': True,
        'block_size': 128,
        'mixed_precision': True
    }

In [4]:
# Initialize training components
config = create_training_config()
rng = jax.random.PRNGKey(42)

# Create model and initialize training state
print("Initializing model and training state...")
state = create_train_state_tpu(config, rng)
train_step = create_train_step_tpu(config, state)

# Initialize profiler
profiler = TPUProfiler(config=config['model_config'])



Initializing model and training state...


AttributeError: 'dict' object has no attribute 'learning_rate'

In [None]:
def analyze_training_metrics(profiler: TPUProfiler, num_steps: int = 100):
    """Analyze training metrics over multiple steps"""
    metrics = {
        'step_time': [],
        'throughput': [],
        'memory_used': [],
        'tpu_utilization': []
    }
    
    for step in range(num_steps):
        profiler.start_step()
        
        # Simulate training step
        batch_size = config['batch_size'] * config['grad_accum_steps']
        profiler.record_batch_time(batch_size, 0.1)  # Example duration
        profiler.measure_tpu_utilization()
        
        profiler.end_step()
        
        # Collect metrics
        summary = profiler.get_metrics_summary()
        metrics['step_time'].append(summary['step_time_mean'])
        metrics['throughput'].append(summary.get('steps_per_second', 0))
        metrics['memory_used'].append(summary.get('memory_accessed_mean', 0))
        metrics['tpu_utilization'].append(summary.get('tpu_utilization_mean', 0))
    
    return metrics

# Run analysis
training_metrics = analyze_training_metrics(profiler)

# Plot metrics
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training Performance Metrics')

axes[0, 0].plot(training_metrics['step_time'])
axes[0, 0].set_title('Step Time')
axes[0, 0].set_xlabel('Step')
axes[0, 0].set_ylabel('Time (s)')

axes[0, 1].plot(training_metrics['throughput'])
axes[0, 1].set_title('Throughput')
axes[0, 1].set_xlabel('Step')
axes[0, 1].set_ylabel('Steps/second')

axes[1, 0].plot(training_metrics['memory_used'])
axes[1, 0].set_title('Memory Usage')
axes[1, 0].set_xlabel('Step')
axes[1, 0].set_ylabel('Bytes')

axes[1, 1].plot(training_metrics['tpu_utilization'])
axes[1, 1].set_title('TPU Utilization')
axes[1, 1].set_xlabel('Step')
axes[1, 1].set_ylabel('Utilization %')

plt.tight_layout()

In [None]:
# Create configuration with proper dtype settings
config = create_training_config()

# Initialize model and check params
print("Creating model...")
model = create_vishwamai_transformer(config)

# Initialize training components
print("\nInitializing training state...")
rng = jax.random.PRNGKey(42)
state = create_train_state_tpu(config, rng)
train_step = create_train_step_tpu(config, state)

# Initialize profiler with proper config
print("\nSetting up profiler...")
profiler = TPUProfiler(config=config['model_config'])

# Run analysis
print("\nAnalyzing training metrics...")
training_metrics = analyze_training_metrics(profiler)

# Plot results
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
fig.suptitle('Training Performance Analysis')

steps = range(len(training_metrics['step_time']))

axes[0,0].plot(steps, training_metrics['step_time'])
axes[0,0].set_title('Step Time')
axes[0,0].set_xlabel('Step')
axes[0,0].set_ylabel('Time (s)')

axes[0,1].plot(steps, training_metrics['throughput'])
axes[0,1].set_title('Training Throughput')
axes[0,1].set_xlabel('Step')
axes[0,1].set_ylabel('Steps/Second')

axes[1,0].plot(steps, training_metrics['memory_used'])
axes[1,0].set_title('Memory Usage')
axes[1,0].set_xlabel('Step')
axes[1,0].set_ylabel('Bytes')

axes[1,1].plot(steps, training_metrics['tpu_utilization'])
axes[1,1].set_title('TPU Utilization')
axes[1,1].set_xlabel('Step')
axes[1,1].set_ylabel('Utilization %')

plt.tight_layout()
plt.show()

In [None]:
# Get performance recommendations
recommendations = profiler.get_performance_recommendations()
print("\nPerformance Recommendations:")
for i, rec in enumerate(recommendations, 1):
    print(f"\n{i}. {rec}")

## Model Performance Analysis

Current metrics from previous analysis:
- Total Parameters: 109,529,088
- MoE Parameters: 301,991,936
- Memory Usage:
  - Activations: 96.00 MB
  - Attention: 3072.00 MB
  - KV Cache: 192.00 MB

In [None]:
def profile_model_performance(config: Dict[str, Any]):
    """Profile model inference performance"""
    model = create_vishwamai_transformer(config)
    batch_size = 1
    seq_length = config['max_seq_len']
    
    # Create dummy input
    dummy_input = jnp.ones((batch_size, seq_length), dtype=jnp.int32)
    
    # Profile memory
    memory_profile = profiler.profile_memory_usage(
        lambda x: model.apply({'params': state.params}, x),
        {'input': dummy_input.shape}
    )
    
    return memory_profile

# Run performance profiling
perf_metrics = profile_model_performance(config['model_config'])
print("\nModel Performance Profile:")
for k, v in perf_metrics.items():
    print(f"{k}: {v/1e9:.2f} GB")

In [None]:
def benchmark_attention(batch_size: int = 32, seq_len: int = 512):
    """Benchmark different attention implementations"""
    # Generate dummy inputs
    rng = jax.random.PRNGKey(0)
    x = jax.random.normal(rng, (batch_size, seq_len, config['model_config']['hidden_dim']))
    
    # Standard attention
    def run_std_attention():
        q = k = v = x
        scores = jnp.einsum('bqd,bkd->bqk', q, k)
        scores = scores / jnp.sqrt(config['model_config']['head_dim'])
        attn = jax.nn.softmax(scores)
        return jnp.einsum('bqk,bkd->bqd', attn, v)
    
    # Flash attention
    def run_flash_attention():
        q = k = v = x.reshape(batch_size, seq_len, 
                             config['model_config']['num_heads'], 
                             config['model_config']['head_dim'])
        return FlashAttention(
            num_heads=config['model_config']['num_heads'],
            head_dim=config['model_config']['head_dim']
        )(q, k, v)
    
    # Benchmark
    std_time = %timeit -o -n 10 -r 3 -q run_std_attention()
    flash_time = %timeit -o -n 10 -r 3 -q run_flash_attention()
    
    return {
        'standard_attention_ms': std_time.best * 1000,
        'flash_attention_ms': flash_time.best * 1000,
        'speedup': std_time.best / flash_time.best
    }

In [None]:
# Run benchmarks
print("Running attention benchmarks...")
results_512 = benchmark_attention(seq_len=512)
results_1024 = benchmark_attention(seq_len=1024)
results_2048 = benchmark_attention(seq_len=2048)

# Plot results
seq_lens = [512, 1024, 2048]
std_times = [results_512['standard_attention_ms'],
            results_1024['standard_attention_ms'],
            results_2048['standard_attention_ms']]
flash_times = [results_512['flash_attention_ms'],
              results_1024['flash_attention_ms'],
              results_2048['flash_attention_ms']]

plt.figure(figsize=(10, 6))
plt.plot(seq_lens, std_times, 'b-', label='Standard Attention')
plt.plot(seq_lens, flash_times, 'r-', label='Flash Attention')
plt.xlabel('Sequence Length')
plt.ylabel('Time (ms)')
plt.title('Attention Performance Comparison')
plt.legend()
plt.grid(True)
plt.show()

print(f"\nSpeedup at different sequence lengths:")
print(f"512: {results_512['speedup']:.2f}x")
print(f"1024: {results_1024['speedup']:.2f}x")
print(f"2048: {results_2048['speedup']:.2f}x")