In [None]:
import torch

In [None]:
# (a) How much peak memory does running AdamW require?
# Answer based on the memory usage of the parameters, gradients, and optimizer states.
# Answer in term of batch_size and model hyper parameters (vocab_size, context_length, num_layers, d_model, num_heads). 
# Assume d_ff = 4 * d_model

# Parameters
# Embeddings: vocab_size * d_model
# Attention Weights: 4 * d_model^2 (Q, K, V, O) per layer
# FFN Weights: 2 * (d_model * 4 * d_model) = 8 * d_model^2
# Total per Layer: 12 * d_model^2
# P = 12  * L * D^2 + V * D

# Gradients
# G = P

# Optimizer States 
# O = 2 * P

# Activations
# Linear Layers (Q, K, V, FFN expansion, FFN projection): 14 * batch_size * context_length * d_model
# Attention Matrix (B, H, T, T): 2 * B * H * T^2

# Logits 
# L = batch_size * context_length * vocab_size

# A = L * (14 * B * T * D + 2 * B * H * T^2) + B * T * V

# Total Bytes = 4 * [4(12LD^2 + VD) + L(14BTD + 2BHT^2) + BTV]

In [None]:
V = 50257  # Vocab Size
T = 1024   # Context Length
L = 48     # Num Layers
d = 1600   # Model Dim
H = 25     # Num Heads
d_ff = 4 * d  # Feedforward Dimension

# (b) Instantiate GPT-2 XL into the formula from (a)
# From part (a):
#   N = 2*V*d + L*(12*d^2 + 2*d) + d
#   M_peak = 16*N + 4*[L*(19*B*T*d + 2*B*H*T^2) + B*T*d + B*T*V]
#
# Separate into constant (independent of B) and B-dependent terms:
#   M_peak = 16*N + 4*B*T*[L*(19*d + 2*H*T) + d + V]
#          = (constant term) + (per-batch term) * B

N = 2 * V * d + L * (12 * d**2 + 2 * d) + d
constant_bytes = 16 * N  # params + grads + optimizer state
constant_gb = constant_bytes / (1024**3)

per_batch_bytes = 4 * T * (L * (19 * d + 2 * H * T**2 // T) + d + V)
per_batch_bytes = 4 * T * (L * (19 * d + 2 * H * T) + d + V)
per_batch_gb = per_batch_bytes / (1024**3)

print(f"Total Parameters N = {N:,}")
print(f"N ≈ {N/1e9:.2f} billion")
print(f"Constant term (16N):  {constant_gb:.2f} GB")
print(f"Per-batch coefficient: {per_batch_gb:.4f} GB per sample")
print(f"M_peak (GB) = {per_batch_gb:.4f} * B + {constant_gb:.2f}")

# Max batch size for 80 GB
max_memory_gb = 80
max_B = int((max_memory_gb - constant_gb) / per_batch_gb)
print(f"Max batch size for {max_memory_gb} GB: B = {max_B}")

Max Batch Size: 12284768211.92053


In [None]:
# (c) How many FLOPs does running one step of AdamW. take?
# Answer in algebraic expression & explanation
flops_per_param_adamw = 8 # 2 for gradient computation, 6 for AdamW update
total_params = 1.56e9 # 1.56 Billion
total_flops_adamw = flops_per_param_adamw * total_params
print(f"Total FLOPs for AdamW step: {total_flops_adamw / 1e12} TFLOPs")

In [None]:
# (d) 
# Effective Throughput = Peak FLOP/s * MFU
# Given Peak FLOP/s = 19.5 teraFLOP/s and MFU = 0.5
# 9.75 teraFlOP/s

# Forward Pass
# FLOPs per token 6N, parameters 1.5B 
# Forward Pass = 6 * 1.5B = 9 * 10^12 FLOPs/token

# Forward + Backward = 9 × 10¹² + 2(9 × 10¹²) = 27 × 10¹² FLOPs/token

# Total Flops = 400,000 steps * 27 × 10¹² FLOPs/token = 1.08 × 10¹⁹ FLOPs

# Time = Total Flops / Effective Throughput
#      = 1.08 × 10¹⁹ FLOPs / 9.75 × 10¹² FLOPs/s
#      = 1.108 × 10⁶ seconds
#      = 12.83 days