# Transformer Sizing

This notebook runs a bunch of analysis about a GPT-2 Transformer, e.g. number of FLOPS, parameters, peak memory footprint, checkpoint size, etc

**Reference**
- This notebook is based directly on Karpathy's [nanoGPT/transformer_sizing.ipynb](https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb)

In [1]:
from collections import OrderedDict

from gollem.models.gpt2.config import GPT2Config
from gollem.models.gpt2.config import get_gpt2_model_config

# Change this to the model you want to analyze
model_name = "gpt2"

model_cfg = get_gpt2_model_config(model_name)

  from torch.distributed.optim import ZeroRedundancyOptimizer


## Parameter Count

In [2]:
def get_params(cfg: GPT2Config):
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = cfg.n_ctx * cfg.d_model
    out["embedding/token"] = cfg.vocab_size * cfg.d_model
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["attention/kqv"] = cfg.d_model * 3 * cfg.d_model
    out["attention/proj"] = cfg.d_model**2
    out["attention"] = (
        out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]
    )

    # MLP blocks
    out["mlp/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["mlp/ffw"] = cfg.d_model * cfg.d_mlp + int(cfg.ln_bias) * cfg.d_mlp
    out["mlp/proj"] = cfg.d_mlp * cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["ln_f"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model  # final layernorm
    if cfg.share_embd_params:
        # 0 because of parameter sharing. This layer uses the weights from the embedding layer
        out["out_embedding"] = 0
    else:
        out["out_embedding"] = cfg.d_model * cfg.vocab_size

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["ln_f"] + out["out_embedding"]
    )

    return out


# compare our param count to that reported by PyTorch (for "GPT2" with 124M params)
# TODO update the PyTorch value to include bias
model_params = get_params(model_cfg)
params_total = model_params["total"]
if model_cfg.model_name == "gpt2":
    expected_params = 124402944
    print(
        f"we see: {params_total}, expected: {expected_params}, match: {params_total == expected_params}"
    )
# create a header
print(f"{'name':20s} {'params':10s} {'ratio (%)':10s}")
for k, v in model_params.items():
    print(f"{k:20s} {v:10d} {v / params_total * 100:10.4f}")

we see: 124402944, expected: 124402944, match: True
name                 params     ratio (%) 
embedding/position       786432     0.6322
embedding/token        38597376    31.0261
embedding              39383808    31.6583
attention/ln               1536     0.0012
attention/kqv           1769472     1.4224
attention/proj           589824     0.4741
attention               2360832     1.8977
mlp/ln                     1536     0.0012
mlp/ffw                 2362368     1.8990
mlp/proj                2360064     1.8971
mlp                     4723968     3.7973
block                   7084800     5.6950
transformer            85017600    68.3405
ln_f                       1536     0.0012
out_embedding                 0     0.0000
total                 124402944   100.0000


## Parameter/Checkpoint Size

In [3]:
# we can now calculate the size of each checkpoint
# params are stored in fp32 (i.e. 4 bytes), and the AdamW optimizer has 2 additional buffers per param for statistics
params_bytes = params_total * 4
params_and_buffers_bytes = params_bytes + 2 * params_bytes
print(f"est checkpoint size: {params_and_buffers_bytes / 1e9:.2f} GB")
# TODO update this with actual measured bytes
measured_bytes = 1542470366  # from wc -c ckpt.pt
print(f"measured with wc -c ckpt.pt: {measured_bytes}")
print(f"fluff ratio: {measured_bytes / params_and_buffers_bytes * 100:.2f}%")

est checkpoint size: 1.49 GB
measured with wc -c ckpt.pt: 1542470366
fluff ratio: 103.32%


## GPU memory usage

We can estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer.

In [4]:
# Nvidia reports memory in GiB (denoted as GB but is actually GiB)
# 1 GiB = 1024 MiB = 1024**2 KiB = 1024**3 Bytes
GPU_MEMORY = {
    "H100": 80 * 1024**3,  # 80 GiB
    "A100": 40 * 1024**3,  # 40 GiB
    "RTX3090": 24 * 1024**3,  # 24 GiB
    "RTX4090": 24 * 1024**3,  # 24 GiB
    "RTX2070": 8 * 1024**3,  # 8 GiB
}

print("GPU memory ratio taken up just for parameters (incl. optimizer)")
print(f"{'GPU':12s} {'ratio (%)':8s}")
for k, v in GPU_MEMORY.items():
    print(f"{k:12s} {params_and_buffers_bytes / v * 100:8.2f}")


GPU memory ratio taken up just for parameters (incl. optimizer)
GPU          ratio (%)
H100             1.74
A100             3.48
RTX3090          5.79
RTX4090          5.79
RTX2070         17.38


## GPU forward-backward memory usage

We can estimate the total memory usage of the model over the course of a forward-backward pass.

This is made up of:

1. M_model - the model and optimizer parameters
2. M_optimizer - the optimizer buffers
3. M_gradient - the gradient of the model
4. M_activations - the activations of the model

The activations scale with batch size and are typically the largest component of memory usage, since the others are fixed given model size and context length.

Calculating the activation memory usage is quite tricky since it is affected by low level optimizations that can be hard to estimate exactly. It is also affected by things like flash attention, dropout, activation checkpointing, etc.

Another source of memory usage is the CUDA and pytorch overhead which is typically 0.5-2GiB and will depend on the setup of the system that is running. We don't include this in the calculations, so factor this in when comparing calculated vs empirical memory usage.

NOTE: these are estimates, so really should be taken as ballpark figures rather than exact values.

In [6]:
def compute_activations(cfg: GPT2Config, B: int, dtype: str, using_flash_attn: bool):
    # Total: $8B + 16T + L \times (34TBH + 5AT^2B) + 4TBH$
    out = OrderedDict()

    bytes_per_activation = 2 if dtype in ["bfloat16", "float16"] else 4
    bytes_per_long = 8

    # token and position embeddings
    out["embedding/position"] = cfg.n_ctx * bytes_per_long
    out["embedding/token"] = cfg.n_ctx * B * bytes_per_long
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    TBH = cfg.n_ctx * B * cfg.d_model

    # attention blocks
    out["attention/ln"] = TBH * bytes_per_activation
    out["attention/kqv"] = TBH * bytes_per_activation
    if using_flash_attn:
        # when using flash attention a bunch of optimizations are done
        out["attention/qk_matmul"] = 0
        out["attention/softmax"] = 0
        # flash attention requires K, Q, V as well as two vectors l, m of length T (size BT)
        out["attention/attention_over_v"] = (
            TBH * 3 + 2 * cfg.n_ctx * B
        ) * bytes_per_activation
    else:
        out["attention/qk_matmul"] = TBH * 2 * bytes_per_activation
        out["attention/softmax"] = cfg.n_head * cfg.n_ctx**2 * B * bytes_per_activation
        out["attention/attention_over_v"] = (
            TBH + cfg.n_head * cfg.n_ctx**2 * B
        ) * bytes_per_activation
    out["attention/proj"] = TBH * bytes_per_activation
    out["attention"] = (
        out["attention/ln"]
        + out["attention/kqv"]
        + out["attention/qk_matmul"]
        + out["attention/softmax"]
        + out["attention/attention_over_v"]
        + out["attention/proj"]
    )

    # MLP blocks
    out["mlp/ln"] = TBH * bytes_per_activation
    out["mlp/ffw"] = TBH * bytes_per_activation
    out["mlp/ffw_activation"] = cfg.n_ctx * B * cfg.d_mlp * bytes_per_activation
    out["mlp/proj"] = cfg.n_ctx * B * cfg.d_mlp * bytes_per_activation
    out["mlp"] = (
        out["mlp/ln"] + out["mlp/ffw"] + out["mlp/ffw_activation"] + out["mlp/proj"]
    )

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]

    # final layernorm and output projection
    out["ln_f"] = TBH * bytes_per_activation
    out["out_embedding"] = TBH * bytes_per_activation

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["ln_f"] + out["out_embedding"]
    )

    return out

In [7]:
dtype = "bfloat16"
using_flash_attn = True
# dtype = "float32"

# when using bf16 or fp16 we use mixed precision so have tostore
# both fp16 and fp32 versions of the parameters
# when using fp32 we store only the full precision fp32 parameters
bytes_per_param = 6 if dtype in ("bfloat16", "float16") else 4
M_model = params_total * bytes_per_param

# AdamW optimizer has 2 buffers per parameter
# values are stored in fp32
M_optimizer = 2 * params_total * 4

# we store one gradient value per parameter
# Gradient are stored in fp32
M_gradient = params_total * 4


divisor = 1024**3

print(f"M_model: {M_model / divisor:.2f} GiB")
print(f"M_optimizer: {M_optimizer / divisor:.2f} GiB")
print(f"M_gradient: {M_gradient / divisor:.2f} GiB")


# Empirical peak memory usage for gpt2 (124M) using flash attention
# run_name	            	 peak_mem_usage
# batch_size=1024 (1x1024)	 2269
# batch_size=2048 (2x1024)	 2870
# batch_size=4096 (4x1024)	 3918
# batch_size=8192 (8x1024)	 6012
# batch_size=16384 (16x1024) 10206
# batch_size=32768 (32x1024) 18595
# Map from batch size to peak memory usage in MiB
empirical_peak_memory_usage = {
    1: 2269,
    2: 2870,
    4: 3918,
    8: 6012,
    16: 10206,
    32: 18595,
}

# M_activations is more complex
print(f"{'batch size':10s} {'M_activations':13s} {'total memory':12s} {'  diff':6s}")
for batch_size in [1, 2, 4, 8, 16, 32]:
    M_activations = compute_activations(model_cfg, batch_size, dtype, using_flash_attn)
    total_memory = M_model + M_optimizer + M_activations["total"] + M_gradient
    total_memory_GiB = total_memory / divisor
    diff = (empirical_peak_memory_usage[batch_size] / 1024) - total_memory_GiB
    print(
        f"{batch_size:10d} {M_activations['total'] / divisor:13.2f} {total_memory_GiB:12.2f} {diff:6.2f}"
    )


M_model: 0.70 GiB
M_optimizer: 0.93 GiB
M_gradient: 0.46 GiB
batch size M_activations total memory   diff
         1          0.28         2.37  -0.15
         2          0.57         2.65   0.15
         4          1.14         3.22   0.60
         8          2.27         4.36   1.51
        16          4.55         6.63   3.33
        32          9.10        11.18   6.98


## FLOPS

Here we estimate FLOPS for a single forward pass.

In [8]:
def compute_flops(cfg: GPT2Config):
    # we only count Weight FLOPs,
    # FLOPS for all other layers (LayerNorm, Softmax, etc) and bias vector additian are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = cfg.d_model // cfg.n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * cfg.n_ctx * (cfg.d_model * 3 * cfg.d_model)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * cfg.n_ctx * cfg.n_ctx * cfg.d_model
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * cfg.n_head * (cfg.n_ctx * cfg.n_ctx * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.d_model)
    out["attention"] = sum(
        out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"]
    )

    # MLP blocks
    out["mlp/ffw1"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.d_mlp)
    out["mlp/ffw2"] = 2 * cfg.n_ctx * (cfg.d_mlp * cfg.d_model)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["out_embedding"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["out_embedding"]
    out["backward_total"] = (
        2 * out["forward_total"]
    )  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out


# compare our param count to that reported by PyTorch
model_flops = compute_flops(model_cfg)
flops_total = model_flops["forward_total"]
print(f"{'name':20s} {'flops':14s} {'ratio (%)':10s}")
for k, v in model_flops.items():
    print(f"{k:20s} {v:14d} {v / flops_total * 100:10.4f}")

name                 flops          ratio (%) 
attention/kqv            3623878656     1.2426
attention/scores         1610612736     0.5522
attention/reduce         1610612736     0.5522
attention/proj           1207959552     0.4142
attention                8053063680     2.7612
mlp/ffw1                 4831838208     1.6567
mlp/ffw2                 4831838208     1.6567
mlp                      9663676416     3.3135
block                   17716740096     6.0747
transformer            212600881152    72.8963
out_embedding           79047426048    27.1037
forward_total          291648307200   100.0000
backward_total         583296614400   200.0000
total                  874944921600   300.0000


In [10]:
# now here is an estimate copy pasted from the PaLM paper
# this formula is often used to calculate MFU (model flops utilization)
def compute_palm_flops(cfg: GPT2Config):
    """estimate of the model flops following PaLM paper formula"""
    # non-embedding model parameters. note that we do not subtract the
    # embedding/token params because those are tied and get used in the last layer.
    model_params = get_params(cfg)
    N = model_params["total"] - model_params["embedding/position"]
    L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.d_model // cfg.n_head, cfg.n_ctx
    mf_per_token = 6 * N + 12 * L * H * Q * T
    mf = mf_per_token * cfg.n_ctx
    return mf


palm_flops = compute_palm_flops(model_cfg)
print(
    f"palm_flops: {palm_flops:d}, flops: {model_flops['total']:d}, ratio: {palm_flops / model_flops['total']:.4f}"
)

palm_flops: 875463966720, flops: 874944921600, ratio: 1.0006


## GPU Flops Usage

Given our estimated FLOPS we can calculate how much of our GPU FLOP capacity is being used, that is our model flop utilization (MFU).

To calculate this we need a few bits of information:

- GPU speed (FLOPS)
- batch size (including gradient accumulation)
- time per iteration

For the GPU speed we refer to: https://www.techpowerup.com/gpu-specs/ which gives theoretical performance, specifically looking at performance for BF16 and FP16, which ever is supported and faster.

In [14]:
# R
GPU_FLOPS = {
    "H100": 756e12,  # 756 TFLOPS BF16 (this is a guess, spec sheet shows 1513 for sparse tensors)
    "A100": 312e12,  # 312 TFLOPS BF16
    "RTX4090": 83e12,  # 83 TFLOPS FP16
    "RTX2070": 15e12,  # 15 TFLOPS FP16
}

# TODO Change these values to desired values
batch_size = 20
grad_accum = 5
measured_time = 0.755  # in seconds per iteration

# calculate flops achieved
total_batch_size = batch_size * grad_accum
measured_throughput = total_batch_size / measured_time
flops_achieved = model_flops["total"] * measured_throughput

print(model_flops["total"] // model_cfg.n_ctx)
print(measured_throughput)

# the fraction of the A100 that we are using:
print("Fraction of GPU FLOPS used")
print(f"{'GPU':14s} {'ratio (%)':10s}")
for k, v in GPU_FLOPS.items():
    print(k, v)
    print(f"{k:14s} {flops_achieved / v * 100:10.2f}")


854438400
132.4503311258278
Fraction of GPU FLOPS used
GPU            ratio (%) 
H100 756000000000000.0
H100                15.33
A100 312000000000000.0
A100                37.14
RTX4090 83000000000000.0
RTX4090            139.62
RTX2070 15000000000000.0
RTX2070            772.58


## Total Training Compute

Here we use the value computed so far to compute an estimate of the total amount of compute needed to train the model.

Here we estimate the total amount of compute `C` required to train the model as `C ~= 6*N*D`, where:

- `6` is a heuristic value (see [Dzmitry Bahdanau's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4) for an explanation) 
    - it basically stems from weight multiplications requiring 6 FLOPS per token per weight when combining both forward and backward passes.
- `N` is the total number of model parameters
- `D` is total size of dataset (in tokens)

Note the equation changes to `C ~= 8*N*D` with activation checkpointing (i.e. recompute activations as needed to save memory when doing back-prop).

We also need to factor in the model flops utilization (MFU) to correct for the fact that we cannot use 100% of the GPUs FLOPS due to memory bottlenecks, etc (again see Dzmitry's blog for examples).


In [None]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
model_size = get_params(model_cfg)["total"]  # this is number of parameters, N

# TODO change these parameters
tokens_num = 300e9  # 300B tokens, this is dataset size in tokens, D
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
num_gpus = 8  # number of GPUS used in parallel

print("Time needed to train model on different GPUS")
print(f"{'GPU':10s} {'time (days)':16s}")
for gpu_name, gpu_flops in GPU_FLOPS.items():
    flops_throughput = gpu_flops * num_gpus * assumed_mfu
    flops_needed = 6 * model_size * tokens_num  # 6ND
    time_needed_s = flops_needed / flops_throughput  # in seconds
    print(f"{gpu_name:10s} {time_needed_s / 3600 / 24:10.2f}")
