# Check GPU availability

In [1]:

import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

if torch.cuda.is_available():
    print("GPU name:", torch.cuda.get_device_name(0))
    print("GPU memory:", torch.cuda.get_device_properties(0).total_memory / 1e9, "GB")
else:
    print("No GPU found - will use CPU")

PyTorch version: 2.8.0+cu128
CUDA available: True
GPU name: NVIDIA RTX 4000 Ada Generation
GPU memory: 21.01805056 GB


# Imports

In [5]:
import torch
import torch.nn as nn
import math

# Raw Implementation of Multi Head Attention



In [15]:
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    """
    This is the standard multi-head attention class.
    """
    def __init__(self, d_model, num_heads, dropout=0.1):  # This runs when you create an object of this class
        super().__init__()  # This is used to call nn.module's init method which initializes the methods and attributes of the nn.module class
        assert d_model % num_heads == 0
        
        # We are storing all these so that they can be anywhere in the code
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
       
        # nn.Linear is PyTorch's fully connected (dense) layer that performs a linear transformation on the input.
        # It takes the input and multiplies it by a weight matrix and adds a bias term.
        # So it does a y=xw^T+b
        
        # So we need to create projections for Q, K, V (the parameters are input_dim, output_dim), so self.q_proj will create a weight matrix of size d_model x d_model,the weight initlization follows Xavier/Kaiming Initilication
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)

        # Post combination of all heads we need a final projection
        self.out_proj = nn.Linear(d_model, d_model)

        # Dropout helps us to randomly drop out some neurons to prevent overfitting
        self.dropout = nn.Dropout(dropout)
        
    # This is the method which runs when you call the model
    def forward(self, x):
        # This is tuple unpacking
        batch_size, seq_len, _ = x.size()  # Fixed: using _ instead of d_model to avoid shadowing

        # Now we need to project the input matrix into a different matrix
        # So we need to create projections for Q, K, V
        # Q: What am i looking for?
        # K: What do i contain?
        # V: What information do i have?

        Q = self.q_proj(x)  # Query = x@W_q^T + b_q  #This actually calls the forward method
        K = self.k_proj(x)  # Key = x@W_k^T + b_k
        V = self.v_proj(x)  # Value = x@W_v^T + b_v
        
        # Now we wish to split the query, key and value matrices into multiple attention heads so that we can perform parallel computations
        # Now we are reshaping the matrix to (batch_size, seq_len, num_heads, head_dim)
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Now we need to transpose the matrix to put heads first
        # We are doing this since we want to compute attention for each head separately
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Compute attention scores
        # Scaling prevents softmax from saturating
        # scores[i,j]: how much token i should attend to token j high score means more attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Convert to probabilities
        attn_weights = torch.softmax(scores, dim=-1)

        # Apply dropout to the attention weights
        attn_weights = self.dropout(attn_weights)

        # We need to multiply with V
        # (batch_size, num_heads, seq_len, seq_len) * (batch_size, num_heads, seq_len, head_dim)
        # Here we are taking combination of information from all the heads weighted by attention
        output = torch.matmul(attn_weights, V)
        
        # We need to concatenate heads back
        # This is done to transpose the output and make it contiguous in memory (since a simple transpose is not contiguous)
        output = output.transpose(1, 2).contiguous()
        # This is concatenation of heads
        output = output.view(batch_size, seq_len, self.d_model)  # Fixed: batch -> batch_size, d_model -> self.d_model

        # Final Projection
        output = self.out_proj(output)

        return output


# Test the implementation
if __name__ == "__main__":
    # Create model
    model = MultiHeadAttention(d_model=512, num_heads=8, dropout=0.1)
    
    # Create input
    batch_size = 32
    seq_len = 10
    x = torch.randn(batch_size, seq_len, 512)
    
    # Forward pass
    output = model(x)
    
    print(f"Input shape:  {x.shape}")       # [32, 10, 512]
    print(f"Output shape: {output.shape}")   # [32, 10, 512]
    print("Multi-head attention works!")

Input shape:  torch.Size([32, 10, 512])
Output shape: torch.Size([32, 10, 512])
Multi-head attention works!


# Raw Implementation for GQA

In [16]:
import torch
import torch.nn as nn
import math

class GroupedQueryAttention(nn.Module):
    """
    This is the Grouped Query Attention (GQA) class.
    """ 
    def __init__(self, d_model, num_heads, num_kv_heads, dropout=0.1):  # This runs when you create an object of this class
        super().__init__()  # This is used to call nn.module's init method which initializes the methods and attributes of the nn.module class
        assert d_model % num_heads == 0
        assert num_heads % num_kv_heads == 0
        
        # We are storing all these so that they can be anywhere in the code
        self.d_model = d_model
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads  # In GQA, we have num_kv_heads for K and V, and num_heads for Q. K and V heads are shared across Q head groups, so their initialization is different
        self.head_dim = d_model // num_heads
        self.group_size = num_heads // num_kv_heads  # How many Q heads will share one K, V head
       
        # nn.Linear is PyTorch's fully connected (dense) layer that performs a linear transformation on the input.
        # It takes the input and multiplies it by a weight matrix and adds a bias term.
        # So it does a y=xw^T+b
        
        # So we need to create projections for Q, K, V (the parameters are input_dim, output_dim), so self.q_proj will create a weight matrix of size d_model x d_model, the weight initialization follows Xavier/Kaiming initialization
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(d_model, num_kv_heads * self.head_dim)

        # Post combination of all heads we need a final projection
        self.out_proj = nn.Linear(d_model, d_model)

        # Dropout helps us to randomly drop out some neurons to prevent overfitting
        self.dropout = nn.Dropout(dropout)
        
    # This is the method which runs when you call the model
    def forward(self, x):
        # This is tuple unpacking
        batch_size, seq_len, _ = x.size()  # Fixed: using _ instead of d_model to avoid shadowing

        # Now we need to project the input matrix into a different matrix
        # So we need to create projections for Q, K, V
        # Q: What am i looking for?
        # K: What do i contain?
        # V: What information do i have?

        Q = self.q_proj(x)  # Query = x@W_q^T + b_q  # This actually calls the forward method
        K = self.k_proj(x)  # Key = x@W_k^T + b_k
        V = self.v_proj(x)  # Value = x@W_v^T + b_v
        
        # Now we wish to split the query, key and value matrices into multiple attention heads so that we can perform parallel computations
        # Now we are reshaping the matrix to (batch_size, seq_len, num_heads, head_dim) for Q and (batch_size, seq_len, num_kv_heads, head_dim) for K and V
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)

        # Now we need to transpose the matrix to put heads first
        # We are doing this since we want to compute attention for each head separately
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # GQA: Repeat K and V heads to match Q's number of heads
        # repeat_interleave is used to repeat the elements of the tensor along a given dimension
        # Each K, V head is shared by group_size Q heads
        K = K.repeat_interleave(self.group_size, dim=1)  # (batch, num_heads, seq_len, head_dim)
        V = V.repeat_interleave(self.group_size, dim=1)  # (batch, num_heads, seq_len, head_dim)

        # Compute attention scores
        # Scaling prevents softmax from saturating
        # scores[i,j]: how much token i should attend to token j high score means more attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Convert to probabilities
        attn_weights = torch.softmax(scores, dim=-1)

        # Apply dropout to the attention weights
        attn_weights = self.dropout(attn_weights)

        # We need to multiply with V
        # (batch_size, num_heads, seq_len, seq_len) * (batch_size, num_heads, seq_len, head_dim)
        # Here we are taking combination of information from all the heads weighted by attention
        output = torch.matmul(attn_weights, V)
        
        # We need to concatenate heads back
        # This is done to transpose the output and make it contiguous in memory (since a simple transpose is not contiguous)
        output = output.transpose(1, 2).contiguous()
        # This is concatenation of heads
        output = output.view(batch_size, seq_len, self.d_model)  # Fixed: batch -> batch_size, d_model -> self.d_model

        # Final Projection
        output = self.out_proj(output)

        return output


# Test the implementation
if __name__ == "__main__":
    # Create GQA model
    # 8 Q heads, 2 K/V heads → 4 Q heads share each K/V head
    model = GroupedQueryAttention(d_model=512, num_heads=8, num_kv_heads=2, dropout=0.1)
    
    # Create input
    batch_size = 32
    seq_len = 10
    x = torch.randn(batch_size, seq_len, 512)
    
    # Forward pass
    output = model(x)
    
    print(f"Input shape:  {x.shape}")       # [32, 10, 512]
    print(f"Output shape: {output.shape}")   # [32, 10, 512]
    print("✅ Grouped Query Attention works!")

Input shape:  torch.Size([32, 10, 512])
Output shape: torch.Size([32, 10, 512])
✅ Grouped Query Attention works!


# MLA from scratch

In [17]:
import torch
import torch.nn as nn
import math

class MultiHeadLatentAttention(nn.Module):  # Inheritance
    def __init__(self, d_model, num_heads, d_latent, dropout=0.1):  # d_model is the dimension of the model, d_latent is the compression vector
        """
        Implementing Multi-Head Latent Attention from scratch
        """
        super().__init__()  # Call the parent class constructor
        assert d_model % num_heads == 0  # Assert is used to check if the condition is true
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_latent = d_latent  # Added the compression dimension
        self.head_dim = d_model // num_heads

        # Compress: Initialize the weight matrix for shared compression
        self.kv_compress = nn.Linear(d_model, d_latent)
        
        # Expand latent to K and V for all heads
        self.k_proj = nn.Linear(d_latent, d_model)
        self.v_proj = nn.Linear(d_latent, d_model)

        # NOTE: Q remains the same, there is no compression in Q
        self.q_proj = nn.Linear(d_model, d_model)
        
        # Output projection (This is after attention we need to concatenate all heads)
        self.out_proj = nn.Linear(d_model, d_model)

        # Create a dropout layer to prevent overfitting
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Unpack the input
        batch_size, seq_len, _ = x.size()  # Fixed: using _ instead of d_model to avoid shadowing

        # Project Q (NOTE: Projection just means x@W+b)
        Q = self.q_proj(x)

        # COMPRESS: Compress input to latent space (THE KEY STEP!)
        # Here you compress the input into a smaller latent vector
        kv_latent = self.kv_compress(x)

        # Now we need to expand latent into K and V
        K = self.k_proj(kv_latent)
        V = self.v_proj(kv_latent)
        
        # Here we are reshaping so that each head can compute attention independently and in parallel
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim)

        # Transpose to put heads first
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)

        # Compute the attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Convert scores into attention weights
        attn_weights = torch.softmax(scores, dim=-1)
        
        # Apply dropout
        attn_weights = self.dropout(attn_weights)
        
        # Multiply with V
        output = torch.matmul(attn_weights, V)
        
        # Transpose back
        output = output.transpose(1, 2).contiguous()
       
        # Concat all heads
        output = output.view(batch_size, seq_len, self.d_model)

        # Final Output
        output = self.out_proj(output)
         
        return output


# Test the implementation
if __name__ == "__main__":
    # Create model
    # d_model=512, d_latent=128 means we compress from 512 → 128 → 512
    model = MultiHeadLatentAttention(d_model=512, num_heads=8, d_latent=128, dropout=0.1)
    
    # Create input
    batch_size = 32
    seq_len = 10
    x = torch.randn(batch_size, seq_len, 512)
    
    # Forward pass
    output = model(x)
    
    print(f"Input shape:         {x.shape}")       # [32, 10, 512]
    print(f"Output shape:        {output.shape}")   # [32, 10, 512]
    


Input shape:         torch.Size([32, 10, 512])
Output shape:        torch.Size([32, 10, 512])


# BenchMarking

In [18]:
import time
import torch

# Clear GPU memory before starting
torch.cuda.empty_cache()

In [19]:
# Setting Up Config for Benchmarking
# Benchmark settings
batch_size = 32
seq_len = 512  # Longer sequence to see real differences
d_model = 512
num_heads = 8
num_kv_heads = 2  # For GQA
d_latent = 128    # For MLA
dropout = 0.0     # Turn off for fair benchmarking

In [20]:
# Create input tensor on GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
x = torch.randn(batch_size, seq_len, d_model, device=device)

print(f"Input shape: {x.shape}")
print(f"Device: {device}")

Input shape: torch.Size([32, 512, 512])
Device: cuda


# Main Function that would measure memory

In [21]:
def measure_memory(model, x, device):
    """Measures GPU memory used by a model"""
    # Ensure input is on correct device
    x = x.to(device)
    
    torch.cuda.empty_cache() #This will clear empty cache
    #PyTorch tracks the maximum memory used since the last reset. This resets that counter to zero, so we measure ONLY this model's memory.
    torch.cuda.reset_peak_memory_stats(device) 
    
    # Run forward pass
    _ = model(x)
    
    # Get peak memory in MB
    memory_allocated = torch.cuda.max_memory_allocated(device) / (1024**2)
    
    return memory_allocated

In [22]:
def measure_speed(model, x, num_iterations=100):
    """Measures average forward pass time"""
    # Ensure input is on correct device
    if torch.cuda.is_available():
        x = x.to(model.parameters().__next__().device)
    
    model.eval()  # Set to evaluation mode
    
    # Warmup runs (GPU needs to "warm up")
    for _ in range(10):
        _ = model(x)
    
    torch.cuda.synchronize()  # Wait for GPU to finish
    
    # Actual timing
    start_time = time.time()
    for _ in range(num_iterations):
        _ = model(x)
    torch.cuda.synchronize()  # Wait again
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_iterations
    return avg_time * 1000  # Convert to milliseconds

In [23]:
def count_parameters(model):
    """Counts trainable parameters in the model"""
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params / 1e6  # Return in millions

In [24]:
# Recreate models with explicit GPU placement
mha_model = MultiHeadAttention(d_model, num_heads, dropout)
mha_model = mha_model.to(device)

gqa_model = GroupedQueryAttention(d_model, num_heads, num_kv_heads, dropout)
gqa_model = gqa_model.to(device)

mla_model = MultiHeadLatentAttention(d_model, num_heads, d_latent, dropout)
mla_model = mla_model.to(device)

print(f"✅ All models on {device}")
print(f"MHA first layer device: {next(mha_model.parameters()).device}")

✅ All models on cuda
MHA first layer device: cuda:0


In [25]:
print("Starting Benchmarks...\n")

# Benchmark MHA
print("Benchmarking Multi-Head Attention...")
mha_memory = measure_memory(mha_model, x, device)
mha_speed = measure_speed(mha_model, x)
mha_params = count_parameters(mha_model)

# Benchmark GQA
print("Benchmarking Grouped Query Attention...")
gqa_memory = measure_memory(gqa_model, x, device)
gqa_speed = measure_speed(gqa_model, x)
gqa_params = count_parameters(gqa_model)

# Benchmark MLA
print("Benchmarking Multi-Head Latent Attention...")
mla_memory = measure_memory(mla_model, x, device)
mla_speed = measure_speed(mla_model, x)
mla_params = count_parameters(mla_model)
print("Benchmarks complete!\n")

Starting Benchmarks...

Benchmarking Multi-Head Attention...
Benchmarking Grouped Query Attention...
Benchmarking Multi-Head Latent Attention...
Benchmarks complete!



# Results

In [26]:
# Print benchmark results
print("=" * 80)
print("BENCHMARK RESULTS")
print("=" * 80)
print(f"\nConfiguration:")
print(f"  Batch size: {batch_size}")
print(f"  Sequence length: {seq_len}")
print(f"  Model dimension: {d_model}")
print(f"  Number of heads: {num_heads}")
print(f"  Number of KV heads (GQA): {num_kv_heads}")
print(f"  Latent dimension (MLA): {d_latent}")
print(f"\n{'-' * 80}")
print(f"{'Model':<30} {'Memory (MB)':<15} {'Speed (ms)':<15} {'Parameters (M)':<15}")
print(f"{'-' * 80}")
print(f"{'Multi-Head Attention (MHA)':<30} {mha_memory:<15.2f} {mha_speed:<15.4f} {mha_params:<15.3f}")
print(f"{'Grouped Query Attention (GQA)':<30} {gqa_memory:<15.2f} {gqa_speed:<15.4f} {gqa_params:<15.3f}")
print(f"{'Multi-Head Latent Attention (MLA)':<30} {mla_memory:<15.2f} {mla_speed:<15.4f} {mla_params:<15.3f}")
print(f"{'-' * 80}")

# Calculate savings
print(f"\n{'Comparison vs MHA:':<30}")
print(f"{'-' * 80}")
print(f"{'Metric':<30} {'GQA Savings':<25} {'MLA Savings':<25}")
print(f"{'-' * 80}")

memory_savings_gqa = ((mha_memory - gqa_memory) / mha_memory) * 100
memory_savings_mla = ((mha_memory - mla_memory) / mha_memory) * 100
speed_savings_gqa = ((mha_speed - gqa_speed) / mha_speed) * 100
speed_savings_mla = ((mha_speed - mla_speed) / mha_speed) * 100
param_savings_gqa = ((mha_params - gqa_params) / mha_params) * 100
param_savings_mla = ((mha_params - mla_params) / mha_params) * 100

print(f"{'Memory':<30} {memory_savings_gqa:>6.2f}%{'':<18} {memory_savings_mla:>6.2f}%")
print(f"{'Speed':<30} {speed_savings_gqa:>6.2f}%{'':<18} {speed_savings_mla:>6.2f}%")
print(f"{'Parameters':<30} {param_savings_gqa:>6.2f}%{'':<18} {param_savings_mla:>6.2f}%")
print(f"=" * 80)

BENCHMARK RESULTS

Configuration:
  Batch size: 32
  Sequence length: 512
  Model dimension: 512
  Number of heads: 8
  Number of KV heads (GQA): 2
  Latent dimension (MLA): 128

--------------------------------------------------------------------------------
Model                          Memory (MB)     Speed (ms)      Parameters (M) 
--------------------------------------------------------------------------------
Multi-Head Attention (MHA)     826.41          9.1889          1.051          
Grouped Query Attention (GQA)  762.41          7.9750          0.657          
Multi-Head Latent Attention (MLA) 834.41          8.3616          0.723          
--------------------------------------------------------------------------------

Comparison vs MHA:            
--------------------------------------------------------------------------------
Metric                         GQA Savings               MLA Savings              
---------------------------------------------------------------

# Llama 2 Attention Config

In [27]:
def create_llama_attention(model_size='70B'):
    """
    Creates Llama 2 attention configuration
    
    Llama 2 70B specs:
    - d_model = 8192
    - num_heads = 64  
    - num_kv_heads = 8 (GQA with group size g=8)
    - head_dim = 128
    """
    if model_size == '70B':
        d_model = 8192
        num_heads = 64
        num_kv_heads = 8
        dropout = 0.0
        
        model = GroupedQueryAttention(
            d_model=d_model,
            num_heads=num_heads,
            num_kv_heads=num_kv_heads,
            dropout=dropout
        )
        
        return model
    else:
        raise ValueError(f"Model size {model_size} not supported yet")

In [28]:
# Test Llama 2 70B configuration
print("Creating Llama 2 70B attention...")
llama_attn = create_llama_attention('70B')
llama_attn = llama_attn.to(device)

# Create input (batch=1, seq=4096, dim=8192)
x_llama = torch.randn(1, 4096, 8192, device=device)

print(f"Input shape: {x_llama.shape}")

# Forward pass
output_llama = llama_attn(x_llama)
print(f"Output shape: {output_llama.shape}")
print("✅ Llama attention works!")

Creating Llama 2 70B attention...
Input shape: torch.Size([1, 4096, 8192])
Output shape: torch.Size([1, 4096, 8192])
✅ Llama attention works!


In [29]:
def calculate_kv_cache(model, seq_len, dtype_bytes=2):
    """
    Calculate KV cache memory for a GQA/MHA model
    
    Args:
        model: The attention model (GQA or MHA)
        seq_len: Sequence length
        dtype_bytes: 2 for FP16, 4 for FP32
    
    Returns:
        Memory in bytes
    """
    d_model = model.d_model
    num_kv_heads = model.num_kv_heads
    head_dim = model.head_dim
    
    # KV cache stores: Keys + Values for each KV head
    # Shape: [seq_len, num_kv_heads, head_dim]
    elements_per_cache = seq_len * num_kv_heads * head_dim
    
    # Both K and V
    total_elements = 2 * elements_per_cache
    
    # Convert to bytes
    memory_bytes = total_elements * dtype_bytes
    
    return memory_bytes

# Calculate for Llama 2 70B
kv_cache_per_layer = calculate_kv_cache(llama_attn, seq_len=4096, dtype_bytes=2)
total_cache_80_layers = kv_cache_per_layer * 80

print(f"\nLlama 2 70B KV Cache Analysis:")
print(f"Per layer (4K context): {kv_cache_per_layer / 1e6:.2f} MB")
print(f"Total 80 layers (4K context): {total_cache_80_layers / 1e9:.2f} GB")


Llama 2 70B KV Cache Analysis:
Per layer (4K context): 16.78 MB
Total 80 layers (4K context): 1.34 GB
