# VishwamAI GPU Performance Analysis

This notebook provides a comprehensive analysis of VishwamAI model performance on GPU, including:
- Memory usage and efficiency
- Inference latency
- Throughput analysis
- Mixed precision benefits
- Layer-wise profiling

In [None]:
!git clone https://github.com/VishwamAI/VishwamAI
%cd VishwamAI

In [None]:
!pip install -r jax flax optax dm-haiku torch sentencepiece smallpond  seaborn numpy matplotlib

In [None]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from vishwamai.inference.optimized_inference import OptimizedInference
from vishwamai.models.transformer import VishwamAITransformer
from vishwamai.optimisation.profiling_tools import VishwamAIProfiler
from vishwamai.optimisation.performance_tuning import PerformanceTuner

sns.set_style("whitegrid")
%matplotlib inline

## Load the GPU Model

In [None]:
# Model initialization
model = VishwamAITransformer(
    vocab_size=50000,
    embed_dim=768,
    num_layers=12,
    num_heads=12,
    ff_dim=3072,
    max_seq_len=512
)

# Initialize optimization tools
optimizer = OptimizedInference()
optimizer.set_device('gpu')
optimizer.set_precision('fp16')
model = optimizer.optimize_model(model)

profiler = VishwamAIProfiler(model)
tuner = PerformanceTuner(model)

## Analyze Model Performance

In [None]:
# Generate test data
def generate_test_batch(batch_size, seq_length):
    return torch.randint(0, 50000, (batch_size, seq_length), device='cuda')

# Test different batch sizes
batch_sizes = [1, 4, 8, 16, 32]
seq_length = 512
latencies = []
memory_usage = []

for batch_size in batch_sizes:
    input_data = generate_test_batch(batch_size, seq_length)
    
    # Measure latency
    stats = profiler.profile_model(input_data)
    latencies.append(stats['avg_step_time_ms'])
    memory_usage.append(stats['peak_memory_usage_mb'])

# Plot results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

sns.lineplot(x=batch_sizes, y=latencies, ax=ax1)
ax1.set_title('Inference Latency vs Batch Size')
ax1.set_xlabel('Batch Size')
ax1.set_ylabel('Latency (ms)')

sns.lineplot(x=batch_sizes, y=memory_usage, ax=ax2)
ax2.set_title('Memory Usage vs Batch Size')
ax2.set_xlabel('Batch Size')
ax2.set_ylabel('Memory Usage (MB)')

plt.tight_layout()
plt.show()

## Visualize Model Performance

In [None]:
# Layer-wise profiling
input_data = generate_test_batch(16, 512)
layer_stats = profiler.layer_wise_profiling(input_data)

# Extract data for visualization
layer_names = [stat[0] for stat in layer_stats]
layer_latencies = [stat[1] for stat in layer_stats]
layer_memory = [stat[2] for stat in layer_stats]

# Plot layer-wise metrics
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))

sns.barplot(x=layer_names[:10], y=layer_latencies[:10], ax=ax1)
ax1.set_title('Layer-wise Latency Analysis (Top 10 Layers)')
ax1.set_xticklabels(ax1.get_xticklabels(), rotation=45)
ax1.set_ylabel('Latency (ms)')

sns.barplot(x=layer_names[:10], y=layer_memory[:10], ax=ax2)
ax2.set_title('Layer-wise Memory Usage (Top 10 Layers)')
ax2.set_xticklabels(ax2.get_xticklabels(), rotation=45)
ax2.set_ylabel('Memory Usage (MB)')

plt.tight_layout()
plt.show()

In [None]:
# Mixed precision analysis
precisions = ['fp32', 'fp16', 'bf16']
precision_latencies = []
precision_memory = []

input_data = generate_test_batch(16, 512)

for precision in precisions:
    try:
        optimizer.set_precision(precision)
        model = optimizer.optimize_model(model)
        
        stats = profiler.profile_model(input_data)
        precision_latencies.append(stats['avg_step_time_ms'])
        precision_memory.append(stats['peak_memory_usage_mb'])
    except Exception as e:
        print(f"Precision {precision} not supported: {str(e)}")

# Plot precision comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

sns.barplot(x=precisions, y=precision_latencies, ax=ax1)
ax1.set_title('Inference Latency vs Precision')
ax1.set_ylabel('Latency (ms)')

sns.barplot(x=precisions, y=precision_memory, ax=ax2)
ax2.set_title('Memory Usage vs Precision')
ax2.set_ylabel('Memory Usage (MB)')

plt.tight_layout()
plt.show()

In [None]:
# Throughput optimization
optimal_batch_size = tuner.tune_batch_size((seq_length, model.embed_dim))
print(f"Optimal batch size for maximum throughput: {optimal_batch_size}")

# Test throughput with optimal batch size
input_data = generate_test_batch(optimal_batch_size, seq_length)
stats = profiler.profile_model(input_data)

print(f"\nPerformance with optimal batch size:")
print(f"Average latency: {stats['avg_step_time_ms']:.2f} ms")
print(f"Peak memory usage: {stats['peak_memory_usage_mb']:.2f} MB")
print(f"Throughput: {(optimal_batch_size * 1000 / stats['avg_step_time_ms']):.2f} samples/second")