# TPU Model Analysis

This notebook analyzes the performance and efficiency of the TPU-optimized transformer model from `vishwamai/models/tpu/transformer.py`.

In [ ]:
import jax
import jax.numpy as jnp
from jax import random
from vishwamai.models.tpu.transformer import VishwamAITransformer
import time
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="whitegrid")

## Load the TPU Model

In [ ]:
# Define model parameters
vocab_size = 50000
embed_dim = 512
num_layers = 12
num_heads = 8
ff_dim = 2048
max_seq_len = 512
attention_kwargs = {"num_experts": 4, "taa_kwargs": {"k": 10, "kernel_dim": 256}}

# Initialize the model
rng = random.PRNGKey(0)
model = VishwamAITransformer(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    num_layers=num_layers,
    num_heads=num_heads,
    ff_dim=ff_dim,
    max_seq_len=max_seq_len,
    attention_kwargs=attention_kwargs
)
params = model.init(rng, jnp.ones((1, 5), dtype=jnp.int32))['params']

## Analyze Model Performance

In [ ]:
# Generate random input data
batch_size = 16
seq_len = 128
input_ids = random.randint(rng, (batch_size, seq_len), 0, vocab_size)

# Measure inference time
start_time = time.time()
logits = model.apply({'params': params}, input_ids)
end_time = time.time()
inference_time = end_time - start_time
print(f"Inference time: {inference_time:.4f} seconds")

## Visualize Model Performance

In [ ]:
# Plot inference time
plt.figure(figsize=(10, 6))
sns.barplot(x=["TPU Model"], y=[inference_time])
plt.ylabel("Inference Time (seconds)")
plt.title("TPU Model Inference Time")
plt.show()