# GPU Model Analysis

This notebook analyzes the performance and efficiency of the GPU-optimized transformer model from `vishwamai/models/gpu/transformer.py`.

In [ ]:
import torch
from vishwamai.models.gpu.transformer import VishwamAITransformer
import time
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")

## Load the GPU Model

In [ ]:
# Define model parameters
vocab_size = 50000
embed_dim = 512
num_layers = 12
num_heads = 8
ff_dim = 2048
max_seq_len = 512
attention_kwargs = {"num_experts": 4, "taa_kwargs": {"k": 10, "kernel_dim": 256}}

# Initialize the model
model = VishwamAITransformer(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    ff_dim=ff_dim,
    max_seq_len=max_seq_len,
    attention_kwargs=attention_kwargs
)
model = model.to('cuda')

## Analyze Model Performance

In [ ]:
# Generate random input data
batch_size = 16
seq_len = 128
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)).to('cuda')

# Measure inference time
start_time = time.time()
with torch.no_grad():
    output = model(input_ids)
end_time = time.time()
inference_time = end_time - start_time
print(f"Inference time: {inference_time:.4f} seconds")

## Visualize Model Performance

In [ ]:
# Plot inference time
plt.figure(figsize=(10, 6))
sns.barplot(x=["GPU Model"], y=[inference_time])
plt.ylabel("Inference Time (seconds)")
plt.title("GPU Model Inference Time")
plt.show()