# Estimate VRAM usage for TGI


free_vram = total_gpu_vram - model_params_vram - prefill_overhead_vram - kv_cache_vram - cuda_overhead_vram

In [2]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m70.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (751 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m751.2/751.2 kB[0m [31m52.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy>=1.17
  Downloading numpy-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m55.9 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting tqdm>=4.27
  Downloading tqdm-4.66.5-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.4/78.4 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting filelock
  Downloading filelock-3.16.1-py3-none-any.w

In [6]:
from transformers import AutoConfig
from huggingface_hub import get_safetensors_metadata

### Utils

In [143]:
bytes_per_dtype = {"int4": 0.5, "int8": 1, "float8": 1, "float16": 2, "float32": 4}


def bytes_to_gb(bytes: int):
    """
    Calculate the memory in GB for a given number of bytes.

    Note:
        - 1 GB = 10^9 bytes and 1 GiB = 2^30 bytes
        - We estimate VRAM usage in GB for more conservative estimates
    """
    return round((bytes) / 10**9, 5)

### CUDA Overhead

In [81]:
def get_cuda_overhead():
    """
    When PyTorch uses CUDA for the first time, it may use up 0.5-2GB of GPU memory, reducing the GPU's total available memory.

    Here, we assume 1GB of overhead based on torch>=2.1.1 as discussed here:
    https://github.com/stas00/ml-engineering/tree/master/training/performance#additional-gpu-memory-usage
    """
    return 1

### Model Overhead

In [135]:
def get_model_overhead(model_id: str, dtype: str = "float16") -> int:
    """
    Get the model size in GB for a given model ID and data type.
    """
    metadata = get_safetensors_metadata(model_id)

    if hasattr(metadata, "metadata"):
        if metadata.metadata.get("total_size", False):
            return bytes_to_gb(metadata.metadata["total_size"])
    else:
        num_params = list(metadata.parameter_count.values())[0]
        num_params = int(num_params)
        bytes = num_params * bytes_per_dtype[dtype]
        return bytes_to_gb(bytes)

In [132]:
a

SafetensorsRepoMetadata(metadata={'total_size': 5559367680}, sharded=True, weight_map={'lm_head.bias': 'model-00002-of-00002.safetensors', 'lm_head.weight': 'model-00002-of-00002.safetensors', 'model.embed_tokens.weight': 'model-00001-of-00002.safetensors', 'model.final_layernorm.bias': 'model-00002-of-00002.safetensors', 'model.final_layernorm.weight': 'model-00002-of-00002.safetensors', 'model.layers.0.input_layernorm.bias': 'model-00001-of-00002.safetensors', 'model.layers.0.input_layernorm.weight': 'model-00001-of-00002.safetensors', 'model.layers.0.mlp.fc1.bias': 'model-00001-of-00002.safetensors', 'model.layers.0.mlp.fc1.weight': 'model-00001-of-00002.safetensors', 'model.layers.0.mlp.fc2.bias': 'model-00001-of-00002.safetensors', 'model.layers.0.mlp.fc2.weight': 'model-00001-of-00002.safetensors', 'model.layers.0.self_attn.dense.bias': 'model-00001-of-00002.safetensors', 'model.layers.0.self_attn.dense.weight': 'model-00001-of-00002.safetensors', 'model.layers.0.self_attn.k_proj

### Prefill Overhead

In [85]:
import math


def calculate_mlp_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    intermediate_size: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single MLP block/layer.

    Accounts for following activations:
        - Input to the first Linear layer
        - Input to the activation function
        - Input to the second Linear layer
        - Dropout mask
    """

    bytes_per_unit = bytes_per_dtype[dtype]

    mlp_input = batch_size * seq_len * hidden_size * bytes_per_unit
    act_input = batch_size * seq_len * intermediate_size * bytes_per_unit
    down_proj_input = batch_size * seq_len * intermediate_size * bytes_per_unit
    dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )  # binary mask

    mlp_block_bytes = mlp_input + act_input + down_proj_input + dropout_mask

    return mlp_block_bytes


def calculate_attention_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    head_dim: int,
    num_attention_heads: int,
    num_key_value_heads: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single attention block/layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]

    attention_input = batch_size * seq_len * hidden_size * bytes_per_unit
    q = batch_size * seq_len * head_dim * num_attention_heads * bytes_per_unit
    k = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    softmax_output = batch_size * num_attention_heads * seq_len**2 * bytes_per_unit
    softmax_dropout_mask = (
        batch_size * num_attention_heads * seq_len**2 * bytes_per_dtype["int8"]
    )
    dropout_output = batch_size * num_attention_heads * seq_len**2 * bytes_per_unit
    v = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    out_proj_input = (
        batch_size * seq_len * num_attention_heads * head_dim * bytes_per_unit
    )
    attention_dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )

    attention_block_bytes = (
        attention_input
        + q
        + k
        + softmax_output
        + softmax_dropout_mask
        + dropout_output
        + v
        + out_proj_input
        + attention_dropout_mask
    )

    return attention_block_bytes


def calculate_layernorm_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single layer norm layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]
    layernorm_bytes = batch_size * seq_len * hidden_size * bytes_per_unit

    return layernorm_bytes


def get_prefill_overhead(
    config: AutoConfig,
    max_batch_prefill_tokens: int,
    max_input_tokens: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for a given model and TGI config.

    Size of a biggest tensor within forward pass.
    It is estimated as the sum of all intermediate tensors within computation of a single layer.
    Activations size have quadratic dependence on Sequence Length.

    Reference:
        - https://arxiv.org/pdf/2205.05198
        - https://asmirnov.xyz/vram#fn1
    """
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    head_dim = (
        config.head_dim
        if hasattr(config, "head_dim")
        else config.hidden_size // config.num_key_value_heads
    )

    mlp_overhead = calculate_mlp_prefill_overhead(
        seq_len=max_input_tokens,
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        dtype=dtype,
    )

    attention_overhead = calculate_attention_prefill_overhead(
        seq_len=max_input_tokens,
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        head_dim=head_dim,
        num_attention_heads=config.num_attention_heads,
        num_key_value_heads=config.num_key_value_heads,
        dtype=dtype,
    )

    layernorm_overhead = calculate_layernorm_prefill_overhead(
        seq_len=max_input_tokens,
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        dtype=dtype,
    )

    return bytes_to_gb(mlp_overhead + attention_overhead + layernorm_overhead * 2)

In [87]:
get_prefill_overhead(
    config=AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1"),
    max_batch_prefill_tokens=2048,
    max_input_tokens=1024,
    dtype="float16",
)

0.5788

### KV Cache Overhead

In [144]:
def caclulate_kv_cache_memory_per_token(config: AutoConfig, dtype: str):
    """Calculates the memory required for the key-value cache per token in a Large Language Model (LLM)."""
    dtype_bytes = bytes_per_dtype[dtype]
    head_dim = (
        config.head_dim
        if hasattr(config, "head_dim")
        else config.hidden_size // config.num_key_value_heads
    )

    bytes_per_token = (
        2  # k & v
        * config.num_hidden_layers
        * config.num_key_value_heads
        * head_dim
        * dtype_bytes
    )
    return bytes_per_token

In [86]:
caclulate_kv_cache_memory_per_token(
    config=AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1"),
    dtype="float16",
)

131072

### Estimate VRAM usage

In [150]:
def validate_vram_overhead(
    model_id: str,
    num_gpus: int,
    vram_per_gpu: int,  # in GB
    max_input_tokens: int,
    max_total_tokens: int,
    max_batch_prefill_tokens: int,
    cuda_memory_fraction: float = 1.0,
    dtype: str = "float16",
):
    """
    Validate the VRAM overhead for a given model and TGI config.
    """
    config = AutoConfig.from_pretrained(model_id)

    # determine overhead
    cuda_overhead = get_cuda_overhead()
    model_overhead = get_model_overhead(model_id, dtype)
    prefill_overhead = get_prefill_overhead(
        config, max_batch_prefill_tokens, max_input_tokens, dtype
    )
    kv_cache_overhead_per_token = caclulate_kv_cache_memory_per_token(config, dtype)

    # calculate free vram for kv cache
    free_vram = vram_per_gpu * num_gpus
    free_vram = (
        free_vram - (cuda_overhead * num_gpus) - model_overhead - prefill_overhead
    )
    free_vram = (
        free_vram * cuda_memory_fraction * 0.95
    )  # by default, TGI allocates 95% of free VRAM to kv-cache

    print(f"Model Overhead: {model_overhead} GB")
    print(f"CUDA Overhead: {cuda_overhead} GB")
    print(f"Prefill Overhead: {prefill_overhead} GB")
    print(f"Free VRAM: {free_vram} GB")

    # calculate token budget for full kv cache (num_batches * max_total_tokens)
    kv_cache_token_budget = int((free_vram * 10**9) // kv_cache_overhead_per_token)

    # check if we have enough vram for full kv cache at max_total_tokens
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    kv_cache_tokens_needed = batch_size * max_total_tokens
    is_enough_vram = kv_cache_token_budget >= kv_cache_tokens_needed

    result = {
        "cuda_overhead": cuda_overhead,
        "model_overhead": model_overhead,
        "prefill_overhead": prefill_overhead,
        "kv_cache_overhead_per_token": bytes_to_gb(kv_cache_overhead_per_token),
        "kv_cache_token_budget": kv_cache_token_budget,
        "kv_cache_tokens_needed": kv_cache_tokens_needed,
        "is_enough_vram": is_enough_vram,
    }

    return result

In [155]:
max_input_tokens = 8750
max_total_tokens = 8782
max_batch_prefill_tokens = 8782

validate_vram_overhead(
    model_id=model_id,
    num_gpus=1,
    vram_per_gpu=24,
    max_input_tokens=max_input_tokens,
    max_total_tokens=max_total_tokens,
    max_batch_prefill_tokens=max_batch_prefill_tokens,
    dtype="float16",
)

Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00, 24.01it/s]

Model Overhead: 16.06052 GB
Prefill Overhead: 27.65392 GB
Free VRAM: -19.678718 GB





{'cuda_overhead': 1,
 'model_overhead': 16.06052,
 'prefill_overhead': 27.65392,
 'kv_cache_overhead_per_token': 0.00052,
 'kv_cache_token_budget': -37535,
 'kv_cache_tokens_needed': 17564,
 'is_enough_vram': False}