# Transformer Sizing

This notebook runs a bunch of analysis about a GPT-2 Transformer, e.g. number of FLOPS, parameters, peak memory footprint, checkpoint size, etc

**Reference**
- This notebook is based directly on Karpathy's [nanoGPT/transformer_sizing.ipynb](https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb)

In [1]:
from dataclasses import dataclass
from collections import OrderedDict


@dataclass
class ModelConfig:
    n_ctx: int = 1024
    n_layer: int = 12
    n_head: int = 12
    d_model: int = 768
    d_mlp: int = 4 * 768
    vocab_size: int = 50257
    ln_bias: bool = False
    mlp_bias: bool = False
    share_embd_params: bool = True


MODEL_CONFIG_ARGS = {
    # 14M params
    "gpt2-tiny": ModelConfig(
        n_ctx=128,
        n_layer=2,
        n_head=4,
        d_model=256,
        d_mlp=4 * 256,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 124M params
    "gpt2": ModelConfig(
        n_ctx=1024,
        n_layer=12,
        n_head=12,
        d_model=768,
        d_mlp=4 * 768,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 350M params
    "gpt2-medium": ModelConfig(
        n_ctx=1024,
        n_layer=24,
        n_head=16,
        d_model=1024,
        d_mlp=4 * 1024,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 774M params
    "gpt2-large": ModelConfig(
        n_ctx=1024,
        n_layer=36,
        n_head=20,
        d_model=1280,
        d_mlp=4 * 1280,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
    # 1558M params
    "gpt2-xl": ModelConfig(
        n_ctx=1024,
        n_layer=48,
        n_head=25,
        d_model=1600,
        d_mlp=4 * 1600,
        vocab_size=50257,
        ln_bias=True,
        mlp_bias=True,
        share_embd_params=True,
    ),
}


def load_config(name: str) -> ModelConfig:
    assert name in MODEL_CONFIG_ARGS
    return MODEL_CONFIG_ARGS[name]


# TODO Change this as desired
model_cfg = load_config("gpt2-tiny")

## Parameter Count

In [2]:
def get_params(cfg: ModelConfig):
    """estimates the number of parameters in the model"""
    out = OrderedDict()

    # token and position embeddings
    out["embedding/position"] = cfg.n_ctx * cfg.d_model
    out["embedding/token"] = cfg.vocab_size * cfg.d_model
    out["embedding"] = out["embedding/position"] + out["embedding/token"]

    # attention blocks
    out["attention/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["attention/kqv"] = cfg.d_model * 3 * cfg.d_model
    out["attention/proj"] = cfg.d_model**2
    out["attention"] = (
        out["attention/ln"] + out["attention/kqv"] + out["attention/proj"]
    )

    # MLP blocks
    out["mlp/ln"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["mlp/ffw"] = cfg.d_model * cfg.d_mlp + int(cfg.ln_bias) * cfg.d_mlp
    out["mlp/proj"] = cfg.d_mlp * cfg.d_model + int(cfg.ln_bias) * cfg.d_model
    out["mlp"] = out["mlp/ln"] + out["mlp/ffw"] + out["mlp/proj"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["ln_f"] = cfg.d_model + int(cfg.ln_bias) * cfg.d_model  # final layernorm
    if cfg.share_embd_params:
        # 0 because of parameter sharing. This layer uses the weights from the embedding layer
        out["out_embedding"] = 0
    else:
        out["out_embedding"] = cfg.d_model * cfg.vocab_size

    # total
    out["total"] = (
        out["embedding"] + out["transformer"] + out["ln_f"] + out["out_embedding"]
    )

    return out


# compare our param count to that reported by PyTorch (for "GPT2" with 124M params)
# TODO update the PyTorch value to include bias
model_params = get_params(model_cfg)
params_total = model_params["total"]
print(
    f"we see: {params_total}, expected: {124402944}, match: {params_total == 124337664}"
)
# create a header
print(f"{'name':20s} {'params':10s} {'ratio (%)':10s}")
for k, v in model_params.items():
    print(f"{k:20s} {v:10d} {v / params_total * 100:10.4f}")

we see: 14476544, expected: 124402944, match: False
name                 params     ratio (%) 
embedding/position        32768     0.2264
embedding/token        12865792    88.8734
embedding              12898560    89.0997
attention/ln                512     0.0035
attention/kqv            196608     1.3581
attention/proj            65536     0.4527
attention                262656     1.8144
mlp/ln                      512     0.0035
mlp/ffw                  263168     1.8179
mlp/proj                 262400     1.8126
mlp                      526080     3.6340
block                    788736     5.4484
transformer             1577472    10.8967
ln_f                        512     0.0035
out_embedding                 0     0.0000
total                  14476544   100.0000


## Parameter/Checkpoint Size

In [3]:
# we can now calculate the size of each checkpoint
# params are stored in fp32 (i.e. 4 bytes), and the AdamW optimizer has 2 additional buffers per param for statistics
params_bytes = params_total * 4
params_and_buffers_bytes = params_bytes + 2 * params_bytes
print(f"est checkpoint size: {params_and_buffers_bytes / 1e9:.2f} GB")
# TODO update this with actual measured bytes
measured_bytes = 1542470366  # from wc -c ckpt.pt
print(f"measured with wc -c ckpt.pt: {measured_bytes}")
print(f"fluff ratio: {measured_bytes / params_and_buffers_bytes * 100:.2f}%")

est checkpoint size: 0.17 GB
measured with wc -c ckpt.pt: 1542470366
fluff ratio: 887.91%


## GPU memory usage

We can estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer.

In [4]:
# Nvidia reports memory in GiB (denoted as GB but is actually GiB)
# 1 GiB = 1024 MiB = 1024**2 KiB = 1024**3 Bytes
GPU_MEMORY = {
    "H100": 80 * 1024**3,  # 80 GiB
    "A100": 40 * 1024**3,  # 40 GiB
    "RTX4090": 24 * 1024**3,  # 24 GiB
    "RTX2070": 8 * 1024**3,  # 8 GiB
}

print("GPU memory ratio taken up just for parameters (incl. optimizer)")
print(f"{'GPU':12s} {'ratio (%)':8s}")
for k, v in GPU_MEMORY.items():
    print(f"{k:12s} {params_and_buffers_bytes / v * 100:8.2f}")


GPU memory ratio taken up just for parameters (incl. optimizer)
GPU          ratio (%)
H100             0.20
A100             0.40
RTX4090          0.67
RTX2070          2.02


## FLOPS

Here we estimate FLOPS for a single forward pass.

In [5]:
def compute_flops(cfg: ModelConfig):
    # we only count Weight FLOPs,
    # FLOPS for all other layers (LayerNorm, Softmax, etc) and bias vector additian are effectively irrelevant
    # we count actual FLOPs, not MACs. Hence 2* all over the place
    # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D

    out = OrderedDict()
    head_size = cfg.d_model // cfg.n_head

    # attention blocks
    # 1) the projection to key, query, values
    out["attention/kqv"] = 2 * cfg.n_ctx * (cfg.d_model * 3 * cfg.d_model)
    # 2) calculating the attention scores
    out["attention/scores"] = 2 * cfg.n_ctx * cfg.n_ctx * cfg.d_model
    # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
    out["attention/reduce"] = 2 * cfg.n_head * (cfg.n_ctx * cfg.n_ctx * head_size)
    # 4) the final linear projection
    out["attention/proj"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.d_model)
    out["attention"] = sum(
        out["attention/" + k] for k in ["kqv", "scores", "reduce", "proj"]
    )

    # MLP blocks
    out["mlp/ffw1"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.d_mlp)
    out["mlp/ffw2"] = 2 * cfg.n_ctx * (cfg.d_mlp * cfg.d_model)
    out["mlp"] = out["mlp/ffw1"] + out["mlp/ffw2"]

    # the transformer and the rest of it
    out["block"] = out["attention"] + out["mlp"]
    out["transformer"] = cfg.n_layer * out["block"]
    out["out_embedding"] = 2 * cfg.n_ctx * (cfg.d_model * cfg.vocab_size)

    # forward,backward,total
    out["forward_total"] = out["transformer"] + out["out_embedding"]
    out["backward_total"] = (
        2 * out["forward_total"]
    )  # use common estimate of bwd = 2*fwd
    out["total"] = out["forward_total"] + out["backward_total"]

    return out


# compare our param count to that reported by PyTorch
model_flops = compute_flops(model_cfg)
flops_total = model_flops["forward_total"]
print(f"{'name':20s} {'flops':14s} {'ratio (%)':10s}")
for k, v in model_flops.items():
    print(f"{k:20s} {v:14d} {v / flops_total * 100:10.4f}")

name                 flops          ratio (%) 
attention/kqv              50331648     1.3494
attention/scores            8388608     0.2249
attention/reduce            8388608     0.2249
attention/proj             16777216     0.4498
attention                  83886080     2.2490
mlp/ffw1                   67108864     1.7992
mlp/ffw2                   67108864     1.7992
mlp                       134217728     3.5985
block                     218103808     5.8475
transformer               436207616    11.6950
out_embedding            3293642752    88.3050
forward_total            3729850368   100.0000
backward_total           7459700736   200.0000
total                   11189551104   300.0000


In [6]:
# now here is an estimate copy pasted from the PaLM paper
# this formula is often used to calculate MFU (model flops utilization)
def compute_palm_flops(cfg: ModelConfig):
    """estimate of the model flops following PaLM paper formula"""
    # non-embedding model parameters. note that we do not subtract the
    # embedding/token params because those are tied and get used in the last layer.
    model_params = get_params(cfg)
    N = model_params["total"] - model_params["embedding/position"]
    L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.d_model // cfg.n_head, cfg.n_ctx
    mf_per_token = 6 * N + 12 * L * H * Q * T
    mf = mf_per_token * cfg.n_ctx
    return mf


palm_flops = compute_palm_flops(model_cfg)
print(
    f"palm_flops: {palm_flops:d}, flops: {model_flops['total']:d}, ratio: {palm_flops / model_flops['total']:.4f}"
)

palm_flops: 11193483264, flops: 11189551104, ratio: 1.0004


## GPU Flops Usage

Given our estimated FLOPS we can calculate how much of our GPU FLOP capacity is being used, that is our model flop utilization (MFU).

To calculate this we need a few bits of information:

- GPU speed (FLOPS)
- batch size (including gradient accumulation)
- time per iteration

For the GPU speed we refer to: https://www.techpowerup.com/gpu-specs/ which gives theoretical performance, specifically looking at performance for BF16 and FP16, which ever is supported and faster.

In [7]:
# R
GPU_FLOPS = {
    "H100": 756e12,  # 756 TFLOPS BF16 (this is a guess, spec sheet shows 1513 for sparse tensors)
    "A100": 312e12,  # 312 TFLOPS BF16
    "RTX4090": 83e12,  # 83 TFLOPS FP16
    "RTX2070": 15e12,  # 15 TFLOPS FP16
}

# TODO Change these values to desired values
batch_size = 20
grad_accum = 5
measured_time = 0.755  # in seconds per iteration

# calculate flops achieved
total_batch_size = batch_size * grad_accum
measured_throughput = total_batch_size / measured_time
flops_achieved = model_flops["total"] * measured_throughput

# the fraction of the A100 that we are using:
print("Fraction of GPU FLOPS used")
print(f"{'GPU':14s} {'ratio (%)':10s}")
for k, v in GPU_FLOPS.items():
    print(f"{k:14s} {flops_achieved / v * 100:10.2f}")

Fraction of GPU FLOPS used
GPU            ratio (%) 
H100                 0.20
A100                 0.48
RTX4090              1.79
RTX2070              9.88


## Total Training Compute

Here we use the value computed so far to compute an estimate of the total amount of compute needed to train the model.

Here we estimate the total amount of compute `C` required to train the model as `C ~= 6*N*D`, where:

- `6` is a heuristic value (see [Dzmitry Bahdanau's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4) for an explanation) 
    - it basically stems from weight multiplications requiring 6 FLOPS per token per weight when combining both forward and backward passes.
- `N` is the total number of model parameters
- `D` is total size of dataset (in tokens)

Note the equation changes to `C ~= 8*N*D` with activation checkpointing (i.e. recompute activations as needed to save memory when doing back-prop).

We also need to factor in the model flops utilization (MFU) to correct for the fact that we cannot use 100% of the GPUs FLOPS due to memory bottlenecks, etc (again see Dzmitry's blog for examples).


In [8]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
model_size = get_params(model_cfg)["total"]  # this is number of parameters, N

# TODO change these parameters
tokens_num = 300e9  # 300B tokens, this is dataset size in tokens, D
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
num_gpus = 8  # number of GPUS used in parallel

print("Time needed to train model on different GPUS")
print(f"{'GPU':10s} {'time (days)':16s}")
for gpu_name, gpu_flops in GPU_FLOPS.items():
    flops_throughput = gpu_flops * num_gpus * assumed_mfu
    flops_needed = 6 * model_size * tokens_num  # 6ND
    time_needed_s = flops_needed / flops_throughput  # in seconds
    print(f"{gpu_name:10s} {time_needed_s / 3600 / 24:10.2f}")


Time needed to train model on different GPUS
GPU        time (days)     
H100             0.17
A100             0.40
RTX4090          1.51
RTX2070          8.38
