# KV Cache Memory Model

This notebook models the memory usage of the KV cache for a decoder only transformer

We explore how memory scales with

1) Number of Layers
2) Batch Size
3) Sequence Length (context length)
4) Hidden Dimension
5) Precision (FP16 vs FP8)

The goal is to analyze how KV cache dominates memory as context length and batch size grow

In [1]:
import numpy as np
import matplotlib.pyplot as plt


In [2]:
def kv_cache_memory_bytes(
    num_layers: int,
    batch_size: int,
    seq_len: int,
    hidden_dim: int, 
    bytes_per_element: int=2, #FP16 default
) -> int:

    """ 
        Approximate KV cache memory in bytes
        Formula: 
        memory = 2(K and V) * L *B*T*D* bytes_per_element

        Where 

        L = num_layers      (number of transformer layers)
        B = batch_size      (how many sequences processed at once)
        T = seq_len         (context window / sequence length)
        D = hidden_dim      (model width)
        bytes_per_element   (1=FP8, 2=FP16/BF16, 4=FP32)

"""
    return 2* num_layers * batch_size *seq_len * hidden_dim * bytes_per_element



In [3]:
# Example model parameters (7B-class model)
num_layers = 32       # L
batch_size = 4        # B
seq_len = 2048        # T 
hidden_dim = 4096     # D 
precision_bytes = 2   # FP16

# Calculate raw KV cache memory in bytes
mem_bytes = kv_cache_memory_bytes(
    num_layers,
    batch_size,
    seq_len,
    hidden_dim,
    precision_bytes,
)

mem_bytes


4294967296