In [45]:
import torch
import torch.nn as nn 
from typing import Optional,Tuple,List
import math
from dataclasses import dataclass

In [46]:
@dataclass 
class CacheConfig : 
    """config for KV cache"""
    max_batch_size : int = 8 
    max_seq_length : int = 2048 
    num_layers : int = 12
    num_heads : int = 12
    head_dim : int = 64



In [47]:
class KVCache:
    """
    Simple Key-Value cache for a single layer.
    Stores past key and value tensors to avoid recomputation.
    """
    
    def __init__(self, batch_size: int, max_seq_len: int, num_heads: int, head_dim: int, device='cuda'):
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.device = device
        
        # Pre-allocate cache tensors
        self.key_cache = torch.zeros(
            batch_size, num_heads, max_seq_len, head_dim,
            dtype=torch.float32, device=device
        )
        self.value_cache = torch.zeros(
            batch_size, num_heads, max_seq_len, head_dim,
            dtype=torch.float32, device=device
        )
        self.seq_len = 0  # Current sequence length
    
    def update(self, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Update cache with new key and value tensors.
        
        Args:
            key: New keys [batch_size, num_heads, new_seq_len, head_dim]
            value: New values [batch_size, num_heads, new_seq_len, head_dim]
            
        Returns:
            Complete key and value tensors including cached history
        """
        batch_size, num_heads, new_seq_len, head_dim = key.shape
        
        # Store new keys and values in cache
        self.key_cache[:batch_size, :, self.seq_len:self.seq_len + new_seq_len] = key
        self.value_cache[:batch_size, :, self.seq_len:self.seq_len + new_seq_len] = value
        
        # Update sequence length
        self.seq_len += new_seq_len
        
        # Return full cache up to current position
        full_keys = self.key_cache[:batch_size, :, :self.seq_len]
        full_values = self.value_cache[:batch_size, :, :self.seq_len]
        
        return full_keys, full_values

    def reset(self):
        """Reset cache to empty state"""
        self.seq_len = 0
        self.key_cache.zero_()
        self.value_cache.zero_()
    
    def get_seq_length(self) -> int:
        """Get current sequence length in cache"""
        return self.seq_len

In [48]:
class MultiHeadAttentionWithCache(nn.Module):
    """
    Multi-head attention with KV cache support.
    Demonstrates how to use KV cache in practice.
    """
    
    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.scale = self.head_dim ** -0.5
        
        # Linear projections
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        """
        Forward pass with optional KV cache.
        
        Args:
            hidden_states: Input tensor [batch_size, seq_len, d_model]
            attention_mask: Mask tensor [batch_size, 1, seq_len, cached_seq_len]
            kv_cache: Optional KV cache from previous steps
            use_cache: Whether to return updated cache
            
        Returns:
            output: Attention output [batch_size, seq_len, d_model]
            kv_cache: Updated cache if use_cache=True
        """
        batch_size, seq_len, _ = hidden_states.shape
        
        # Project queries, keys, values
        queries = self.q_proj(hidden_states)
        keys = self.k_proj(hidden_states)
        values = self.v_proj(hidden_states)
        
        # Reshape for multi-head attention
        # [batch_size, seq_len, d_model] -> [batch_size, num_heads, seq_len, head_dim]
        queries = queries.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = keys.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = values.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Update cache if provided
        if kv_cache is not None:
            keys, values = kv_cache.update(keys, values)
        
        # Compute attention scores
        attn_scores = torch.matmul(queries, keys.transpose(-2, -1)) * self.scale
        
        # Apply mask if provided
        if attention_mask is not None:
            attn_scores = attn_scores + attention_mask
        
        # Compute attention weights
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention to values
        attn_output = torch.matmul(attn_weights, values)
        
        # Reshape back
        # [batch_size, num_heads, seq_len, head_dim] -> [batch_size, seq_len, d_model]
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        
        # Final projection
        output = self.out_proj(attn_output)
        
        return output, kv_cache if use_cache else None



In [49]:
class TransformerLayerWithCache(nn.Module):
    """
    Complete transformer layer with KV cache support.
    """
    
    def __init__(self, d_model: int, num_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        
        self.attention = MultiHeadAttentionWithCache(d_model, num_heads, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        kv_cache: Optional[KVCache] = None,
        use_cache: bool = False
    ) -> Tuple[torch.Tensor, Optional[KVCache]]:
        # Self-attention with residual
        attn_output, kv_cache = self.attention(
            self.norm1(hidden_states),
            attention_mask=attention_mask,
            kv_cache=kv_cache,
            use_cache=use_cache
        )
        hidden_states = hidden_states + attn_output
        
        # Feed-forward with residual
        hidden_states = hidden_states + self.ffn(self.norm2(hidden_states))
        
        return hidden_states, kv_cache

In [50]:
def create_causal_mask(seq_len: int, cached_len: int = 0, device='cuda') -> torch.Tensor:
    """
    Create causal attention mask for autoregressive generation.
    
    Args:
        seq_len: Current sequence length
        cached_len: Length of cached sequence
        device: Device to create mask on
        
    Returns:
        Causal mask tensor
    """
    total_len = seq_len + cached_len
    mask = torch.triu(torch.ones(seq_len, total_len, device=device), diagonal=cached_len + 1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask.unsqueeze(0).unsqueeze(0)  # Add batch and head dimensions


In [51]:
# Model parameters
batch_size = 2
d_model = 512
num_heads = 8
head_dim = d_model // num_heads
max_seq_len = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [52]:
layer = TransformerLayerWithCache(d_model, num_heads, d_ff=2048).to(device)


In [53]:
layer.eval()  # Set to eval mode for inference

TransformerLayerWithCache(
  (attention): MultiHeadAttentionWithCache(
    (q_proj): Linear(in_features=512, out_features=512, bias=True)
    (k_proj): Linear(in_features=512, out_features=512, bias=True)
    (v_proj): Linear(in_features=512, out_features=512, bias=True)
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (ffn): Sequential(
    (0): Linear(in_features=512, out_features=2048, bias=True)
    (1): GELU(approximate='none')
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=2048, out_features=512, bias=True)
    (4): Dropout(p=0.1, inplace=False)
  )
)

In [54]:
prompt_len = 10
prompt = torch.randn(batch_size, prompt_len, d_model, device=device)

In [55]:
device

'cuda'

In [56]:
with torch.no_grad():
    current_seq = prompt.clone()  # Start with prompt as current sequence
    for step in range(20):  # Simulate generating 20 new tokens
        # need to process entire sequence every time 
        seq_len = current_seq.shape[1]
        mask = create_causal_mask(seq_len, 0,device=device)
        output, _ = layer(current_seq, attention_mask=mask,  use_cache=False)

        #simulate th getting next token 
        next_token = torch.randn(batch_size, 1, d_model, device=device)
        current_seq = torch.cat([current_seq, next_token], dim=1)

        if step % 5 == 0:
            print(f"Step {step+1}, Current Seq Length: {current_seq.shape[1]}")

print("Inference completed.")


Step 1, Current Seq Length: 11
Step 6, Current Seq Length: 16
Step 11, Current Seq Length: 21
Step 16, Current Seq Length: 26
Inference completed.


In [57]:
with torch.no_grad():
        # Initialize cache
        kv_cache = KVCache(batch_size, max_seq_len, num_heads, head_dim, device)
        
        # Process initial prompt
        mask = create_causal_mask(prompt_len, 0, device)
        output, kv_cache = layer(prompt, attention_mask=mask, kv_cache=kv_cache, use_cache=True)
        print(f"  Processed prompt: {prompt_len} tokens, cache length: {kv_cache.get_seq_length()}")
        
        current_token = torch.randn(batch_size, 1, d_model, device=device)
        
        for step in range(20):
            # Only process 1 new token at a time!
            cached_len = kv_cache.get_seq_length()
            mask = create_causal_mask(1, cached_len, device)
            
            output, kv_cache = layer(
                current_token,
                attention_mask=mask,
                kv_cache=kv_cache,
                use_cache=True
            )
            
            # Simulate getting next token
            current_token = torch.randn(batch_size, 1, d_model, device=device)
            
            if step % 5 == 0:
                print(f"  Step {step}: Processing 1 token, cache length: {kv_cache.get_seq_length()}")
    
print(f"Final cache length: {kv_cache.get_seq_length()}")

  Processed prompt: 10 tokens, cache length: 10
  Step 0: Processing 1 token, cache length: 11
  Step 5: Processing 1 token, cache length: 16
  Step 10: Processing 1 token, cache length: 21
  Step 15: Processing 1 token, cache length: 26
Final cache length: 30
