# Transformer Sizing

This notebook runs a bunch of analysis about a GPT-2 Transformer, e.g. number of FLOPS, parameters, peak memory footprint, checkpoint size, etc

**Reference**
- This notebook is based directly on Karpathy's [nanoGPT/transformer_sizing.ipynb](https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb)

In [119]:
from dataclasses import dataclass
from collections import OrderedDict


@dataclass
class ModelConfig:
    name: str
    n_ctx: int = 1024
    n_layer: int = 12
    n_head: int = 12
    d_model: int = 768
    d_mlp: int = 4 * 768
    vocab_size: int = 50257
    ln_bias: bool = False
    mlp_bias: bool = False
    share_embd_params: bool = True


MODEL_CONFIG_ARGS = {
    # 14M params
    "gpt2-tiny": ModelConfig(
        name="gpt2-tiny",
        n_ctx=128,
        n_layer=2,
        n_head=4,
        d_model=256,
        d_mlp=4 * 256,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 124M params
    "gpt2": ModelConfig(
        name="gpt2",
        n_ctx=1024,
        n_layer=12,
        n_head=12,
        d_model=768,
        d_mlp=4 * 768,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 350M params
    "gpt2-medium": ModelConfig(
        name="gpt2-medium",
        n_ctx=1024,
        n_layer=24,
        n_head=16,
        d_model=1024,
        d_mlp=4 * 1024,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 774M params
    "gpt2-large": ModelConfig(
        name="gpt2-large",
        n_ctx=1024,
        n_layer=36,
        n_head=20,
        d_model=1280,
        d_mlp=4 * 1280,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 1558M params
    "gpt2-xl": ModelConfig(
        name="gpt2-xl",
        n_ctx=1024,
        n_layer=48,
        n_head=25,
        d_model=1600,
        d_mlp=4 * 1600,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
}


def load_config(name: str) -> ModelConfig:
    assert name in MODEL_CONFIG_ARGS
    return MODEL_CONFIG_ARGS[name]


# TODO Change this as desired
model_cfg = load_config("gpt2-xl")

## Parameter Count

In [120]:
def get_params(cfg: ModelConfig):
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = cfg.n_ctx * cfg.d_model
    out["embedding/token"] = cfg.vocab_size * cfg.d_model
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["attention/kqv"] = cfg.d_model * 3 * cfg.d_model
    out["attention/proj"] = cfg.d_model**2
    out["attention"] = (
        out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]
    )

    # MLP blocks
    out["mlp/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["mlp/ffw"] = cfg.d_model * cfg.d_mlp + int(cfg.ln_bias) * cfg.d_mlp
    out["mlp/proj"] = cfg.d_mlp * cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["ln_f"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model  # final layernorm
    if cfg.share_embd_params:
        # 0 because of parameter sharing. This layer uses the weights from the embedding layer
        out["out_embedding"] = 0
    else:
        out["out_embedding"] = cfg.d_model * cfg.vocab_size

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["ln_f"] + out["out_embedding"]
    )

    return out


# compare our param count to that reported by PyTorch (for "GPT2" with 124M params)
# TODO update the PyTorch value to include bias
model_params = get_params(model_cfg)
params_total = model_params["total"]
if model_cfg.name == "gpt2":
    print(
        f"we see: {params_total}, expected: {124402944}, match: {params_total == 124337664}"
    )
# create a header
print(f"{'name':20s} {'params':10s} {'ratio (%)':10s}")
for k, v in model_params.items():
    print(f"{k:20s} {v:10d} {v / params_total * 100:10.4f}")

name                 params     ratio (%) 
embedding/position      1638400     0.1052
embedding/token        80411200     5.1635
embedding              82049600     5.2687
attention/ln               3200     0.0002
attention/kqv           7680000     0.4932
attention/proj          2560000     0.1644
attention              10243200     0.6578
mlp/ln                     3200     0.0002
mlp/ffw                10246400     0.6580
mlp/proj               10241600     0.6576
mlp                    20491200     1.3158
block                  30734400     1.9736
transformer          1475251200    94.7311
ln_f                       3200     0.0002
out_embedding                 0     0.0000
total                1557304000   100.0000


## Parameter/Checkpoint Size

In [121]:
# we can now calculate the size of each checkpoint
# params are stored in fp32 (i.e. 4 bytes), and the AdamW optimizer has 2 additional buffers per param for statistics
params_bytes = params_total * 4
params_and_buffers_bytes = params_bytes + 2 * params_bytes
print(f"est checkpoint size: {params_and_buffers_bytes / 1e9:.2f} GB")
# TODO update this with actual measured bytes
measured_bytes = 1542470366  # from wc -c ckpt.pt
print(f"measured with wc -c ckpt.pt: {measured_bytes}")
print(f"fluff ratio: {measured_bytes / params_and_buffers_bytes * 100:.2f}%")

est checkpoint size: 18.69 GB
measured with wc -c ckpt.pt: 1542470366
fluff ratio: 8.25%


## GPU memory usage

We can estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer.

In [122]:
# Nvidia reports memory in GiB (denoted as GB but is actually GiB)
# 1 GiB = 1024 MiB = 1024**2 KiB = 1024**3 Bytes
GPU_MEMORY = {
    "H100": 80 * 1024**3,  # 80 GiB
    "A100": 40 * 1024**3,  # 40 GiB
    "RTX3090": 24 * 1024**3,  # 24 GiB
    "RTX4090": 24 * 1024**3,  # 24 GiB
    "RTX2070": 8 * 1024**3,  # 8 GiB
}

print("GPU memory ratio taken up just for parameters (incl. optimizer)")
print(f"{'GPU':12s} {'ratio (%)':8s}")
for k, v in GPU_MEMORY.items():
    print(f"{k:12s} {params_and_buffers_bytes / v * 100:8.2f}")


GPU memory ratio taken up just for parameters (incl. optimizer)
GPU          ratio (%)
H100            21.76
A100            43.51
RTX3090         72.52
RTX4090         72.52
RTX2070        217.55


## GPU forward-backward memory usage

We can estimate the total memory usage of the model over the course of a forward-backward pass.

This is made up of:

1. M_model - the model and optimizer parameters
2. M_optimizer - the optimizer buffers
3. M_gradient - the gradient of the model
4. M_activations - the activations of the model

The activations scale with batch size and are typically the largest component of memory usage, since the others are fixed given model size and context length.

Calculating the activation memory usage is quite tricky since it is affected by low level optimizations that can be hard to estimate exactly. It is also affected by things like flash attention, dropout, activation checkpointing, etc.

Another source of memory usage is the CUDA and pytorch overhead which is typically 0.5-2GiB and will depend on the setup of the system that is running. We don't include this in the calculations, so factor this in when comparing calculated vs empirical memory usage.

NOTE: these are estimates, so really should be taken as ballpark figures rather than exact values.

In [129]:
def compute_activations(cfg: ModelConfig, B: int, dtype: str, using_flash_attn: bool):
    # Total: $8B + 16T + L \times (34TBH + 5AT^2B) + 4TBH$
    out = OrderedDict()

    bytes_per_activation = 2 if dtype in ["bfloat16", "float16"] else 4
    bytes_per_long = 8

    # token and position embeddings
    out["embedding/position"] = cfg.n_ctx * bytes_per_long
    out["embedding/token"] = cfg.n_ctx * B * bytes_per_long
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    TBH = cfg.n_ctx * B * cfg.d_model

    # attention blocks
    out["attention/ln"] = TBH * bytes_per_activation
    out["attention/kqv"] = TBH * bytes_per_activation
    if using_flash_attn:
        # when using flash attention a bunch of optimizations are done
        out["attention/qk_matmul"] = 0
        out["attention/softmax"] = 0
        # flash attention requires K, Q, V as well as two vectors l, m of length T (size BT)
        out["attention/attention_over_v"] = (
            TBH * 3 + 2 * cfg.n_ctx * B
        ) * bytes_per_activation
    else:
        out["attention/qk_matmul"] = TBH * 2 * bytes_per_activation
        out["attention/softmax"] = cfg.n_head * cfg.n_ctx**2 * B * bytes_per_activation
        out["attention/attention_over_v"] = (
            TBH + cfg.n_head * cfg.n_ctx**2 * B
        ) * bytes_per_activation
    out["attention/proj"] = TBH * bytes_per_activation
    out["attention"] = (
        out["attention/ln"]
        + out["attention/kqv"]
        + out["attention/qk_matmul"]
        + out["attention/softmax"]
        + out["attention/attention_over_v"]
        + out["attention/proj"]
    )

    # MLP blocks
    out["mlp/ln"] = TBH * bytes_per_activation
    out["mlp/ffw"] = TBH * bytes_per_activation
    out["mlp/ffw_activation"] = cfg.n_ctx * B * cfg.d_mlp * bytes_per_activation
    out["mlp/proj"] = cfg.n_ctx * B * cfg.d_mlp * bytes_per_activation
    out["mlp"] = (
        out["mlp/ln"] + out["mlp/ffw"] + out["mlp/ffw_activation"] + out["mlp/proj"]
    )

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]

    # final layernorm and output projection
    out["ln_f"] = TBH * bytes_per_activation
    out["out_embedding"] = TBH * bytes_per_activation

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["ln_f"] + out["out_embedding"]
    )

    return out

In [124]:
dtype = "bfloat16"
using_flash_attn = True
# dtype = "float32"

# when using bf16 or fp16 we use mixed precision so have tostore
# both fp16 and fp32 versions of the parameters
# when using fp32 we store only the full precision fp32 parameters
bytes_per_param = 6 if dtype in ("bfloat16", "float16") else 4
M_model = params_total * bytes_per_param

# AdamW optimizer has 2 buffers per parameter
# values are stored in fp32
M_optimizer = 2 * params_total * 4

# we store one gradient value per parameter
# Gradient are stored in fp32
M_gradient = params_total * 4


divisor = 1024**3

print(f"M_model: {M_model / divisor:.2f} GiB")
print(f"M_optimizer: {M_optimizer / divisor:.2f} GiB")
print(f"M_gradient: {M_gradient / divisor:.2f} GiB")


# Empirical peak memory usage for gpt2 (124M) using flash attention
# run_name	            	 peak_mem_usage
# batch_size=1024 (1x1024)	 2269
# batch_size=2048 (2x1024)	 2870
# batch_size=4096 (4x1024)	 3918
# batch_size=8192 (8x1024)	 6012
# batch_size=16384 (16x1024) 10206
# batch_size=32768 (32x1024) 18595
# Map from batch size to peak memory usage in MiB
empirical_peak_memory_usage = {
    1: 2269,
    2: 2870,
    4: 3918,
    8: 6012,
    16: 10206,
    32: 18595,
}

# M_activations is more complex
print(f"{'batch size':10s} {'M_activations':13s} {'total memory':12s} {'  diff':6s}")
for batch_size in [1, 2, 4, 8, 16, 32]:
    M_activations = compute_activations(model_cfg, batch_size, dtype, using_flash_attn)
    total_memory = M_model + M_optimizer + M_activations["total"] + M_gradient
    total_memory_GiB = total_memory / divisor
    diff = (empirical_peak_memory_usage[batch_size] / 1024) - total_memory_GiB
    print(
        f"{batch_size:10d} {M_activations['total'] / divisor:13.2f} {total_memory_GiB:12.2f} {diff:6.2f}"
    )


M_model: 8.70 GiB
M_optimizer: 11.60 GiB
M_gradient: 5.80 GiB
batch size M_activations total memory   diff
         1          3.67        29.77 -27.56
         2          7.34        33.44 -30.64
         4         14.67        40.78 -36.95
         8         29.35        55.45 -49.58
        16         58.69        84.80 -74.83
        32        117.39       143.50 -125.34


## FLOPS

Here we estimate FLOPS for a single forward pass.

In [125]:
def compute_flops(cfg: ModelConfig):
    # we only count Weight FLOPs,
    # FLOPS for all other layers (LayerNorm, Softmax, etc) and bias vector additian are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = cfg.d_model // cfg.n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * cfg.n_ctx * (cfg.d_model * 3 * cfg.d_model)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * cfg.n_ctx * cfg.n_ctx * cfg.d_model
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * cfg.n_head * (cfg.n_ctx * cfg.n_ctx * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.d_model)
    out["attention"] = sum(
        out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"]
    )

    # MLP blocks
    out["mlp/ffw1"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.d_mlp)
    out["mlp/ffw2"] = 2 * cfg.n_ctx * (cfg.d_mlp * cfg.d_model)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["out_embedding"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["out_embedding"]
    out["backward_total"] = (
        2 * out["forward_total"]
    )  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out


# compare our param count to that reported by PyTorch
model_flops = compute_flops(model_cfg)
flops_total = model_flops["forward_total"]
print(f"{'name':20s} {'flops':14s} {'ratio (%)':10s}")
for k, v in model_flops.items():
    print(f"{k:20s} {v:14d} {v / flops_total * 100:10.4f}")

name                 flops          ratio (%) 
attention/kqv           15728640000     0.4485
attention/scores         3355443200     0.0957
attention/reduce         3355443200     0.0957
attention/proj           5242880000     0.1495
attention               27682406400     0.7894
mlp/ffw1                20971520000     0.5980
mlp/ffw2                20971520000     0.5980
mlp                     41943040000     1.1961
block                   69625446400     1.9855
transformer           3342021427200    95.3038
out_embedding          164682137600     4.6962
forward_total         3506703564800   100.0000
backward_total        7013407129600   200.0000
total                10520110694400   300.0000


In [126]:
# now here is an estimate copy pasted from the PaLM paper
# this formula is often used to calculate MFU (model flops utilization)
def compute_palm_flops(cfg: ModelConfig):
    """estimate of the model flops following PaLM paper formula"""
    # non-embedding model parameters. note that we do not subtract the
    # embedding/token params because those are tied and get used in the last layer.
    model_params = get_params(cfg)
    N = model_params["total"] - model_params["embedding/position"]
    L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.d_model // cfg.n_head, cfg.n_ctx
    mf_per_token = 6 * N + 12 * L * H * Q * T
    mf = mf_per_token * cfg.n_ctx
    return mf


palm_flops = compute_palm_flops(model_cfg)
print(
    f"palm_flops: {palm_flops:d}, flops: {model_flops['total']:d}, ratio: {palm_flops / model_flops['total']:.4f}"
)

palm_flops: 10524377088000, flops: 10520110694400, ratio: 1.0004


## GPU Flops Usage

Given our estimated FLOPS we can calculate how much of our GPU FLOP capacity is being used, that is our model flop utilization (MFU).

To calculate this we need a few bits of information:

- GPU speed (FLOPS)
- batch size (including gradient accumulation)
- time per iteration

For the GPU speed we refer to: https://www.techpowerup.com/gpu-specs/ which gives theoretical performance, specifically looking at performance for BF16 and FP16, which ever is supported and faster.

In [127]:
# R
GPU_FLOPS = {
    "H100": 756e12,  # 756 TFLOPS BF16 (this is a guess, spec sheet shows 1513 for sparse tensors)
    "A100": 312e12,  # 312 TFLOPS BF16
    "RTX4090": 83e12,  # 83 TFLOPS FP16
    "RTX2070": 15e12,  # 15 TFLOPS FP16
}

# TODO Change these values to desired values
batch_size = 20
grad_accum = 5
measured_time = 0.755  # in seconds per iteration

# calculate flops achieved
total_batch_size = batch_size * grad_accum
measured_throughput = total_batch_size / measured_time
flops_achieved = model_flops["total"] * measured_throughput

# the fraction of the A100 that we are using:
print("Fraction of GPU FLOPS used")
print(f"{'GPU':14s} {'ratio (%)':10s}")
for k, v in GPU_FLOPS.items():
    print(f"{k:14s} {flops_achieved / v * 100:10.2f}")

Fraction of GPU FLOPS used
GPU            ratio (%) 
H100               184.31
A100               446.60
RTX4090           1678.79
RTX2070           9289.28


## Total Training Compute

Here we use the value computed so far to compute an estimate of the total amount of compute needed to train the model.

Here we estimate the total amount of compute `C` required to train the model as `C ~= 6*N*D`, where:

- `6` is a heuristic value (see [Dzmitry Bahdanau's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4) for an explanation) 
    - it basically stems from weight multiplications requiring 6 FLOPS per token per weight when combining both forward and backward passes.
- `N` is the total number of model parameters
- `D` is total size of dataset (in tokens)

Note the equation changes to `C ~= 8*N*D` with activation checkpointing (i.e. recompute activations as needed to save memory when doing back-prop).

We also need to factor in the model flops utilization (MFU) to correct for the fact that we cannot use 100% of the GPUs FLOPS due to memory bottlenecks, etc (again see Dzmitry's blog for examples).


In [128]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
model_size = get_params(model_cfg)["total"]  # this is number of parameters, N

# TODO change these parameters
tokens_num = 300e9  # 300B tokens, this is dataset size in tokens, D
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
num_gpus = 8  # number of GPUS used in parallel

print("Time needed to train model on different GPUS")
print(f"{'GPU':10s} {'time (days)':16s}")
for gpu_name, gpu_flops in GPU_FLOPS.items():
    flops_throughput = gpu_flops * num_gpus * assumed_mfu
    flops_needed = 6 * model_size * tokens_num  # 6ND
    time_needed_s = flops_needed / flops_throughput  # in seconds
    print(f"{gpu_name:10s} {time_needed_s / 3600 / 24:10.2f}")


Time needed to train model on different GPUS
GPU        time (days)     
H100            17.88
A100            43.33
RTX4090        162.87
RTX2070        901.22
