# Estimate VRAM usage for TGI

In [None]:
!pip install transformers

In [1]:
from transformers import AutoConfig
from huggingface_hub import get_safetensors_metadata

  from .autonotebook import tqdm as notebook_tqdm


### Utils

In [2]:
bytes_per_dtype = {"int4": 0.5, "int8": 1, "float8": 1, "float16": 2, "float32": 4}


def bytes_to_gb(bytes: int):
    """
    Calculate the memory in GB for a given number of bytes.

    Note:
        - 1 GB = 10^9 bytes and 1 GiB = 2^30 bytes
        - We estimate VRAM usage in GB for more conservative estimates
    """
    return round((bytes) / 10**9, 5)

### CUDA Overhead

In [32]:
def get_cuda_overhead(buffer: float = 1.2):
    """
    When PyTorch uses CUDA for the first time, it may use up 0.5-2GB of GPU memory, reducing the GPU's total available memory.

    Here, we assume 1GB of overhead based on torch>=2.1.1 as discussed here:
    https://github.com/stas00/ml-engineering/tree/master/training/performance#additional-gpu-memory-usage
    """
    return 1 * 10**9 * buffer

### Model Overhead

In [4]:
def get_model_overhead(model_id: str, dtype: str = "float16") -> int:
    """
    Get the model size in GB for a given model ID and data type.
    """
    metadata = get_safetensors_metadata(model_id)

    if hasattr(metadata, "metadata"):
        if metadata.metadata.get("total_size", False):
            return metadata.metadata["total_size"]
    else:
        num_params = list(metadata.parameter_count.values())[0]
        num_params = int(num_params)
        bytes = num_params * bytes_per_dtype[dtype]
        return bytes

### Prefill Overhead

In [50]:
import math


def calculate_mlp_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    intermediate_size: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single MLP block/layer.

    Accounts for following activations:
        - Input to the first Linear layer
        - Input to the activation function
        - Input to the second Linear layer
        - Dropout mask
    """

    bytes_per_unit = bytes_per_dtype[dtype]

    mlp_input = batch_size * seq_len * hidden_size * bytes_per_unit
    act_input = batch_size * seq_len * intermediate_size * bytes_per_unit
    down_proj_input = batch_size * seq_len * intermediate_size * bytes_per_unit
    dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )  # binary mask

    mlp_block_bytes = mlp_input + act_input + down_proj_input + dropout_mask

    return mlp_block_bytes


def calculate_attention_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    head_dim: int,
    num_attention_heads: int,
    num_key_value_heads: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single attention block/layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]

    attention_input = batch_size * seq_len * hidden_size * bytes_per_unit
    q = batch_size * seq_len * head_dim * num_attention_heads * bytes_per_unit
    k = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    softmax_output = batch_size * num_attention_heads * seq_len**2 * bytes_per_unit
    softmax_dropout_mask = (
        batch_size * num_attention_heads * seq_len**2 * bytes_per_dtype["int8"]
    )
    dropout_output = batch_size * num_attention_heads * seq_len**2 * bytes_per_unit
    v = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    out_proj_input = (
        batch_size * seq_len * num_attention_heads * head_dim * bytes_per_unit
    )
    attention_dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )

    attention_block_bytes = (
        attention_input
        + q
        + k
        + softmax_output
        + softmax_dropout_mask
        + dropout_output
        + v
        + out_proj_input
        + attention_dropout_mask
    )

    return attention_block_bytes


def calculate_flash_attention_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    head_dim: int,
    num_attention_heads: int,
    num_key_value_heads: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single attention block/layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]

    attention_input = batch_size * seq_len * hidden_size * bytes_per_unit
    q = batch_size * seq_len * head_dim * num_attention_heads * bytes_per_unit
    k = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    v = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit

    softmax_output = batch_size * num_attention_heads * seq_len * bytes_per_unit
    out_proj_input = (
        batch_size * seq_len * num_attention_heads * head_dim * bytes_per_unit
    )

    # not necessary, but conservative
    softmax_dropout_mask = (
        batch_size * num_attention_heads * seq_len * bytes_per_dtype["int8"]
    )
    dropout_output = batch_size * num_attention_heads * seq_len * bytes_per_unit

    attention_dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )

    attention_block_bytes = (
        attention_input
        + q
        + k
        + v
        + softmax_output
        + out_proj_input
        # not necessary, but conservative
        + softmax_dropout_mask
        + dropout_output
        + attention_dropout_mask
    )

    return attention_block_bytes


def calculate_layernorm_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single layer norm layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]
    layernorm_bytes = batch_size * seq_len * hidden_size * bytes_per_unit

    return layernorm_bytes


def get_forward_pass_overhead(
    config: AutoConfig,
    max_batch_prefill_tokens: int,
    max_input_tokens: int,
    dtype: str = "float16",
    is_decode: bool = False,
):
    """
    Calculate the prefill overhead for a given model and TGI config.

    Size of a biggest tensor within forward pass.
    It is estimated as the sum of all intermediate tensors within computation of a single layer.
    Activations size have quadratic dependence on Sequence Length.

    Reference:
        - https://arxiv.org/pdf/2205.05198
        - https://asmirnov.xyz/vram#fn1
    """
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    head_dim = (
        config.head_dim
        if hasattr(config, "head_dim")
        else config.hidden_size // config.num_key_value_heads
    )

    mlp_overhead = calculate_mlp_overhead(
        seq_len=(max_input_tokens if not is_decode else 1),
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        dtype=dtype,
    )

    attention_overhead = calculate_flash_attention_overhead(
        seq_len=(max_input_tokens if not is_decode else 1),
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        head_dim=head_dim,
        num_attention_heads=config.num_attention_heads,
        num_key_value_heads=config.num_key_value_heads,
        dtype=dtype,
    )

    layernorm_overhead = calculate_layernorm_overhead(
        seq_len=(max_input_tokens if not is_decode else 1),
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        dtype=dtype,
    )

    return mlp_overhead + attention_overhead + (layernorm_overhead * 2)

### KV Cache Overhead

In [51]:
def caclulate_kv_cache_memory_per_token(config: AutoConfig, dtype: str):
    """Calculates the memory required for the key-value cache per token in a Large Language Model (LLM)."""
    dtype_bytes = bytes_per_dtype[dtype]
    head_dim = (
        config.head_dim
        if hasattr(config, "head_dim")
        else config.hidden_size // config.num_key_value_heads
    )

    bytes_per_token = (
        2  # k & v
        * config.num_hidden_layers
        * config.num_key_value_heads
        * head_dim
        * dtype_bytes
    )
    return bytes_per_token

### Estimate VRAM Overhead (mimic TGI logic)

In [53]:
def validate_vram_overhead(
    model_id: str,
    num_gpus: int,
    vram_per_gpu: int,  # in GB
    max_input_tokens: int,
    max_total_tokens: int,
    max_batch_prefill_tokens: int,
    cuda_memory_fraction: float = 1.0,
    dtype: str = "float16",
):
    """
    Validate the VRAM overhead for a given model and TGI config.
    """
    config = AutoConfig.from_pretrained(model_id)

    # determine overhead
    cuda_overhead = get_cuda_overhead() * num_gpus
    model_overhead = get_model_overhead(model_id, dtype)
    prefill_overhead = get_forward_pass_overhead(
        config, max_batch_prefill_tokens, max_input_tokens, dtype
    )
    decode_overhead = get_forward_pass_overhead(
        config, max_batch_prefill_tokens, max_input_tokens, dtype, is_decode=True
    )
    kv_cache_overhead_per_token = caclulate_kv_cache_memory_per_token(config, dtype)

    # calculate free vram for kv cache
    free_vram = vram_per_gpu * num_gpus * 10**9
    free_vram = (
        free_vram - cuda_overhead - model_overhead - prefill_overhead - decode_overhead
    )
    free_vram = (
        free_vram * cuda_memory_fraction * 0.95
    )  # by default, TGI allocates 95% of free VRAM to kv-cache

    print(f"Model Overhead: {bytes_to_gb(model_overhead)} GB")
    print(f"CUDA Overhead: {bytes_to_gb(cuda_overhead)} GB")
    print(f"Prefill Overhead: {bytes_to_gb(prefill_overhead)} GB")
    print(f"Free VRAM for KV Cache: {bytes_to_gb(free_vram)} GB")

    # calculate token budget available for full kv cache
    kv_cache_token_budget = int((free_vram) // kv_cache_overhead_per_token)

    # check if we have enough vram for full kv cache (num_batches * max_total_tokens)
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    kv_cache_tokens_needed = batch_size * (max_total_tokens - 1)
    vram_needed_for_kv_cache = kv_cache_tokens_needed * kv_cache_overhead_per_token
    is_enough_vram = kv_cache_token_budget >= kv_cache_tokens_needed

    result = {
        "cuda_overhead": bytes_to_gb(cuda_overhead),
        "model_overhead": bytes_to_gb(model_overhead),
        "prefill_overhead": bytes_to_gb(prefill_overhead),
        "kv_cache_overhead_per_token": bytes_to_gb(kv_cache_overhead_per_token),
        "free_vram_for_kv_cache": bytes_to_gb(free_vram),
        "vram_needed_for_kv_cache": bytes_to_gb(vram_needed_for_kv_cache),
        "kv_cache_token_budget": kv_cache_token_budget,
        "kv_cache_tokens_needed": kv_cache_tokens_needed,
        "is_enough_vram": is_enough_vram,
    }

    return result

### Estimate VRAM Overhead (Simplified)

In [54]:
def validate_vram_overhead(
    model_id: str,
    num_gpus: int,
    vram_per_gpu: int,  # in GB
    max_input_tokens: int,
    max_total_tokens: int,
    max_batch_prefill_tokens: int,
    cuda_memory_fraction: float = 1.0,
    dtype: str = "float16",
):
    """
    Validate the VRAM overhead for a given model and TGI config.
    """
    config = AutoConfig.from_pretrained(model_id)

    # calculate total vram
    total_vram_available = (
        (vram_per_gpu * num_gpus * 10**9) * cuda_memory_fraction * 0.95
    )  # 95% of free VRAM

    # determine overhead
    cuda_overhead = get_cuda_overhead() * num_gpus
    model_overhead = get_model_overhead(model_id, dtype)
    prefill_overhead = get_forward_pass_overhead(
        config, max_batch_prefill_tokens, max_input_tokens, dtype
    )
    decode_overhead = get_forward_pass_overhead(
        config,
        max_batch_prefill_tokens,
        max_input_tokens,
        dtype,
        is_decode=True,
    )

    kv_cache_overhead_per_token = caclulate_kv_cache_memory_per_token(config, dtype)
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    kv_cache_full_overhead = (
        batch_size * (max_total_tokens - 1) * kv_cache_overhead_per_token
    )

    total_vram_needed = (
        model_overhead
        + cuda_overhead
        + prefill_overhead
        + kv_cache_full_overhead
        + decode_overhead
    )

    # print(f"Total VRAM Available: {bytes_to_gb(total_vram_available)} GB")
    # print(f"Model Overhead: {bytes_to_gb(model_overhead)} GB")
    # print(f"CUDA Overhead: {bytes_to_gb(cuda_overhead)} GB")
    # print(f"Prefill Overhead: {bytes_to_gb(prefill_overhead)} GB")
    # print(f"KV Cache Full Overhead: {bytes_to_gb(kv_cache_full_overhead)} GB")

    is_enough_vram = total_vram_available >= total_vram_needed

    result = {
        "cuda_overhead": bytes_to_gb(cuda_overhead),
        "model_overhead": bytes_to_gb(model_overhead),
        "prefill_overhead": bytes_to_gb(prefill_overhead),
        "decode_overhead": bytes_to_gb(decode_overhead),
        "kv_cache_full_overhead": bytes_to_gb(kv_cache_full_overhead),
        "total_vram_available": bytes_to_gb(total_vram_available),
        "total_vram_needed": bytes_to_gb(total_vram_needed),
        "is_enough_vram": is_enough_vram,
    }

    return result

## Test

In [56]:
num_gpus = 1
max_input_tokens = 8750
max_total_tokens = 8782
max_batch_prefill_tokens = 8782

# num_gpus = 2
# max_input_tokens = 15520
# max_total_tokens = 15554
# max_batch_prefill_tokens = 15554

model_id = "meta-llama/Llama-3.1-8B-Instruct"

validate_vram_overhead(
    model_id=model_id,
    num_gpus=num_gpus,
    vram_per_gpu=24,
    max_input_tokens=max_input_tokens,
    max_total_tokens=max_total_tokens,
    max_batch_prefill_tokens=max_batch_prefill_tokens,
    dtype="float16",
)

Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00,  8.31it/s]


{'cuda_overhead': 1.2,
 'model_overhead': 16.06052,
 'prefill_overhead': 3.15672,
 'decode_overhead': 0.00036,
 'kv_cache_full_overhead': 9.20755,
 'total_vram_available': 22.8,
 'total_vram_needed': 29.62515,
 'is_enough_vram': False}

In [57]:
import pandas as pd

metrics = []
for tok in range(10000, 20001, 1000):

    result = validate_vram_overhead(
        model_id=model_id,
        num_gpus=4,
        vram_per_gpu=24,
        max_input_tokens=tok - 1,
        max_total_tokens=tok,
        max_batch_prefill_tokens=tok,
        dtype="float16",
    )
    result["max_input_tokens"] = tok - 1
    metrics.append(result)

df = pd.DataFrame(metrics)

Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00,  4.90it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 17.70it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00,  4.21it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 13.48it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 19.32it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 10.95it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 23.10it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 11.63it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 17.61it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 14.52it/s]
Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 11.28it/s]


In [58]:
df

Unnamed: 0,cuda_overhead,model_overhead,prefill_overhead,decode_overhead,kv_cache_full_overhead,total_vram_available,total_vram_needed,is_enough_vram,max_input_tokens
0,4.8,16.06052,3.60732,0.00036,10.48471,91.2,34.95291,True,9999
1,4.8,16.06052,3.96809,0.00036,11.53329,91.2,36.36226,True,10999
2,4.8,16.06052,4.32886,0.00036,12.58186,91.2,37.7716,True,11999
3,4.8,16.06052,4.68962,0.00036,13.63044,91.2,39.18095,True,12999
4,4.8,16.06052,5.05039,0.00036,14.67902,91.2,40.59029,True,13999
5,4.8,16.06052,5.41116,0.00036,15.72759,91.2,41.99963,True,14999
6,4.8,16.06052,5.77193,0.00036,16.77617,91.2,43.40898,True,15999
7,4.8,16.06052,6.1327,0.00036,17.82474,91.2,44.81832,True,16999
8,4.8,16.06052,6.49346,0.00036,18.87332,91.2,46.22767,True,17999
9,4.8,16.06052,6.85423,0.00036,19.9219,91.2,47.63701,True,18999
