Consider GPT-2 XL, which has the following configuration:
-  vocab_size: 50,257
-   context_length: 1,024
-   num_layers: 48
-   d_model: 1,600
-   num_heads: 25
-   d_ff: 6,400

Suppose we constructed our model using this configuration. How many trainable parameters would our model have? Assuming each parameter is represented using singl-precision floating point, how much memory is required to just load this model?

In [204]:
vocab_size = 50257
context_len = 1024
num_layers = 48
d_model = 1600
num_heads = 25
d_ff = 6400

token_embeddings = vocab_size * d_model
position_embeddings = context_len * d_model
embd_layer = token_embeddings + position_embeddings

q_k_v_o = 4 * d_model * d_model
ffn = 2 * d_model * d_ff
transformer_blocks = num_layers * (q_k_v_o + ffn)

lm_head = d_model * vocab_size

total_params = embd_layer + transformer_blocks + lm_head
print(f"Total trainable params {total_params/10**9:0.2f}B")
bytes_float32 = total_params * 4
mem_fp32 = bytes_float32 / 1024 ** 3
bytes_bf16 = total_params * 2
mem_bf16 = bytes_bf16 / 1024 ** 3
print(f"Memory fp32: {mem_fp32:,.2f} GiB")
print(f"Memory bf16: {mem_bf16:,.2f} GiB")

Total trainable params 1.64B
Memory fp32: 6.10 GiB
Memory bf16: 3.05 GiB


Identify the matrix multiplies required to complete a forward pass of our GPT-2 XL-shaped model.
How many FLOPs do these matrix multiplies require in total? Assume that our input
sequence has `context_length` tokens

In [205]:
def get_overall_flops(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    b
):
    F_forward_pass = get_forward_flops_TransformerLM(
        vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b
    )
    F_adamw = get_adamw_flops(
        vocab_size,
        context_length,
        num_layers,
        d_model,
        d_ff
    )
    F_backward_pass = get_backward_flops_TransformerLM(
        F_forward_pass
    )
    return F_forward_pass + F_backward_pass + F_adamw

def get_forward_flops_TransformerLM(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    b
):
    F_transformer_blocks = get_flops_TransformerBlock(
        b, context_length, d_model, num_heads, d_ff
    ) * num_layers
    F_ln_final = get_flops_RMSNorm(b, context_length, d_model)
    F_lm_head = get_flops_lm_head(b, context_length, d_model, vocab_size)
    return F_transformer_blocks + F_ln_final + F_lm_head

def get_adamw_flops(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    d_ff
):
    # beta_1 * m + (1 - beta_1) * grad # 2 FLOPs
    # beta_2 * v + (1 - beta_2) * (grad**2) # 3  FLOPs
    # m / (1 - beta_1**t) # 1 FLOPs
    # v / (1 - beta_2**t) # 1 FLOPs
    #  (torch.sqrt(v) + eps) # 2 FLOPs
    # m / (torch.sqrt(v) + eps) # 1 FLOPs
    # p - lr * step - lr * wd * p # 4 FLOPs

    # TOTAL CONSTANT FLOPs = 16 
    total_params = (
        2 * vocab_size * d_model  # token and lm head embeddings
        + context_length * d_model  # positional embedding
        + num_layers * (4 * d_model**2 + 2 * d_model * d_ff) # Transformer blocks: attention + FFN
    )
    return 16 * total_params

def get_backward_flops_TransformerLM(
    F_forward_pass
):
    F_backward_pass = 2 * F_forward_pass
    return F_backward_pass


def get_flops_lm_head(b, context_length, d_model, vocab_size):
    return 2 * b * context_length * d_model * vocab_size

def get_flops_TransformerBlock(b, context_length, d_model, num_heads, d_ff):
    F_rmsnorm_1 = get_flops_RMSNorm(b, context_length, d_model)
    F_mha = get_flops_MultiHeadAttention(b, context_length, d_model, num_heads)
    F_rmsnorm_2 = F_rmsnorm_1
    F_ffn = get_flops_FFN(b, context_length, d_model, d_ff)
    return F_rmsnorm_1 + F_mha + F_rmsnorm_2 + F_ffn

def get_flops_RMSNorm(b, context_length, d_model):
    return (3*d_model + 2) * b * context_length

def get_flops_MultiHeadAttention(b, context_length, d_model, num_heads):
    F_QKV_linear_proj = get_qkv_proj(b, context_length, d_model)
    F_attention = get_flops_atten_scores(b, context_length, d_model)
    F_softmax = get_flops_Softmax(b, context_length, num_heads)
    F_weighted = get_flops_atten_scores(b, context_length, d_model)
    F_out_proj = get_flops_output_proj(b, context_length, d_model)
    return F_QKV_linear_proj + F_attention + F_softmax + F_weighted + F_out_proj

def get_qkv_proj(b, context_length, d_model):
    return 3 * (2 * b * context_length * (d_model**2))

def get_flops_atten_scores(b, context_length, d_model):
    return 2 * b * (context_length**2) * d_model

def get_flops_Softmax(b, context_length, num_heads):
    return 5 * b * (context_length**2) * num_heads

def get_flops_output_proj(b, context_length, d_model):
    return 2 * b * context_length * d_model * d_model

def get_flops_FFN(b, context_length, d_model, d_ff):
    F_w1 = 2 * b * context_length * d_model * d_ff
    F_gelu = get_flops_GELU(b, context_length, d_ff)
    F_w2 = F_w1
    return F_w1 + F_gelu + F_w2

def get_flops_GELU(b, context_length, d_ff):
    return 5 * b * context_length * d_ff

def get_table_all_flops(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    b = 1
):
    all_model_flops = {
        "Transformer_block": {
            "RMSNorm_1": get_flops_RMSNorm(b, context_length, d_model),
            "MultiHeadAttentio": {
                "QKV_proj": get_qkv_proj(b, context_length, d_model),
                "Attention_scores": get_flops_atten_scores(b, context_length, d_model),
                "Softmax": get_flops_Softmax(b, context_length, num_heads),
                "Attention_weighted": get_flops_atten_scores(b, context_length, d_model),
                "O_proj": get_flops_output_proj(b, context_length, d_model),
                "Overall_MultiHeadAttentio": get_flops_MultiHeadAttention(b, context_length, d_model, num_heads)
            },
            "RMSNorm_2": get_flops_RMSNorm(b, context_length, d_model),
            "FFN": get_flops_FFN(b, context_length, d_model, d_ff),
            "Overall_Transformer_block": get_flops_TransformerBlock(
                b, context_length, d_model, num_heads, d_ff
            ) * num_layers
        },
        "RMSNorm_final": get_flops_RMSNorm(b, context_length, d_model),
        "LM_head": get_flops_lm_head(b, context_length, d_model, vocab_size),
        "forward_pass_TransformerLM": get_forward_flops_TransformerLM(
            vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b
        ),
        "backward_pass_TransformerLM": get_backward_flops_TransformerLM(
            get_forward_flops_TransformerLM(
                vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b)
        ),
        "AdamW-step": get_adamw_flops(
        vocab_size, context_length, num_layers, d_model, d_ff
        ),
        "Overall_TransformerLM": get_overall_flops(
            vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b
        )
    }
    return all_model_flops

In [206]:
import pandas as pd

# --- your FLOP-computing helpers remain unchanged here ------------------ #
# (get_flops_RMSNorm, get_flops_MultiHeadAttention, … get_table_all_flops)
# ----------------------------------------------------------------------- #

def explode_to_tuples(d, prefix=()):
    """Yield (path_tuple, value) pairs from a nested dict."""
    for k, v in d.items():
        path = prefix + (k,)
        if isinstance(v, dict):
            yield from explode_to_tuples(v, path)
        else:
            yield path, v

def flops_table_multiindex(
    vocab_size, context_len, n_layers, d_model, n_heads, d_ff, b=1
):
    # 1. get the nested FLOPs dict
    nested = get_table_all_flops(
        vocab_size, context_len, n_layers, d_model, n_heads, d_ff, b
    )
    overll_flops = nested['Overall_TransformerLM']
    # 2. explode to (tuple, value) pairs
    tuples, vals = zip(*explode_to_tuples(nested))

    # 3. make all tuples the same length
    max_depth = max(len(t) for t in tuples)
    padded = [t + ("",) * (max_depth - len(t)) for t in tuples]

    # 4. build the MultiIndex DataFrame
    names = [f"Level-{i+1}" for i in range(max_depth)]
    mi = pd.MultiIndex.from_tuples(padded, names=names)
    df = pd.DataFrame({"FLOPs": vals}, index=mi)

    # 5. optional pretty formatting
    df["FLOPs_fmt"] = df["FLOPs"].astype(float).map("{:.2e}".format)
    df["FLOPs_proportion"] = df["FLOPs"] / overll_flops
    return df


gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

gpt2_xl

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,4917248,4920000.0,4.651496e-07
Transformer_block,MultiHeadAttentio,QKV_proj,15728640000,15700000000.0,0.001487859
Transformer_block,MultiHeadAttentio,Attention_scores,3355443200,3360000000.0,0.0003174098
Transformer_block,MultiHeadAttentio,Softmax,131072000,131000000.0,1.239882e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,3355443200,3360000000.0,0.0003174098
Transformer_block,MultiHeadAttentio,O_proj,5242880000,5240000000.0,0.0004959529
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,27813478400,27800000000.0,0.00263103
Transformer_block,RMSNorm_2,,4917248,4920000.0,4.651496e-07
Transformer_block,FFN,,41975808000,42000000000.0,0.003970723
Transformer_block,Overall_Transformer_block,,3350357803008,3350000000000.0,0.3169288


In [207]:
gpt2_small = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=12,
    d_model=768,
    n_heads=12,
    d_ff=6400,
)

gpt2_small

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,2361344,2360000.0,2e-06
Transformer_block,MultiHeadAttentio,QKV_proj,3623878656,3620000000.0,0.002878
Transformer_block,MultiHeadAttentio,Attention_scores,1610612736,1610000000.0,0.001279
Transformer_block,MultiHeadAttentio,Softmax,62914560,62900000.0,5e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,1610612736,1610000000.0,0.001279
Transformer_block,MultiHeadAttentio,O_proj,1207959552,1210000000.0,0.000959
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,8115978240,8120000000.0,0.006446
Transformer_block,RMSNorm_2,,2361344,2360000.0,2e-06
Transformer_block,FFN,,20165427200,20200000000.0,0.016017
Transformer_block,Overall_Transformer_block,,339433537536,339000000000.0,0.269598


In [208]:
gpt2_medium = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=24,
    d_model=1024,
    n_heads=16,
    d_ff=6400,
)

gpt2_medium

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,3147776,3150000.0,9.855874e-07
Transformer_block,MultiHeadAttentio,QKV_proj,6442450944,6440000000.0,0.00201717
Transformer_block,MultiHeadAttentio,Attention_scores,2147483648,2150000000.0,0.0006723899
Transformer_block,MultiHeadAttentio,Softmax,83886080,83900000.0,2.626523e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,2147483648,2150000000.0,0.0006723899
Transformer_block,MultiHeadAttentio,O_proj,2147483648,2150000000.0,0.0006723899
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,12968787968,13000000000.0,0.004060605
Transformer_block,RMSNorm_2,,3147776,3150000.0,9.855874e-07
Transformer_block,FFN,,26876313600,26900000000.0,0.008415134
Transformer_block,Overall_Transformer_block,,956433530880,956000000000.0,0.299465


In [209]:
gpt2_large = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=36,
    d_model=1280,
    n_heads=20,
    d_ff=6400,
)

gpt2_large

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,3934208,3930000.0,6.471264e-07
Transformer_block,MultiHeadAttentio,QKV_proj,10066329600,10100000000.0,0.001655781
Transformer_block,MultiHeadAttentio,Attention_scores,2684354560,2680000000.0,0.0004415417
Transformer_block,MultiHeadAttentio,Softmax,104857600,105000000.0,1.724772e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,2684354560,2680000000.0,0.0004415417
Transformer_block,MultiHeadAttentio,O_proj,3355443200,3360000000.0,0.0005519271
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,18895339520,18900000000.0,0.003108039
Transformer_block,RMSNorm_2,,3934208,3930000.0,6.471264e-07
Transformer_block,FFN,,33587200000,33600000000.0,0.005524661
Transformer_block,Overall_Transformer_block,,1889654685696,1890000000000.0,0.3108238


In [210]:
gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

gpt2_xl

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,4917248,4920000.0,4.651496e-07
Transformer_block,MultiHeadAttentio,QKV_proj,15728640000,15700000000.0,0.001487859
Transformer_block,MultiHeadAttentio,Attention_scores,3355443200,3360000000.0,0.0003174098
Transformer_block,MultiHeadAttentio,Softmax,131072000,131000000.0,1.239882e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,3355443200,3360000000.0,0.0003174098
Transformer_block,MultiHeadAttentio,O_proj,5242880000,5240000000.0,0.0004959529
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,27813478400,27800000000.0,0.00263103
Transformer_block,RMSNorm_2,,4917248,4920000.0,4.651496e-07
Transformer_block,FFN,,41975808000,42000000000.0,0.003970723
Transformer_block,Overall_Transformer_block,,3350357803008,3350000000000.0,0.3169288


In [211]:
gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=16384,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

gpt2_xl

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,78675968,78700000.0,1.941625e-07
Transformer_block,MultiHeadAttentio,QKV_proj,251658240000,252000000000.0,0.0006210612
Transformer_block,MultiHeadAttentio,Attention_scores,858993459200,859000000000.0,0.002119889
Transformer_block,MultiHeadAttentio,Softmax,33554432000,33600000000.0,8.280816e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,858993459200,859000000000.0,0.002119889
Transformer_block,MultiHeadAttentio,O_proj,83886080000,83900000000.0,0.0002070204
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,2087085670400,2090000000000.0,0.005150668
Transformer_block,RMSNorm_2,,78675968,78700000.0,1.941625e-07
Transformer_block,FFN,,671612928000,672000000000.0,0.001657457
Transformer_block,Overall_Transformer_block,,132425085616128,132000000000000.0,0.3268086


---

How much peak memory does running AdamW require? Decompose your answer based on the memory usage of the parameters, activations, gradients, and optimizer state. Express your answer
in terms of the batch_size and the model hyperparameters (vocab_size, context_length,
num_layers, d_model, num_heads). Assume d_ff = 4 ×d_model.
For simplicity, when calculating memory usage of activations, consider only the following com-
ponents:
* Transformer block
    - RMSNorm(s)
    - Multi-head self-attention sublayer: QKV projections, QKT matrix multiply, softmax, weighted sum of values, output projection.
    - Position-wise feed-forward: W1 matrix multiply, GELU, W2 matrix multiply
* final RMSNorm
* output embedding
* cross-entropy on logits
Deliverable: An algebraic expression for each of parameters, activations, gradients, and opti-
mizer state, as well as the total.

In [212]:
def get_constant_params(
    vocab_size,
    context_len,
    num_layers,
    d_model,
    d_ff
):
  token_embeddings = vocab_size * d_model
  position_embeddings = context_len * d_model
  embd_layer = token_embeddings + position_embeddings

  # Transformer parameter memory:
  q_k_v_o = 4 * d_model * d_model
  ffn = 2 * d_model * d_ff
  transformer_blocks = num_layers * (q_k_v_o + ffn)

  # lm head parameter:
  lm_head = d_model * vocab_size

  # Total trainable parameters:
  total_params = embd_layer + transformer_blocks + lm_head

  return total_params

def get_activations_per_seq(
    vocab_size,
    context_len,
    num_layers,
    d_model,
    num_heads,
    d_ff,
):
  # RMSNorm
  # Each token goes through this
  rmsnorma_act = context_len * d_model

  # Q, K, V
  # q, k, v projections
  qkv_act = 3 * context_len * d_model

  # scaled dot product attention:
  pre_softmax_act = num_heads * context_len * context_len

  # Concatenate the heads and pass to the next sub-layer:
  attn_weight_act = context_len * d_model

  # FFN:
  # for h1 = W1 @ x -> need to backprop through GELU
  h1_act = context_len * d_ff
  # for h2 = GELU(h1) -> need to backprop through W2
  h2_act = context_len * d_ff
  # total
  ffn_act = h1_act + h2_act

  # Total per-layer activations per sequence
  per_layer_act = rmsnorma_act + qkv_act + pre_softmax_act + attn_weight_act + ffn_act

  # Total for all layers, then add final layer head
  transformer_activations = num_layers * per_layer_act

  final_rmsnorm = context_len * d_model
  logits = context_len * vocab_size

  # Total activations per sequence
  activations_per_sequence = transformer_activations + final_rmsnorm + logits
  return activations_per_sequence


def get_constant_memory(
    vocab_size,
    context_len,
    num_layers,
    d_model,
    d_ff,
    dtype_size
  ):
  # --- Parameter memory ---
  total_params = get_constant_params(
      vocab_size=vocab_size,
      context_len=context_len,
      num_layers=num_layers,
      d_model=d_model,
      d_ff=d_ff,
  )

  # --- Gradients ---:
  gradients = total_params

  # --- Optimizer state (m and v per parameter) --- 
  AdamW_state = 2 * total_params

  # --- Constant Memory ---:
  total_constant = total_params + gradients + AdamW_state
  total_mem_in_bytes = total_constant * dtype_size
  return total_mem_in_bytes

def get_activations_memory(
  vocab_size,
  context_len,
  num_layers,
  d_model,
  num_heads,
  d_ff,
  dtype_size
  ):
  # --- Activation per sequence ---
  activations = get_activations_per_seq(
      vocab_size=vocab_size,
      context_len=context_len,
      num_layers=num_layers,
      d_model=d_model,
      num_heads=num_heads,
      d_ff = d_ff,
  )
  total_mem_in_bytes = activations * dtype_size
  return total_mem_in_bytes

def calc_mem(
    vocab_size,
    context_len,
    num_layers,
    d_model,
    num_heads,
    d_ff = 4*d_model,
    batch_size = 1,
    dtype_size=4
):

  constant_memory = get_constant_memory(
    vocab_size=vocab_size,
    context_len=context_len,
    num_layers=num_layers,
    d_model=d_model,
    d_ff=d_ff,
    dtype_size=dtype_size
  )
  
  activation_memory = get_activations_memory(
    vocab_size,
    context_len,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    dtype_size
  )

  # --- Memory totals in bytes --- 
  total_mem = activation_memory * batch_size + constant_memory
  constant_memory = constant_memory / (1024 ** 3)
  activation_memory = activation_memory / (1024 ** 3)
  total_mem_in_gib = total_mem / (1024 ** 3)

  print(f"Total Constant Memory: {constant_memory:.2f} GiB ({'fp32' if dtype_size==4 else 'bf16'})")
  print(f"Total Activation Memory: {activation_memory:.2f} GiB ({'fp32' if dtype_size==4 else 'bf16'})")
  print(f"Total Memory (a * batch_size + b): {total_mem_in_gib:.2f} GiB ({'fp32' if dtype_size==4 else 'bf16'})")


In [213]:
calc_mem(
    vocab_size=50257,
    context_len=1024,
    num_layers=48,
    d_model=1600,
    d_ff = 4*d_model,
    num_heads=25,
    batch_size=1,
    dtype_size=2
)

Total Constant Memory: 12.20 GiB (bf16)
Total Activation Memory: 4.35 GiB (bf16)
Total Memory (a * batch_size + b): 16.54 GiB (bf16)


Instantiate your answer for a GPT-2 XL-shaped model to get an expression that only depends on
the batch_size. What is the maximum batch size you can use and still fit within 80GB memory?
Deliverable: An expression that looks like a ·batch_size + b for numerical values a, b, and a
number representing the maximum batch size

In [214]:
calc_mem(
    vocab_size=50257,
    context_len=1024,
    num_layers=48,
    d_model=1600,
    d_ff = 4*d_model,
    num_heads=25,
    batch_size=15,
    dtype_size=2
)

Total Constant Memory: 12.20 GiB (bf16)
Total Activation Memory: 4.35 GiB (bf16)
Total Memory (a * batch_size + b): 77.40 GiB (bf16)


How many FLOPs does running one step of AdamW take?

In [231]:
adam_flops = 16 * (
    2 * vocab_size * d_model # tok embedding + lm head
    + context_len * d_model # position embedding
    + num_layers * (4 * d_model**2 + 2 * d_model * d_ff) # Attention + FFN
)
f"{adam_flops/1e12:.4f} TeraFLOPs"

'0.0262 TeraFLOPs'

Model FLOPs utilization (MFU) is defined as the ratio of observed throughput (tokens per second)
relative to the hardware’s theoretical peak FLOP throughput [Chowdhery et al., 2022]. An
NVIDIA A100 GPU has a theoretical peak of 19.5 teraFLOP/s for float32 operations. Assuming
you are able to get 50% MFU, how long would it take to train a GPT-2 XL for 400K steps and a
batch size of 1024 on a single A100? Following Kaplan et al. [2020] and Hoffmann et al. [2022],
assume that the backward pass has twice the FLOPs of the forward pass.

The number of days training would take, with a brief justification.

In [262]:
gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

total_transformer_flops_per_step = gpt2_xl["FLOPs"]["Overall_TransformerLM"].values[0]
print("Total transformer flops per step:")
print(f"{total_transformer_flops_per_step / 10e12:.4f} TeraFLOPs")
print("Total transformer flops over 400K steps:")
total_transformer_flops = 400*1e3 * total_transformer_flops_per_step
print(f"{total_transformer_flops / 10e12:,.4f} TeraFLOPs")
print("Total training time with NVIDIA A100 GPU:")
seconds = total_transformer_flops / (19.5*1e12 * 0.5)
days = seconds / (60 * 60 * 24)
print(f"{days:.2f} days")


Total transformer flops per step:
1.0571 TeraFLOPs
Total transformer flops over 400K steps:
422,853.0763 TeraFLOPs
Total training time with NVIDIA A100 GPU:
5.02 days
