In [10]:
from huggingface_hub import hf_hub_download
import json


def bits_to_gb(bits):
    return bits / (8 * 1024**3)


def calculate_train_vram_requirements(
        batch_size, seq_len, params, precision, num_layers, num_attn_heads, hidden_size, **ignored
):
    """
    full train, not lora
    source: https://arxiv.org/pdf/2205.05198.pdf (section 4.1)
    credit: https://medium.com/@siddheshgunjal82/understanding-vram-requirements-to-train-inference-with-large-language-models-llms-a3edd0f09d9f
    """
    # Calculate activations using the provided formula
    activations = (
        num_layers * (5/2) * num_attn_heads * batch_size * seq_len**2
                   + 17 * batch_size * hidden_size * seq_len
    )

    # Calculate VRAM using the provided formula
    vram_bits = precision * (activations + params)

    # Convert VRAM from bits to Gigabytes
    return bits_to_gb(vram_bits)


def calculate_inference_vram_requirements(
        batch_size, seq_len, params, precision, num_layers, hidden_size,
        num_attn_heads, num_kv_heads, gqa=True
):
    """
    source 1: https://developer.nvidia.com/blog/mastering-llm-techniques-inference-optimization/
    source 2: https://www.databricks.com/blog/llm-inference-performance-engineering-best-practices
    - same as source 1, but with the introduction a factor (n_heads / n_kv_heads) specific to GQA
      - "GQA helps with keeping the KV cache size down by sharing Keys/Values"
    - defaulting to calculated models using GQA since Mistral, Yi, and Llama 2 use it
    """
    kv_cache = batch_size * seq_len * 2 * num_layers * hidden_size
    if gqa:
        kv_cache *= num_kv_heads / num_attn_heads

    vram_bits = precision * (kv_cache + params)

    return bits_to_gb(vram_bits)

def get_model_params(model_uri):
    hf_hub_download(repo_id=model_uri, filename="config.json", local_dir=".")
    with open("config.json", "r") as f:
        model_params = json.load(f)
    return model_params

def print_table(model_uri, bparams, batch_size=1, precisions=None, mode="infer"):
    precisions = precisions or [4, 6, 8, 16]

    model_params = get_model_params(model_uri)

    seq_lens = (
        [2**i for i in range(8, 20) if 2**i< model_params["max_position_embeddings"]]
        + [model_params["max_position_embeddings"]]
    )

    calc_params = {
        "num_layers": model_params["num_hidden_layers"],
        "hidden_size": model_params["hidden_size"],
        "num_attn_heads": model_params["num_attention_heads"],
        "num_kv_heads": model_params["num_key_value_heads"],
    }

    if mode == "infer":
        vram_calculator = calculate_inference_vram_requirements
    elif mode == "train":
        vram_calculator = calculate_train_vram_requirements
    elif mode == "train_lora":
        raise NotImplemented
    else:
        raise ValueError

    column_width = 10

    # Print the header of the table with precisions
    header = f"{'SL / BP':>{column_width}}" + "".join([f" | {p:^10}" for p in precisions])
    results = [
        f"Model: {model_uri}",
        f"Params: {bparams}B",
        f"Batch Size: {batch_size}",
        f"Mode: {mode}",
        "",
        "Sequence Length vs Bit Precision - Memory Requirements"
    ]
    results.append(header)
    results.append("-" * len(header))

    # Iterate over each seq_len and calculate VRAM for each precision
    for seq_len in seq_lens:
        seq_len_label = f"{seq_len:>{column_width}}"
        if seq_len == max(seq_lens):
            seq_len_label = "*" + seq_len_label[1:]
        row_data = [seq_len_label]
        for precision in precisions:
            vram_required = vram_calculator(
                batch_size=batch_size,
                seq_len=seq_len,
                precision=precision,
                params=bparams * 1e9,
                **calc_params  # Unpack additional parameters if provided
            )
            row_data.append(f"{vram_required:8.1f}GB")  # Format with 1 decimal point

        # Print each row of the table
        results.append(" | ".join(row_data))

    results += ["", "* Model Max Context Size"]

    print("    " + "\n    ".join(results))

    # save everything to a file
    with open(f"{model_uri.replace('/', '-')}-{mode}.txt", "w") as f:
        f.write("\n".join(results))


In [12]:
print_table("alpindale/WizardLM-2-8x22B", bparams=141, mode="train")

config.json:   0%|          | 0.00/768 [00:00<?, ?B/s]

    Model: alpindale/WizardLM-2-8x22B
    Params: 141B
    Batch Size: 1
    Mode: train
    
    Sequence Length vs Bit Precision - Memory Requirements
       SL / BP |     4      |     6      |     8      |     16    
    --------------------------------------------------------------
           256 |     65.9GB |     98.8GB |    131.8GB |    263.5GB
           512 |     66.5GB |     99.8GB |    133.0GB |    266.0GB
          1024 |     69.0GB |    103.5GB |    138.0GB |    276.0GB
          2048 |     78.9GB |    118.3GB |    157.8GB |    315.5GB
          4096 |    118.4GB |    177.5GB |    236.7GB |    473.4GB
          8192 |    276.1GB |    414.1GB |    552.1GB |   1104.2GB
         16384 |    906.5GB |   1359.7GB |   1812.9GB |   3625.8GB
         32768 |   3427.3GB |   5140.9GB |   6854.5GB |  13709.0GB
    *    65536 |  13508.8GB |  20263.3GB |  27017.7GB |  54035.4GB
    
    * Model Max Context Size
