# VMAP

In [3]:
import jax
import jax.numpy as jnp
from jax import random, vmap
import time
import numpy as np

# Setup: 32 attention heads, each with 64x64 query-key matrices
batch_size = 32
matrix_size = 64

# Create random matrices representing attention queries and keys
key = random.PRNGKey(0)
queries = random.normal(key, (batch_size, matrix_size, matrix_size))
keys = random.normal(key, (batch_size, matrix_size, matrix_size))

# Method 1: Manual loop (slow, not idiomatic JAX)
def matmul_loop(queries, keys):
    results = []
    for i in range(queries.shape[0]):
        results.append(jnp.matmul(queries[i], keys[i]))
    return jnp.stack(results)

# Method 2: Using vmap (fast, vectorized, parallelizable)
def matmul_single(q, k):
    return jnp.matmul(q, k)

# vmap automatically vectorizes over the first dimension
matmul_batch = vmap(matmul_single)

# Benchmark
_ = matmul_loop(queries[:2], keys[:2])  # Small warmup
_ = matmul_batch(queries[:2], keys[:2])  # Small warmup
jax.block_until_ready(_)  # Wait for GPU

print("\nTiming 100 iterations...")

# Time loop version
start = time.time()
for _ in range(100):
    result_loop = matmul_loop(queries, keys)
    jax.block_until_ready(result_loop)
loop_time = time.time() - start

# Time vmap version
start = time.time()
for _ in range(100):
    result_vmap = matmul_batch(queries, keys)
    jax.block_until_ready(result_vmap)
vmap_time = time.time() - start

print(f"\nResults:")
print(f"   Loop approach: {loop_time:.4f}s")
print(f"   vmap approach: {vmap_time:.4f}s")
print(f"   Speedup: {loop_time/vmap_time:.2f}x faster!")
print(f"\nResults match: {jnp.allclose(result_loop, result_vmap)}")




Timing 100 iterations...

Results:
   Loop approach: 2.1449s
   vmap approach: 0.1052s
   Speedup: 20.39x faster!

Results match: True


# PRNG

In [4]:


# BAD: Reusing the same key (produces identical masks!)
def dropout_bad(x, key, drop_rate=0.5):
    masks = []
    for i in range(5):  # 5 layers
        #Using same key repeatedly - ALL MASKS ARE IDENTICAL!
        mask = random.bernoulli(key, 1 - drop_rate, x.shape)
        masks.append(mask)
    return masks

# GOOD: Proper key splitting (produces independent masks)
def dropout_good(x, key, drop_rate=0.5):
    """CORRECT: Splitting key gives independent randomness"""
    masks = []
    for i in range(5):  # 5 layers
        # Split key to get independent randomness for each layer
        key, subkey = random.split(key)
        mask = random.bernoulli(subkey, 1 - drop_rate, x.shape)
        masks.append(mask)
    return masks

# BEST: Vectorized key splitting with vmap
def dropout_best(x, key, drop_rate=0.5, num_layers=5):
    """OPTIMAL: Vectorized key splitting + parallel generation"""
    # Split key into num_layers independent subkeys at once
    keys = random.split(key, num_layers)

    # Vectorize mask generation across all keys
    def make_mask(subkey):
        return random.bernoulli(subkey, 1 - drop_rate, x.shape)

    # vmap over keys dimension to generate all masks in parallel
    masks = vmap(make_mask)(keys)
    return masks

# Demo
print("Generating dropout masks for 5 layers...\n")
key = random.PRNGKey(42)
x = jnp.ones((4, 10))  # Batch of 4, feature size 10

masks_bad = dropout_bad(x, key)
masks_good = dropout_good(x, key)
masks_best = dropout_best(x, key)

print("BAD (reusing key) - First 3 masks:")
for i in range(3):
    print(f"   Layer {i}: {masks_bad[i][0, :5].astype(int)}...")
print("   PROBLEM: All masks are IDENTICAL!")

print("\nGOOD (splitting key) - First 3 masks:")
for i in range(3):
    print(f"   Layer {i}: {masks_good[i][0, :5].astype(int)}...")
print("   Each layer has independent randomness")

print("\nBEST (vectorized splitting) - First 3 masks:")
for i in range(3):
    print(f"   Layer {i}: {masks_best[i][0, :5].astype(int)}...")
print("   Independent + parallelized generation")

# Verify independence
print("\nStatistical verification:")
correlation_bad = jnp.corrcoef(masks_bad[0].flatten(), masks_bad[1].flatten())[0, 1]
correlation_good = jnp.corrcoef(masks_good[0].flatten(), masks_good[1].flatten())[0, 1]

print(f"   Correlation between layers (bad):  {correlation_bad:.4f} (should be ~0)")
print(f"   Correlation between layers (good): {correlation_good:.4f} (actually independent)")

check_same = lambda masks: all(jnp.allclose(masks[0], m) for m in masks[1:])
print(f"All BAD masks identical?  {check_same(masks_bad)}")
print(f"All GOOD masks identical? {check_same(masks_good)}\nAll BEST masks identical?  {check_same(masks_best)}")


#adds more randomness, model learns in a better way
#Pytorch does vmap by default because it is not functional JAX is functional so we need to do manually
# For key splitting it is the same reason , pytorch is a glaobal state managed framework so defining one global seed will produce similar results,
# whereas the JAX is functional so we need to define explicitly

Generating dropout masks for 5 layers...

BAD (reusing key) - First 3 masks:
   Layer 0: [1 0 0 0 1]...
   Layer 1: [1 0 0 0 1]...
   Layer 2: [1 0 0 0 1]...
   PROBLEM: All masks are IDENTICAL!

GOOD (splitting key) - First 3 masks:
   Layer 0: [0 0 1 1 1]...
   Layer 1: [1 1 1 1 0]...
   Layer 2: [1 1 1 1 0]...
   Each layer has independent randomness

BEST (vectorized splitting) - First 3 masks:
   Layer 0: [0 1 0 0 0]...
   Layer 1: [0 0 1 1 1]...
   Layer 2: [0 0 1 1 1]...
   Independent + parallelized generation

Statistical verification:
   Correlation between layers (bad):  1.0000 (should be ~0)
   Correlation between layers (good): -0.1549 (actually independent)
All BAD masks identical?  True
All GOOD masks identical? False
All BEST masks identical?  False


# bfloat()

# With TPU

In [1]:
import jax
print(jax.devices())



[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)]


In [5]:
# Simple 2-layer neural network
def create_params(key, input_size, hidden_size, output_size, dtype):
    """Initialize network parameters in specified dtype"""
    k1, k2 = random.split(key)
    w1 = random.normal(k1, (input_size, hidden_size), dtype=dtype) * 0.01
    w2 = random.normal(k2, (hidden_size, output_size), dtype=dtype) * 0.01
    return {'w1': w1, 'w2': w2}

def forward(params, x):
    """Forward pass through 2-layer network"""
    h = jnp.tanh(x @ params['w1'])  # Hidden layer
    return h @ params['w2']  # Output layer

def loss_fn(params, x, y):
    """MSE loss"""
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

# Training setup
key = random.PRNGKey(0)
input_size, hidden_size, output_size = 1024, 2048, 512
batch_size = 128

# Generate dummy data
x_data = random.normal(key, (batch_size, input_size))
y_data = random.normal(key, (batch_size, output_size))

# Create models in different precisions
print("Creating models in different precisions...\n")

params_fp32 = create_params(key, input_size, hidden_size, output_size, jnp.float32)
params_bf16 = create_params(key, input_size, hidden_size, output_size, jnp.bfloat16)

# Convert input data
x_fp32 = x_data.astype(jnp.float32)
x_bf16 = x_data.astype(jnp.bfloat16)
y_fp32 = y_data.astype(jnp.float32)
y_bf16 = y_data.astype(jnp.bfloat16)

# Memory comparison
def get_memory_mb(params):
    """Calculate memory usage in MB"""
    total_bytes = sum(p.nbytes for p in jax.tree_util.tree_leaves(params))
    return total_bytes / (1024 ** 2)

mem_fp32 = get_memory_mb(params_fp32)
mem_bf16 = get_memory_mb(params_bf16)

print(f"Memory Usage:")
print(f"   float32 params: {mem_fp32:.2f} MB")
print(f"   bfloat16 params: {mem_bf16:.2f} MB")
print(f"   Memory saved: {mem_fp32 - mem_bf16:.2f} MB ({(1 - mem_bf16/mem_fp32)*100:.1f}%)")

# Computational speed comparison
print("\nSpeed comparison (100 forward passes)...")

# Warm up
_ = forward(params_fp32, x_fp32)
_ = forward(params_bf16, x_bf16)
jax.block_until_ready(_)

# Time float32
start = time.time()
for _ in range(100):
    out_fp32 = forward(params_fp32, x_fp32)
    jax.block_until_ready(out_fp32)
time_fp32 = time.time() - start

# Time bfloat16
start = time.time()
for _ in range(100):
    out_bf16 = forward(params_bf16, x_bf16)
    jax.block_until_ready(out_bf16)
time_bf16 = time.time() - start

print(f"\nComputation Time:")
print(f"   float32:  {time_fp32:.4f}s")
print(f"   bfloat16: {time_bf16:.4f}s")
if time_bf16 < time_fp32:
    print(f"   Speedup: {time_fp32/time_bf16:.2f}x faster with bfloat16!")
else:
    print(f"   Note: Speedup varies by hardware (TPU shows 2-4x gains)")

# Accuracy comparison
print("\nNumerical Accuracy:")
loss_fp32 = loss_fn(params_fp32, x_fp32, y_fp32)
loss_bf16 = loss_fn(params_bf16, x_bf16, y_bf16)

# Convert to Python floats to ensure proper formatting
loss_fp32_val = float(loss_fp32)
loss_bf16_val = float(loss_bf16)

print(f"   float32 loss:  {loss_fp32_val:.6f}")
print(f"   bfloat16 loss: {loss_bf16_val:.6f}")
print(f"   Difference:    {abs(loss_fp32_val - loss_bf16_val):.6f}")



# float32 (32 bits):
# [S][EEEEEEEE][MMMMMMMMMMMMMMMMMMMMMMM]
#  1    8 bits      23 bits
#  ↑    exponent    mantissa (precision)
# sign

# bfloat16 (16 bits):
# [S][EEEEEEEE][MMMMMMM]
#  1    8 bits   7 bits
#  ↑    exponent mantissa
# sign

# float16 (16 bits):
# [S][EEEEE][MMMMMMMMMM]
#  1  5 bits  10 bits
#  ↑  exponent mantissa
# sign

### When large language models have deep layers which is more than 150+ complex layers the bfloat saves time and memory footprint with almost same efficiency in the precision which is not that much required for gradients

Creating models in different precisions...

Memory Usage:
   float32 params: 12.00 MB
   bfloat16 params: 6.00 MB
   Memory saved: 6.00 MB (50.0%)

Speed comparison (100 forward passes)...

Computation Time:
   float32:  0.0242s
   bfloat16: 0.0231s
   Speedup: 1.05x faster with bfloat16!

Numerical Accuracy:
   float32 loss:  1.024064
   bfloat16 loss: 1.023438
   Difference:    0.000626


# With GPU

In [1]:
import jax
print(jax.devices())

[CudaDevice(id=0)]


In [3]:
import jax
import jax.numpy as jnp
from jax import random, vmap
import time
import numpy as np

# Simple 2-layer neural network
def create_params(key, input_size, hidden_size, output_size, dtype):
    """Initialize network parameters in specified dtype"""
    k1, k2 = random.split(key)
    w1 = random.normal(k1, (input_size, hidden_size), dtype=dtype) * 0.01
    w2 = random.normal(k2, (hidden_size, output_size), dtype=dtype) * 0.01
    return {'w1': w1, 'w2': w2}

def forward(params, x):
    """Forward pass through 2-layer network"""
    h = jnp.tanh(x @ params['w1'])  # Hidden layer
    return h @ params['w2']  # Output layer

def loss_fn(params, x, y):
    """MSE loss"""
    pred = forward(params, x)
    return jnp.mean((pred - y) ** 2)

# Training setup
key = random.PRNGKey(0)
input_size, hidden_size, output_size = 1024, 2048, 512
batch_size = 128

# Generate dummy data
x_data = random.normal(key, (batch_size, input_size))
y_data = random.normal(key, (batch_size, output_size))

# Create models in different precisions
print("Creating models in different precisions...\n")

params_fp32 = create_params(key, input_size, hidden_size, output_size, jnp.float32)
params_bf16 = create_params(key, input_size, hidden_size, output_size, jnp.bfloat16)

# Convert input data
x_fp32 = x_data.astype(jnp.float32)
x_bf16 = x_data.astype(jnp.bfloat16)
y_fp32 = y_data.astype(jnp.float32)
y_bf16 = y_data.astype(jnp.bfloat16)

# Memory comparison
def get_memory_mb(params):
    """Calculate memory usage in MB"""
    total_bytes = sum(p.nbytes for p in jax.tree_util.tree_leaves(params))
    return total_bytes / (1024 ** 2)

mem_fp32 = get_memory_mb(params_fp32)
mem_bf16 = get_memory_mb(params_bf16)

print(f"Memory Usage:")
print(f"   float32 params: {mem_fp32:.2f} MB")
print(f"   bfloat16 params: {mem_bf16:.2f} MB")
print(f"   Memory saved: {mem_fp32 - mem_bf16:.2f} MB ({(1 - mem_bf16/mem_fp32)*100:.1f}%)")

# Computational speed comparison
print("\nSpeed comparison (100 forward passes)...")

# Warm up
_ = forward(params_fp32, x_fp32)
_ = forward(params_bf16, x_bf16)
jax.block_until_ready(_)

# Time float32
start = time.time()
for _ in range(100):
    out_fp32 = forward(params_fp32, x_fp32)
    jax.block_until_ready(out_fp32)
time_fp32 = time.time() - start

# Time bfloat16
start = time.time()
for _ in range(100):
    out_bf16 = forward(params_bf16, x_bf16)
    jax.block_until_ready(out_bf16)
time_bf16 = time.time() - start

print(f"\nComputation Time:")
print(f"   float32:  {time_fp32:.4f}s")
print(f"   bfloat16: {time_bf16:.4f}s")
if time_bf16 < time_fp32:
    print(f"   Speedup: {time_fp32/time_bf16:.2f}x faster with bfloat16!")
else:
    print(f"   Note: Speedup varies by hardware (TPU shows 2-4x gains)")

# Accuracy comparison
print("\nNumerical Accuracy:")
loss_fp32 = loss_fn(params_fp32, x_fp32, y_fp32)
loss_bf16 = loss_fn(params_bf16, x_bf16, y_bf16)

# Convert to Python floats to ensure proper formatting
loss_fp32_val = float(loss_fp32)
loss_bf16_val = float(loss_bf16)

print(f"   float32 loss:  {loss_fp32_val:.6f}")
print(f"   bfloat16 loss: {loss_bf16_val:.6f}")
print(f"   Difference:    {abs(loss_fp32_val - loss_bf16_val):.6f}")


Creating models in different precisions...

Memory Usage:
   float32 params: 12.00 MB
   bfloat16 params: 6.00 MB
   Memory saved: 6.00 MB (50.0%)

Speed comparison (100 forward passes)...

Computation Time:
   float32:  0.0564s
   bfloat16: 0.1586s
   Note: Speedup varies by hardware (TPU shows 2-4x gains)

Numerical Accuracy:
   float32 loss:  1.024060
   bfloat16 loss: 1.023438
   Difference:    0.000623


In [None]:
# bfloat is not saving time on GPU's?