# Estimate VRAM usage for TGI

In [2]:
!pip install transformers

Collecting transformers
  Using cached transformers-4.44.2-py3-none-any.whl (9.5 MB)
Collecting filelock
  Downloading filelock-3.16.1-py3-none-any.whl (16 kB)
Collecting safetensors>=0.4.1
  Using cached safetensors-0.4.5-cp310-cp310-macosx_11_0_arm64.whl (381 kB)
Collecting regex!=2019.12.17
  Downloading regex-2024.9.11-cp310-cp310-macosx_11_0_arm64.whl (284 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m284.6/284.6 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m MB/s[0m eta [36m0:00:01[0m:01[0m
[?25hCollecting huggingface-hub<1.0,>=0.23.2
  Downloading huggingface_hub-0.25.1-py3-none-any.whl (436 kB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m436.4/436.4 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m31m10.5 MB/s[0m eta [36m0:00:01[0m
[?25hCollecting tokenizers<0.20,>=0.19
  Using cached tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl (2.4 MB)
Collecting requests
  Using cached requests-2.32.3-py3-none-any.w

In [1]:
from transformers import AutoConfig
from huggingface_hub import get_safetensors_metadata

  from .autonotebook import tqdm as notebook_tqdm


### Utils

In [26]:
bytes_per_dtype = {"int4": 0.5, "int8": 1, "float8": 1, "float16": 2, "float32": 4}


def bytes_to_gb(bytes: int):
    """
    Calculate the memory in GB for a given number of bytes.

    Note:
        - 1 GB = 10^9 bytes and 1 GiB = 2^30 bytes
        - We estimate VRAM usage in GB for more conservative estimates
    """
    return round((bytes) / 10**9, 5)

### CUDA Overhead

In [25]:
def get_cuda_overhead():
    """
    When PyTorch uses CUDA for the first time, it may use up 0.5-2GB of GPU memory, reducing the GPU's total available memory.

    Here, we assume 1GB of overhead based on torch>=2.1.1 as discussed here:
    https://github.com/stas00/ml-engineering/tree/master/training/performance#additional-gpu-memory-usage
    """
    return 1 * 10**9

### Model Overhead

In [20]:
def get_model_overhead(model_id: str, dtype: str = "float16") -> int:
    """
    Get the model size in GB for a given model ID and data type.
    """
    metadata = get_safetensors_metadata(model_id)

    if hasattr(metadata, "metadata"):
        if metadata.metadata.get("total_size", False):
            return metadata.metadata["total_size"]
    else:
        num_params = list(metadata.parameter_count.values())[0]
        num_params = int(num_params)
        bytes = num_params * bytes_per_dtype[dtype]
        return bytes

### Prefill Overhead

In [56]:
import math


def calculate_mlp_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    intermediate_size: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single MLP block/layer.

    Accounts for following activations:
        - Input to the first Linear layer
        - Input to the activation function
        - Input to the second Linear layer
        - Dropout mask
    """

    bytes_per_unit = bytes_per_dtype[dtype]

    mlp_input = batch_size * seq_len * hidden_size * bytes_per_unit
    act_input = batch_size * seq_len * intermediate_size * bytes_per_unit
    down_proj_input = batch_size * seq_len * intermediate_size * bytes_per_unit
    dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )  # binary mask

    mlp_block_bytes = mlp_input + act_input + down_proj_input + dropout_mask

    return mlp_block_bytes


def calculate_attention_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    head_dim: int,
    num_attention_heads: int,
    num_key_value_heads: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single attention block/layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]

    attention_input = batch_size * seq_len * hidden_size * bytes_per_unit
    q = batch_size * seq_len * head_dim * num_attention_heads * bytes_per_unit
    k = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    softmax_output = batch_size * num_attention_heads * seq_len**2 * bytes_per_unit
    softmax_dropout_mask = (
        batch_size * num_attention_heads * seq_len**2 * bytes_per_dtype["int8"]
    )
    dropout_output = batch_size * num_attention_heads * seq_len**2 * bytes_per_unit
    v = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    out_proj_input = (
        batch_size * seq_len * num_attention_heads * head_dim * bytes_per_unit
    )
    attention_dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )

    attention_block_bytes = (
        attention_input
        + q
        + k
        + softmax_output
        + softmax_dropout_mask
        + dropout_output
        + v
        + out_proj_input
        + attention_dropout_mask
    )

    return attention_block_bytes


def calculate_flash_attention_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    head_dim: int,
    num_attention_heads: int,
    num_key_value_heads: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single attention block/layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]

    attention_input = batch_size * seq_len * hidden_size * bytes_per_unit
    q = batch_size * seq_len * head_dim * num_attention_heads * bytes_per_unit
    k = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit
    v = batch_size * seq_len * head_dim * num_key_value_heads * bytes_per_unit

    softmax_output = batch_size * num_attention_heads * seq_len * bytes_per_unit
    softmax_dropout_mask = (
        batch_size * num_attention_heads * seq_len * bytes_per_dtype["int8"]
    )
    dropout_output = batch_size * num_attention_heads * seq_len * bytes_per_unit

    out_proj_input = (
        batch_size * seq_len * num_attention_heads * head_dim * bytes_per_unit
    )
    attention_dropout_mask = (
        batch_size * seq_len * hidden_size * bytes_per_dtype["int8"]
    )

    attention_block_bytes = (
        attention_input
        + q
        + k
        + v
        + softmax_output
        + softmax_dropout_mask
        + dropout_output
        + out_proj_input
        + attention_dropout_mask
    )

    return attention_block_bytes


def calculate_layernorm_prefill_overhead(
    seq_len: int,
    batch_size: int,
    hidden_size: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for single layer norm layer.
    """
    bytes_per_unit = bytes_per_dtype[dtype]
    layernorm_bytes = batch_size * seq_len * hidden_size * bytes_per_unit

    return layernorm_bytes


def get_prefill_overhead(
    config: AutoConfig,
    max_batch_prefill_tokens: int,
    max_input_tokens: int,
    dtype: str = "float16",
):
    """
    Calculate the prefill overhead for a given model and TGI config.

    Size of a biggest tensor within forward pass.
    It is estimated as the sum of all intermediate tensors within computation of a single layer.
    Activations size have quadratic dependence on Sequence Length.

    Reference:
        - https://arxiv.org/pdf/2205.05198
        - https://asmirnov.xyz/vram#fn1
    """
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    head_dim = (
        config.head_dim
        if hasattr(config, "head_dim")
        else config.hidden_size // config.num_key_value_heads
    )

    mlp_overhead = calculate_mlp_prefill_overhead(
        seq_len=max_input_tokens,
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        intermediate_size=config.intermediate_size,
        dtype=dtype,
    )

    attention_overhead = calculate_flash_attention_prefill_overhead(
        seq_len=max_input_tokens,
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        head_dim=head_dim,
        num_attention_heads=config.num_attention_heads,
        num_key_value_heads=config.num_key_value_heads,
        dtype=dtype,
    )

    layernorm_overhead = calculate_layernorm_prefill_overhead(
        seq_len=max_input_tokens,
        batch_size=batch_size,
        hidden_size=config.hidden_size,
        dtype=dtype,
    )

    return mlp_overhead + attention_overhead + (layernorm_overhead * 2)

### KV Cache Overhead

In [43]:
def caclulate_kv_cache_memory_per_token(config: AutoConfig, dtype: str):
    """Calculates the memory required for the key-value cache per token in a Large Language Model (LLM)."""
    dtype_bytes = bytes_per_dtype[dtype]
    head_dim = (
        config.head_dim
        if hasattr(config, "head_dim")
        else config.hidden_size // config.num_key_value_heads
    )

    bytes_per_token = (
        2  # k & v
        * config.num_hidden_layers
        * config.num_key_value_heads
        * head_dim
        * dtype_bytes
    )
    return bytes_per_token

In [44]:
caclulate_kv_cache_memory_per_token(
    config=AutoConfig.from_pretrained("mistralai/Mistral-7B-v0.1"),
    dtype="float16",
)

131072

### Estimate VRAM usage (normal attention)

In [47]:
def validate_vram_overhead(
    model_id: str,
    num_gpus: int,
    vram_per_gpu: int,  # in GB
    max_input_tokens: int,
    max_total_tokens: int,
    max_batch_prefill_tokens: int,
    cuda_memory_fraction: float = 1.0,
    dtype: str = "float16",
):
    """
    Validate the VRAM overhead for a given model and TGI config.
    """
    config = AutoConfig.from_pretrained(model_id)

    # determine overhead
    cuda_overhead = get_cuda_overhead() * num_gpus
    model_overhead = get_model_overhead(model_id, dtype)
    prefill_overhead = get_prefill_overhead(
        config, max_batch_prefill_tokens, max_input_tokens, dtype
    )
    kv_cache_overhead_per_token = caclulate_kv_cache_memory_per_token(config, dtype)

    # calculate free vram for kv cache
    free_vram = vram_per_gpu * num_gpus * 10**9
    free_vram = free_vram - cuda_overhead - model_overhead - prefill_overhead
    free_vram = (
        free_vram * cuda_memory_fraction * 0.95
    )  # by default, TGI allocates 95% of free VRAM to kv-cache

    print(f"Model Overhead: {bytes_to_gb(model_overhead)} GB")
    print(f"CUDA Overhead: {bytes_to_gb(cuda_overhead)} GB")
    print(f"Prefill Overhead: {bytes_to_gb(prefill_overhead)} GB")
    print(f"Free VRAM for KV Cache: {bytes_to_gb(free_vram)} GB")

    # calculate token budget available for full kv cache
    kv_cache_token_budget = int((free_vram) // kv_cache_overhead_per_token)

    # check if we have enough vram for full kv cache (num_batches * max_total_tokens)
    batch_size = math.ceil(max_batch_prefill_tokens / max_input_tokens)
    kv_cache_tokens_needed = batch_size * (max_total_tokens - 1)
    vram_needed_for_kv_cache = kv_cache_tokens_needed * kv_cache_overhead_per_token
    is_enough_vram = kv_cache_token_budget >= kv_cache_tokens_needed

    result = {
        "cuda_overhead": bytes_to_gb(cuda_overhead),
        "model_overhead": bytes_to_gb(model_overhead),
        "prefill_overhead": bytes_to_gb(prefill_overhead),
        "kv_cache_overhead_per_token": bytes_to_gb(kv_cache_overhead_per_token),
        "free_vram_for_kv_cache": bytes_to_gb(free_vram),
        "vram_needed_for_kv_cache": bytes_to_gb(vram_needed_for_kv_cache),
        "kv_cache_token_budget": kv_cache_token_budget,
        "kv_cache_tokens_needed": kv_cache_tokens_needed,
        "is_enough_vram": is_enough_vram,
    }

    return result

In [54]:
# max_input_tokens = 8750
# max_total_tokens = 8782
# max_batch_prefill_tokens = 8782

# max_input_tokens = 15520
# max_total_tokens = 15554
# max_batch_prefill_tokens = 15554

max_input_tokens = 20466
max_total_tokens = 20530
max_batch_prefill_tokens = 20530

model_id = "meta-llama/Llama-3.1-8B-Instruct"
validate_vram_overhead(
    model_id=model_id,
    num_gpus=4,
    vram_per_gpu=24,
    max_input_tokens=max_input_tokens,
    max_total_tokens=max_total_tokens,
    max_batch_prefill_tokens=max_batch_prefill_tokens,
    dtype="float16",
)

Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00,  7.00it/s]

Model Overhead: 16.06052 GB
CUDA Overhead: 4.0 GB
Prefill Overhead: 7.38348 GB
Free VRAM for KV Cache: 65.1282 GB





{'cuda_overhead': 4.0,
 'model_overhead': 16.06052,
 'prefill_overhead': 7.38348,
 'kv_cache_overhead_per_token': 0.00052,
 'free_vram_for_kv_cache': 65.1282,
 'vram_needed_for_kv_cache': 21.52622,
 'kv_cache_token_budget': 124222,
 'kv_cache_tokens_needed': 41058,
 'is_enough_vram': True}

In [55]:
max_input_tokens = 20466
max_total_tokens = 20530
max_batch_prefill_tokens = 50000

model_id = "meta-llama/Llama-3.1-8B-Instruct"
validate_vram_overhead(
    model_id=model_id,
    num_gpus=4,
    vram_per_gpu=24,
    max_input_tokens=max_input_tokens,
    max_total_tokens=max_total_tokens,
    max_batch_prefill_tokens=max_batch_prefill_tokens,
    dtype="float16",
)

Parse safetensors files: 100%|██████████| 4/4 [00:00<00:00,  7.70it/s]

Model Overhead: 16.06052 GB
CUDA Overhead: 4.0 GB
Prefill Overhead: 11.07522 GB
Free VRAM for KV Cache: 61.62105 GB





{'cuda_overhead': 4.0,
 'model_overhead': 16.06052,
 'prefill_overhead': 11.07522,
 'kv_cache_overhead_per_token': 0.00052,
 'free_vram_for_kv_cache': 61.62105,
 'vram_needed_for_kv_cache': 32.28933,
 'kv_cache_token_budget': 117532,
 'kv_cache_tokens_needed': 61587,
 'is_enough_vram': True}