# Transformer Sizing

This notebook runs a bunch of analysis about a GPT-2 Transformer, e.g. number of FLOPS, parameters, peak memory footprint, checkpoint size, etc

**Reference**
- This notebook is based directly on Karpathy's [nanoGPT/transformer_sizing.ipynb](https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb)

In [1]:
from collections import OrderedDict

from gollem.models.llama3.config import Llama3Config
from gollem.models.llama3.config import get_llama3_model_config


# TODO Change this as desired
model_name = "llama3-1B"

model_cfg = get_llama3_model_config(model_name)

  from torch.distributed.optim import ZeroRedundancyOptimizer


## Parameter Count

In [2]:
def get_params(cfg: Llama3Config):
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token embeddings
    out["embedding/token"] = cfg.vocab_size * cfg.d_model
    out["embedding"] = out["embedding/token"]

    # attention blocks
    out["attention/norm"] = cfg.d_model
    d_head = cfg.d_model // cfg.n_head
    out["attention/wq"] = cfg.d_model * cfg.d_model
    out["attention/wk"] = cfg.d_model * cfg.n_kv_head * d_head
    out["attention/wv"] = cfg.d_model * cfg.n_kv_head * d_head
    out["attention/wo"] = cfg.d_model * cfg.d_model
    out["attention"] = (
        out["attention/norm"]
        + out["attention/wq"]
        + out["attention/wk"]
        + out["attention/wv"]
        + out["attention/wo"]
    )
    out["attention_total"] = cfg.n_layer * out["attention"]

    # MLP blocks
    out["mlp/norm"] = cfg.d_model
    out["mlp/w1"] = cfg.d_model * cfg.intermediate_size
    out["mlp/w2"] = cfg.intermediate_size * cfg.d_model
    out["mlp/w3"] = cfg.d_model * cfg.intermediate_size
    out["mlp"] = out["mlp/norm"] + out["mlp/w1"] + out["mlp/w2"] + out["mlp/w3"]
    out["mlp_total"] = cfg.n_layer * out["mlp"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["norm_final"] = cfg.d_model
    out["out_embedding"] = cfg.d_model * cfg.vocab_size

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["norm_final"] + out["out_embedding"]
    )

    return out


model_params = get_params(model_cfg)
params_total = model_params["total"]

# create a header
print(f"{'name':20s} {'params':10s} {'ratio (%)':10s}")
for k, v in model_params.items():
    print(f"{k:20s} {v:10d} {v / params_total * 100:10.4f}")

name                 params     ratio (%) 
embedding/token       262144000    18.3276
embedding             262144000    18.3276
attention/norm             2048     0.0001
attention/wq            4194304     0.2932
attention/wk            2097152     0.1466
attention/wv            2097152     0.1466
attention/wo            4194304     0.2932
attention              12584960     0.8799
attention_total       201359360    14.0779
mlp/norm                   2048     0.0001
mlp/w1                 14680064     1.0263
mlp/w2                 14680064     1.0263
mlp/w3                 14680064     1.0263
mlp                    44042240     3.0792
mlp_total             704675840    49.2668
block                  56627200     3.9590
transformer           906035200    63.3447
norm_final                 2048     0.0001
out_embedding         262144000    18.3276
total                1430325248   100.0000


## Parameter/Checkpoint Size

In [3]:
# we can now calculate the size of each checkpoint
# params are stored in fp32 (i.e. 4 bytes), and the AdamW optimizer has 2 additional buffers per param for statistics
params_bytes = params_total * 4
params_and_buffers_bytes = params_bytes + 2 * params_bytes
print(f"est checkpoint size: {params_and_buffers_bytes / 1e9:.2f} GB")
# TODO update this with actual measured bytes
measured_bytes = 1542470366  # from wc -c ckpt.pt
print(f"measured with wc -c ckpt.pt: {measured_bytes}")
print(f"fluff ratio: {measured_bytes / params_and_buffers_bytes * 100:.2f}%")

est checkpoint size: 17.16 GB
measured with wc -c ckpt.pt: 1542470366
fluff ratio: 8.99%


## GPU memory usage

We can estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer.

In [4]:
# Nvidia reports memory in GiB (denoted as GB but is actually GiB)
# 1 GiB = 1024 MiB = 1024**2 KiB = 1024**3 Bytes
GPU_MEMORY = {
    "H100": 80 * 1024**3,  # 80 GiB
    "A100": 40 * 1024**3,  # 40 GiB
    "RTX3090": 24 * 1024**3,  # 24 GiB
    "RTX4090": 24 * 1024**3,  # 24 GiB
    "RTX2070": 8 * 1024**3,  # 8 GiB
}

print("GPU memory ratio taken up just for parameters (incl. optimizer)")
print(f"{'GPU':12s} {'ratio (%)':8s}")
for k, v in GPU_MEMORY.items():
    print(f"{k:12s} {params_and_buffers_bytes / v * 100:8.2f}")


GPU memory ratio taken up just for parameters (incl. optimizer)
GPU          ratio (%)
H100            19.98
A100            39.96
RTX3090         66.60
RTX4090         66.60
RTX2070        199.81


## GPU forward-backward memory usage

We can estimate the total memory usage of the model over the course of a forward-backward pass.

This is made up of:

1. M_model - the model and optimizer parameters
2. M_optimizer - the optimizer buffers
3. M_gradient - the gradient of the model
4. M_activations - the activations of the model

The activations scale with batch size and are typically the largest component of memory usage, since the others are fixed given model size and context length.

Calculating the activation memory usage is quite tricky since it is affected by low level optimizations that can be hard to estimate exactly. It is also affected by things like flash attention, dropout, activation checkpointing, etc.

Another source of memory usage is the CUDA and pytorch overhead which is typically 0.5-2GiB and will depend on the setup of the system that is running. We don't include this in the calculations, so factor this in when comparing calculated vs empirical memory usage.

NOTE: these are estimates, so really should be taken as ballpark figures rather than exact values.

In [5]:
def compute_activations(cfg: Llama3Config, B: int, dtype: str):
    # Total: $8B + 16T + L \times (34TBH + 5AT^2B) + 4TBH$
    out = OrderedDict()

    bytes_per_activation = 2 if dtype in ["bfloat16", "float16"] else 4
    bytes_per_long = 8

    # token and position embeddings
    out["embedding"] = cfg.n_ctx * B * bytes_per_long

    TBH = cfg.n_ctx * B * cfg.d_model

    # attention blocks
    out["attention/norm"] = TBH * bytes_per_activation
    out["attention/kqv"] = TBH * bytes_per_activation

    # flash attention requires K, Q, V as well as two vectors l, m of length T (size BT)
    out["attention/attention_over_v"] = (
        TBH * 3 + 2 * cfg.n_ctx * B
    ) * bytes_per_activation

    out["attention/proj"] = TBH * bytes_per_activation
    out["attention"] = (
        out["attention/norm"]
        + out["attention/kqv"]
        + out["attention/attention_over_v"]
        + out["attention/proj"]
    )

    # MLP blocks
    out["mlp/norm"] = TBH * bytes_per_activation
    out["mlp/w1"] = TBH * bytes_per_activation
    out["mlp/w3"] = TBH * bytes_per_activation
    out["mlp/silu"] = cfg.n_ctx * B * cfg.intermediate_size * bytes_per_activation
    out["mlp/w2"] = cfg.n_ctx * B * cfg.intermediate_size * bytes_per_activation
    out["mlp"] = (
        out["mlp/norm"]
        + out["mlp/w1"]
        + out["mlp/w2"]
        + out["mlp/w3"]
        + out["mlp/silu"]
    )

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]

    # final layernorm and output projection
    out["norm_final"] = TBH * bytes_per_activation
    out["out_embedding"] = TBH * bytes_per_activation

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["norm_final"] + out["out_embedding"]
    )

    return out

In [6]:
dtype = "bfloat16"
# dtype = "float32"

# when using bf16 or fp16 we use mixed precision so have tostore
# both fp16 and fp32 versions of the parameters
# when using fp32 we store only the full precision fp32 parameters
bytes_per_param = 6 if dtype in ("bfloat16", "float16") else 4
M_model = params_total * bytes_per_param

# AdamW optimizer has 2 buffers per parameter
# values are stored in fp32
M_optimizer = 2 * params_total * 4

# we store one gradient value per parameter
# Gradient are stored in fp32
M_gradient = params_total * 4


divisor = 1024**3

print(f"M_model: {M_model / divisor:.2f} GiB")
print(f"M_optimizer: {M_optimizer / divisor:.2f} GiB")
print(f"M_gradient: {M_gradient / divisor:.2f} GiB")


# Empirical peak memory usage for gpt2 (124M) using flash attention
# run_name	            	 peak_mem_usage
# batch_size=1024 (1x1024)	 2269
# batch_size=2048 (2x1024)	 2870
# batch_size=4096 (4x1024)	 3918
# batch_size=8192 (8x1024)	 6012
# batch_size=16384 (16x1024) 10206
# batch_size=32768 (32x1024) 18595
# Map from batch size to peak memory usage in MiB
empirical_peak_memory_usage = {
    1: 2269,
    2: 2870,
    4: 3918,
    8: 6012,
    16: 10206,
    32: 18595,
}

# M_activations is more complex
print(f"{'batch size':10s} {'M_activations':13s} {'total memory':12s} {'  diff':6s}")
for batch_size in [1, 2, 4, 8, 16, 32]:
    M_activations = compute_activations(model_cfg, batch_size, dtype)
    total_memory = M_model + M_optimizer + M_activations["total"] + M_gradient
    total_memory_GiB = total_memory / divisor
    diff = (empirical_peak_memory_usage[batch_size] / 1024) - total_memory_GiB
    print(
        f"{batch_size:10d} {M_activations['total'] / divisor:13.2f} {total_memory_GiB:12.2f} {diff:6.2f}"
    )


M_model: 7.99 GiB
M_optimizer: 10.66 GiB
M_gradient: 5.33 GiB
batch size M_activations total memory   diff
         1          1.01        24.99 -22.77
         2          2.02        25.99 -23.19
         4          4.03        28.01 -24.18
         8          8.06        32.04 -26.17
        16         16.13        40.10 -30.14
        32         32.25        56.23 -38.07


## FLOPS

Here we estimate FLOPS for a single forward pass.

In [7]:
def compute_flops(cfg: Llama3Config):
    # we only count Weight FLOPs,
    # FLOPS for all other layers (LayerNorm, Softmax, etc) and bias vector additian are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    d_head = cfg.d_model // cfg.n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/wq"] = 2 * cfg.n_ctx * cfg.d_model * cfg.d_model
    out["attention/wk"] = 2 * cfg.n_ctx * cfg.d_model * cfg.n_kv_head * d_head
    out["attention/wv"] = 2 * cfg.n_ctx * cfg.d_model * cfg.n_kv_head * d_head
    out["attention/wo"] = 2 * cfg.n_ctx * cfg.d_model * cfg.d_model
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * cfg.n_ctx * cfg.n_ctx * cfg.d_model
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * cfg.n_head * (cfg.n_ctx * cfg.n_ctx * d_head)
    out["attention"] = (
        out["attention/wq"]
        + out["attention/wk"]
        + out["attention/wv"]
        + out["attention/wo"]
        + out["attention/scores"]
        + out["attention/reduce"]
    )
    # MLP blocks
    out["mlp/w1"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.intermediate_size)
    out["mlp/w2"] = 2 * cfg.n_ctx * (cfg.intermediate_size * cfg.d_model)
    out["mlp/w3"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.intermediate_size)
    out["mlp"] = out["mlp/w1"] + out["mlp/w2"] + out["mlp/w3"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["out_embedding"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["out_embedding"]
    out["backward_total"] = (
        2 * out["forward_total"]
    )  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out


# compare our param count to that reported by PyTorch
model_flops = compute_flops(model_cfg)
flops_total = model_flops["forward_total"]
print(f"{'name':20s} {'flops':14s} {'ratio (%)':10s}")
for k, v in model_flops.items():
    print(f"{k:20s} {v:14d} {v / flops_total * 100:10.4f}")

name                 flops          ratio (%) 
attention/wq             8589934592     0.3396
attention/wk             4294967296     0.1698
attention/wv             4294967296     0.1698
attention/wo             8589934592     0.3396
attention/scores         4294967296     0.1698
attention/reduce         4294967296     0.1698
attention               34359738368     1.3582
mlp/w1                  30064771072     1.1885
mlp/w2                  30064771072     1.1885
mlp/w3                  30064771072     1.1885
mlp                     90194313216     3.5654
block                  124554051584     4.9236
transformer           1992864825344    78.7776
out_embedding          536870912000    21.2224
forward_total         2529735737344   100.0000
backward_total        5059471474688   200.0000
total                 7589207212032   300.0000


In [8]:
# now here is an estimate copy pasted from the PaLM paper
# this formula is often used to calculate MFU (model flops utilization)
def compute_palm_flops(cfg: Llama3Config):
    """estimate of the model flops following PaLM paper formula"""
    # non-embedding model parameters.
    model_params = get_params(cfg)
    N = model_params["total"] - model_params["embedding"]
    L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.d_model // cfg.n_head, cfg.n_ctx
    mf_per_token = 6 * N + 12 * L * H * Q * T
    mf = mf_per_token * cfg.n_ctx
    return mf


palm_flops = compute_palm_flops(model_cfg)
print(
    f"palm_flops: {palm_flops:d}, flops: {model_flops['total']:d}, ratio: {palm_flops / model_flops['total']:.4f}"
)

palm_flops: 7589622448128, flops: 7589207212032, ratio: 1.0001


## GPU Flops Usage

Given our estimated FLOPS we can calculate how much of our GPU FLOP capacity is being used, that is our model flop utilization (MFU).

To calculate this we need a few bits of information:

- GPU speed (FLOPS)
- batch size (including gradient accumulation)
- time per iteration

For the GPU speed we refer to: https://www.techpowerup.com/gpu-specs/ which gives theoretical performance, specifically looking at performance for BF16 and FP16, which ever is supported and faster.

In [9]:
# R
GPU_FLOPS = {
    "H100": 756e12,  # 756 TFLOPS BF16 (this is a guess, spec sheet shows 1513 for sparse tensors)
    "A100": 312e12,  # 312 TFLOPS BF16
    "RTX4090": 83e12,  # 83 TFLOPS FP16
    "RTX2070": 15e12,  # 15 TFLOPS FP16
}

# TODO Change these values to desired values
batch_size = 20
grad_accum = 5
measured_time = 0.755  # in seconds per iteration

# calculate flops achieved
total_batch_size = batch_size * grad_accum
measured_throughput = total_batch_size / measured_time
flops_achieved = model_flops["total"] * measured_throughput

# the fraction of the A100 that we are using:
print("Fraction of GPU FLOPS used")
print(f"{'GPU':14s} {'ratio (%)':10s}")
for k, v in GPU_FLOPS.items():
    print(f"{k:14s} {flops_achieved / v * 100:10.2f}")

Fraction of GPU FLOPS used
GPU            ratio (%) 
H100               132.96
A100               322.18
RTX4090           1211.08
RTX2070           6701.29


## Total Training Compute

Here we use the value computed so far to compute an estimate of the total amount of compute needed to train the model.

Here we estimate the total amount of compute `C` required to train the model as `C ~= 6*N*D`, where:

- `6` is a heuristic value (see [Dzmitry Bahdanau's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4) for an explanation) 
    - it basically stems from weight multiplications requiring 6 FLOPS per token per weight when combining both forward and backward passes.
- `N` is the total number of model parameters
- `D` is total size of dataset (in tokens)

Note the equation changes to `C ~= 8*N*D` with activation checkpointing (i.e. recompute activations as needed to save memory when doing back-prop).

We also need to factor in the model flops utilization (MFU) to correct for the fact that we cannot use 100% of the GPUs FLOPS due to memory bottlenecks, etc (again see Dzmitry's blog for examples).


In [10]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
model_size = get_params(model_cfg)["total"]  # this is number of parameters, N

# TODO change these parameters
tokens_num = 300e9  # 300B tokens, this is dataset size in tokens, D
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
num_gpus = 8  # number of GPUS used in parallel

print("Time needed to train model on different GPUS")
print(f"{'GPU':10s} {'time (days)':16s}")
for gpu_name, gpu_flops in GPU_FLOPS.items():
    flops_throughput = gpu_flops * num_gpus * assumed_mfu
    flops_needed = 6 * model_size * tokens_num  # 6ND
    time_needed_s = flops_needed / flops_throughput  # in seconds
    print(f"{gpu_name:10s} {time_needed_s / 3600 / 24:10.2f}")


Time needed to train model on different GPUS
GPU        time (days)     
H100            16.42
A100            39.79
RTX4090        149.59
RTX2070        827.73
