## Summary

This notebook aims to walk through the calculation process of memory & compute (in FLOPs). Some motivating questions:

- How long would it take to train a 70B parameter model on 15T tokens on 1024 H100s?
- What's the largest model that can you can train on 8 H100s using AdamW (naively)?


This notebook is largely inspired by [CS336 Lecture 2 - resource accounting](https://www.youtube.com/watch?v=msHyYioAyNE&list=PLoROMvodv4rOY23Y0BoGoBGgQ1zmU_MT_&index32323232=2&ab_channel=StanfordOnline).

In [1]:
import torch

## Memory Accounting

### Tensor memory

Almost everything (parameters, gradients, activations, optimizer states) are stored as floating point numbers.

Different data types a tensor could take:
- `float32` - 32bits, 1 for sign, 8 for exponent, 23 for fraction - [wiki](https://en.wikipedia.org/wiki/Single-precision_floating-point_format)
    - Also known as `fp32`, `single precision` is the default
- `float16` - 16bits, 1 for sign, 5 for exponent, 10 for fraction
    - Also known as `fp16`, `half precision`
    - Suffers from low dynamic range (especially for small numbers), causing instability in training
- `bfloat16` - 16bits, 1 for sign, 8 for exponent, 7 for fraction
    - Care more about dynamic range than fraction(same dynamic range as `float32`), proposed by Google Brain (brain floating point)
- `fp8` - introduced by NVIDIA - [doc](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html)
    - Two variation - E4M3 & E5M2


Implications on training:
- Training with `float32` works, but requires lots of memory
- Training with `fp8`, `float16`, and even `bfloat16` is risky as it's instable.
- Solution - **mixed precision training** (use lower preicision when possible)

In [6]:
def get_memory_usage(tensor):
    """
    Calculate the memory usage of a PyTorch tensor.
    
    Args:
        tensor (torch.Tensor): The tensor to calculate memory usage for.
        
    Returns:
        int: Memory usage in bytes.
    """
    return tensor.numel() * tensor.element_size()


# Memory usage of fp32 tensors
x32 = torch.zeros(4, 8) 
assert x32.dtype == torch.float32 
assert x32.numel() == 4 * 8
assert x32.element_size() == 4  # Float is 4 bytes
assert get_memory_usage(x32) == 4 * 8 * 4  # 128 bytes

# Memory usage of fp16 tensors
x16 = torch.zeros(4, 8, dtype=torch.float16)
assert x16.element_size() == 2  # Half is 2 bytes
assert get_memory_usage(x16) == 4 * 8 * 2 

x16 = torch.tensor([1e-8], dtype=torch.float16)
assert x16 == 0 # Underflow to zero

xb16 = torch.tensor([1e-8], dtype=torch.bfloat16)
assert xb16 == 1e-8 # No underflow, bfloat16 has larger dynamic range

print(torch.finfo(torch.float32))  # Float32 info
print(torch.finfo(torch.float16))  # Float16 info
print(torch.finfo(torch.bfloat16))  # BFloat16 info

finfo(resolution=1e-06, min=-3.40282e+38, max=3.40282e+38, eps=1.19209e-07, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=float32)
finfo(resolution=0.001, min=-65504, max=65504, eps=0.000976562, smallest_normal=6.10352e-05, tiny=6.10352e-05, dtype=float16)
finfo(resolution=0.01, min=-3.38953e+38, max=3.38953e+38, eps=0.0078125, smallest_normal=1.17549e-38, tiny=1.17549e-38, dtype=bfloat16)


### Example for GPT2 parameter verification

In [9]:
from transformers import GPT2Model
model = GPT2Model.from_pretrained('gpt2')
print(model)

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)


In [21]:
# Model parameters
# GPT-2 model configuration

hidden_size = 768
vocab_size = 50257

# input embedding - weight sharing with output embedding
wte_size = vocab_size * hidden_size
# position embedding
wpe_size = 1024 * hidden_size  # 1024 is the max sequence length

# transformer block
num_layers = 12
num_heads = 12
head_dim = 64

layer_norm = 2 * hidden_size # LayerNorm has two parameters: weight and bias
attn = 3 * hidden_size * hidden_size + 3 * hidden_size  # Q, K, V weights + biases
attn_proj = hidden_size * hidden_size + hidden_size  # Output linear layer weights + biases
first_fp = 4 * hidden_size * hidden_size + 4 * hidden_size  # first FC layer weights + biases
second_fp = hidden_size * hidden_size * 4 + hidden_size  #  second FC layer weights + biases

transformer = num_layers * (
    2*layer_norm +
    attn +
    attn_proj +
    first_fp +
    second_fp
)

total_params = wte_size + wpe_size + transformer
print(f"Total parameters in GPT-2: {total_params/10**6}M parameters")

Total parameters in GPT-2: 124.438272M parameters


which is the same as indicated in [GPT2 documentation](https://huggingface.co/openai-community/gpt2#:~:text=This%20is%20the%20smallest%20version%20of%20GPT%2D2%2C%20with%20124M%20parameters.).

## Compute Accounting