# Transformer Sizing

This notebook runs a bunch of analysis about a GPT-2 Transformer, e.g. number of FLOPS, parameters, peak memory footprint, checkpoint size, etc

**Reference**
- This notebook is based directly on Karpathy's [nanoGPT/transformer_sizing.ipynb](https://github.com/karpathy/nanoGPT/blob/master/transformer_sizing.ipynb)

In [None]:
from gollem.models.gpt2.config import get_gpt2_model_config
from gollem.models.llama3.config import get_llama3_model_config
from gollem.gpu_stats import get_gpu_flops_for_all_gpus
from gollem.gpu_stats import get_gpu_flops
from gollem.gpu_stats import GPU_INFO

# NOTE: change things here for your model
# ----------------------------------------
# Change this to the model you want to analyze
model_name = "gpt2"
# observed checkpoint size
measured_checkpoint_size_bytes: int | None = None
# measured throughput
measured_tokens_per_second = 27000
# what GPU we used?
gpu_name = "H100"
# what precision we used? (float32, float16, bfloat16, int8, etc)
dtype = "float32"
# ----------------------------------------

if model_name.startswith("gpt"):
    model_cfg = get_gpt2_model_config(model_name)
elif model_name.startswith("llama"):
    model_cfg = get_llama3_model_config(model_name)
else:
    raise ValueError(f"Model name {model_name} not supported")

## Parameter Count

Here we look at the total number of parameters and the breakdown by component for the model.

In [None]:
model_params = model_cfg.get_params()
print(f"{'name':20s} {'params':10s} {'ratio (%)':10s}")
for k, v in model_params.per_component.items():
    print(f"{k:20s} {v:10d} {v / model_params.total * 100:10.4f}")

## Parameter/Checkpoint Size

We can now calculate the size of each checkpoint.

Noting that params are stored in fp32 (i.e. 4 bytes), and the AdamW optimizer has 2 additional buffers per param for statistics.

We can also compare this to the observed checkpoint size, if we have it.


In [None]:
params_bytes = model_params.total * 4
params_and_buffers_bytes = params_bytes + 2 * params_bytes
print(f"est checkpoint size: {params_and_buffers_bytes / 1e9:.2f} GB")

if measured_checkpoint_size_bytes is not None:
    print(f"measured with wc -c ckpt.pt: {measured_checkpoint_size_bytes}")
    print(
        f"fluff ratio: {measured_checkpoint_size_bytes / params_and_buffers_bytes * 100:.2f}%"
    )

## GPU memory usage

We can estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer for different GPU sizes.

In [None]:
# Nvidia reports memory in GiB (denoted as GB but is actually GiB)
print("GPU memory ratio taken up just for parameters (incl. optimizer)")
print(f"{'GPU':12s} {'ratio (%)':8s}")
for k, v in GPU_INFO.items():
    print(f"{k:12s} {params_and_buffers_bytes / v.memory_bytes * 100:8.2f}")


## GPU forward-backward memory usage

We can estimate the total memory usage of the model over the course of a forward-backward pass.

This is made up of:

1. M_model - the model and optimizer parameters
2. M_optimizer - the optimizer buffers
3. M_gradient - the gradient of the model
4. M_activations - the activations of the model

The activations scale with batch size and are typically the largest component of memory usage, since the others are fixed given model size and context length.

Calculating the activation memory usage is quite tricky since it is affected by low level optimizations that can be hard to estimate exactly. It is also affected by things like flash attention, dropout, activation checkpointing, etc.

Another source of memory usage is the CUDA and pytorch overhead which is typically 0.5-2GiB and will depend on the setup of the system that is running. We don't include this in the calculations, so factor this in when comparing calculated vs empirical memory usage.

NOTE: these are estimates, so really should be taken as ballpark figures rather than exact values.

In [None]:
# when using bf16 or fp16 we use mixed precision so have tostore
# both fp16 and fp32 versions of the parameters
# when using fp32 we store only the full precision fp32 parameters
bytes_per_param = 6 if dtype in ("bfloat16", "float16") else 4
M_model = model_params.total * bytes_per_param

# AdamW optimizer has 2 buffers per parameter
# values are stored in fp32
M_optimizer = 2 * model_params.total * 4

# we store one gradient value per parameter
# Gradient are stored in fp32
M_gradient = model_params.total * 4

divisor = 1024**3

print(f"M_model: {M_model / divisor:.2f} GiB")
print(f"M_optimizer: {M_optimizer / divisor:.2f} GiB")
print(f"M_gradient: {M_gradient / divisor:.2f} GiB")

# M_activations is more complex
print(f"{'batch size':10s} {'M_activations':13s} {'total memory':12s}")
for batch_size in [1, 2, 4, 8, 16, 32]:
    M_activations = model_cfg.compute_activations(batch_size, dtype)
    total_memory = M_model + M_optimizer + M_activations.total + M_gradient
    total_memory_GiB = total_memory / divisor
    print(
        f"{batch_size:10d} {M_activations.total / divisor:13.2f} {total_memory_GiB:12.2f}"
    )


## FLOPS

Now let's looks at the FLOPS used by the model and compare to the estimate from the PaLM paper.

In [None]:
model_flops = model_cfg.compute_flops()
flops_total = model_flops.forward_total
print(f"{'name':20s} {'flops':14s} {'ratio (%)':10s}")
for k, v in model_flops.per_component.items():
    print(f"{k:20s} {v:14d} {v / flops_total * 100:10.4f}")

## MFU

Given our estimated FLOPS we can calculate how much of our GPU's FLOP capacity is being used, that is our model flop utilization (MFU).

To calculate this we need 3 bits of information:

- GPU speed (FLOPS)
- observer throughput (tokens per second)
- model flops (FLOPs) per iteration

For the GPU speed we refer to: https://www.techpowerup.com/gpu-specs/ which gives theoretical performance, specifically looking at performance for BF16 and FP16, which ever is supported and faster.

In [None]:
# calculate flops achieved
# Get model flops per token (note we need to divide by context length
# to get the flops per token)
model_flops_per_token = model_cfg.compute_flops().total / model_cfg.n_ctx
gpu_flops = get_gpu_flops(gpu_name, dtype)

print(f"{'GPU':10s} {'MFU (%)':8s}")
mfu = (model_flops_per_token * measured_tokens_per_second) / gpu_flops
print(f"{gpu_name:10s} {mfu * 100:8.2f}")


## Total Training Compute

Here we use the value computed so far to compute an estimate of the total amount of compute needed to train the model.

Here we estimate the total amount of compute `C` required to train the model as `C ~= 6*N*D`, where:

- `6` is a heuristic value (see [Dzmitry Bahdanau's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4) for an explanation) 
    - it basically stems from weight multiplications requiring 6 FLOPS per token per weight when combining both forward and backward passes.
- `N` is the total number of model parameters
- `D` is total size of dataset (in tokens)

Note the equation changes to `C ~= 8*N*D` with activation checkpointing (i.e. recompute activations as needed to save memory when doing back-prop).

We also need to factor in the model flops utilization (MFU) to correct for the fact that we cannot use 100% of the GPUs FLOPS due to memory bottlenecks, etc (again see Dzmitry's blog for examples).


In [None]:
# Finally let's check out the 6ND approximation as total cost of training in FLOPs
model_size = model_cfg.get_params().total  # this is number of parameters, N

# TODO change these parameters
tokens_num = 300e9  # 300B tokens, this is dataset size in tokens, D
assumed_mfu = 0.3  # assume this model flops utilization (take the current 37% from above and add some DDP overhead)
num_gpus = 8  # number of GPUS used in parallel

print("Time needed to train model on different GPUS")
print(f"{'GPU':10s} {'time (days)':16s}")
for gpu_name, gpu_flops in get_gpu_flops_for_all_gpus(dtype).items():
    flops_throughput = gpu_flops * num_gpus * assumed_mfu
    flops_needed = 6 * model_size * tokens_num  # 6ND
    time_needed_s = flops_needed / flops_throughput  # in seconds
    print(f"{gpu_name:10s} {time_needed_s / 3600 / 24:10.2f}")
