Consider GPT-2 XL, which has the following configuration:
-  vocab_size: 50,257
-   context_length: 1,024
-   num_layers: 48
-   d_model: 1,600
-   num_heads: 25
-   d_ff: 6,400

Suppose we constructed our model using this configuration. How many trainable parameters would our model have? Assuming each parameter is represented using singl-precision floating point, how much memory is required to just load this model?

In [137]:
vocab_size = 50257
context_length = 1024
num_layers = 48
d_model = 1600
num_heads = 25
d_ff = 6400

token_embeddings = vocab_size * d_model
position_embeddings = context_length * d_model
embd_layer = token_embeddings + position_embeddings

q_k_v_o = 4 * d_model * d_model
ffn = 2 * d_model * d_ff
transformer_blocks = num_layers * (q_k_v_o + ffn)

lm_head = d_model * vocab_size

total_params = embd_layer + transformer_blocks + lm_head
print(f"Total trainable params {total_params/10**9:0.2f}B")
bytes_float32 = total_params * 4
mem_fp32 = bytes_float32 / 1024 ** 3
bytes_bf16 = total_params * 2
mem_bf16 = bytes_bf16 / 1024 ** 3
print(f"Memory fp32: {mem_fp32:,.2f} GiB")
print(f"Memory bf16: {mem_bf16:,.2f} GiB")

Total trainable params 1.64B
Memory fp32: 6.10 GiB
Memory bf16: 3.05 GiB


Identify the matrix multiplies required to complete a forward pass of our GPT-2 XL-shaped model.
How many FLOPs do these matrix multiplies require in total? Assume that our input
sequence has `context_length` tokens

In [136]:
def get_overall_flops(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    b = 1
):
    F_forward_pass = get_forward_flops_TransformerLM(
        vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b
    )
    F_backward_pass = get_backward_flops_TransformerLM(
        F_forward_pass
    )
    return F_forward_pass + F_backward_pass

def get_forward_flops_TransformerLM(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    b = 1
):
    F_transformer_blocks = get_flops_TransformerBlock(
        b, context_length, d_model, num_heads, d_ff
    ) * num_layers
    F_ln_final = get_flops_RMSNorm(b, context_length, d_model)
    F_lm_head = get_flops_lm_head(b, context_length, d_model, vocab_size)
    return F_transformer_blocks + F_ln_final + F_lm_head

def get_backward_flops_TransformerLM(
    F_forward_pass
):
    F_backward_pass = 2 * F_forward_pass
    return F_backward_pass


def get_flops_lm_head(b, context_length, d_model, vocab_size):
    return 2 * b * context_length * d_model * vocab_size

def get_flops_TransformerBlock(b, context_length, d_model, num_heads, d_ff):
    F_rmsnorm_1 = get_flops_RMSNorm(b, context_length, d_model)
    F_mha = get_flops_MultiHeadAttention(b, context_length, d_model, num_heads)
    F_rmsnorm_2 = F_rmsnorm_1
    F_ffn = get_flops_FFN(b, context_length, d_model, d_ff)
    return F_rmsnorm_1 + F_mha + F_rmsnorm_2 + F_ffn

def get_flops_RMSNorm(b, context_length, d_model):
    return (3*d_model + 2) * b * context_length

def get_flops_MultiHeadAttention(b, context_length, d_model, num_heads):
    F_QKV_linear_proj = get_qkv_proj(b, context_length, d_model)
    F_attention = get_flops_atten_scores(b, context_length, d_model)
    F_softmax = get_flops_Softmax(b, context_length, num_heads)
    F_weighted = get_flops_atten_scores(b, context_length, d_model)
    F_out_proj = get_flops_output_proj(b, context_length, d_model)
    return F_QKV_linear_proj + F_attention + F_softmax + F_weighted + F_out_proj

def get_qkv_proj(b, context_length, d_model):
    return 3 * (2 * b * context_length * (d_model**2))

def get_flops_atten_scores(b, context_length, d_model):
    return 2 * b * (context_length**2) * d_model

def get_flops_Softmax(b, context_length, num_heads):
    return 5 * b * (context_length**2) * num_heads

def get_flops_output_proj(b, context_length, d_model):
    return 2 * b * context_length * d_model * d_model

def get_flops_FFN(b, context_length, d_model, d_ff):
    F_w1 = 2 * b * context_length * d_model * d_ff
    F_gelu = get_flops_GELU(b, context_length, d_ff)
    F_w2 = F_w1
    return F_w1 + F_gelu + F_w2

def get_flops_GELU(b, context_length, d_ff):
    return 5 * b * context_length * d_ff

def get_table_all_flops(
    vocab_size,
    context_length,
    num_layers,
    d_model,
    num_heads,
    d_ff,
    b = 1
):
    all_model_flops = {
        "Transformer_block": {
            "RMSNorm_1": get_flops_RMSNorm(b, context_length, d_model),
            "MultiHeadAttentio": {
                "QKV_proj": get_qkv_proj(b, context_length, d_model),
                "Attention_scores": get_flops_atten_scores(b, context_length, d_model),
                "Softmax": get_flops_Softmax(b, context_length, num_heads),
                "Attention_weighted": get_flops_atten_scores(b, context_length, d_model),
                "O_proj": get_flops_output_proj(b, context_length, d_model),
                "Overall_MultiHeadAttentio": get_flops_MultiHeadAttention(b, context_length, d_model, num_heads)
            },
            "RMSNorm_2": get_flops_RMSNorm(b, context_length, d_model),
            "FFN": get_flops_FFN(b, context_length, d_model, d_ff),
            "Overall_Transformer_block": get_flops_TransformerBlock(
                b, context_length, d_model, num_heads, d_ff
            ) * num_layers
        },
        "RMSNorm_final": get_flops_RMSNorm(b, context_length, d_model),
        "LM_head": get_flops_lm_head(b, context_length, d_model, vocab_size),
        "forward_pass_TransformerLM": get_forward_flops_TransformerLM(
            vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b
        ),
        "backward_pass_TransformerLM": get_backward_flops_TransformerLM(
            get_forward_flops_TransformerLM(
                vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b)
        ),
        "Overall_TransformerLM": get_overall_flops(
            vocab_size, context_length, num_layers, d_model, num_heads, d_ff, b
        )
    }
    return all_model_flops

In [138]:
import pandas as pd

# --- your FLOP-computing helpers remain unchanged here ------------------ #
# (get_flops_RMSNorm, get_flops_MultiHeadAttention, … get_table_all_flops)
# ----------------------------------------------------------------------- #

def explode_to_tuples(d, prefix=()):
    """Yield (path_tuple, value) pairs from a nested dict."""
    for k, v in d.items():
        path = prefix + (k,)
        if isinstance(v, dict):
            yield from explode_to_tuples(v, path)
        else:
            yield path, v

def flops_table_multiindex(
    vocab_size, context_len, n_layers, d_model, n_heads, d_ff, b=1
):
    # 1. get the nested FLOPs dict
    nested = get_table_all_flops(
        vocab_size, context_len, n_layers, d_model, n_heads, d_ff, b
    )
    overll_flops = nested['Overall_TransformerLM']
    # 2. explode to (tuple, value) pairs
    tuples, vals = zip(*explode_to_tuples(nested))

    # 3. make all tuples the same length
    max_depth = max(len(t) for t in tuples)
    padded = [t + ("",) * (max_depth - len(t)) for t in tuples]

    # 4. build the MultiIndex DataFrame
    names = [f"Level-{i+1}" for i in range(max_depth)]
    mi = pd.MultiIndex.from_tuples(padded, names=names)
    df = pd.DataFrame({"FLOPs": vals}, index=mi)

    # 5. optional pretty formatting
    df["FLOPs_fmt"] = df["FLOPs"].astype(float).map("{:.2e}".format)
    df["FLOPs_proportion"] = df["FLOPs"] / overll_flops
    return df


gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

gpt2_xl

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,4917248,4920000.0,4.663049e-07
Transformer_block,MultiHeadAttentio,QKV_proj,15728640000,15700000000.0,0.001491554
Transformer_block,MultiHeadAttentio,Attention_scores,3355443200,3360000000.0,0.0003181982
Transformer_block,MultiHeadAttentio,Softmax,131072000,131000000.0,1.242962e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,3355443200,3360000000.0,0.0003181982
Transformer_block,MultiHeadAttentio,O_proj,5242880000,5240000000.0,0.0004971847
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,27813478400,27800000000.0,0.002637565
Transformer_block,RMSNorm_2,,4917248,4920000.0,4.663049e-07
Transformer_block,FFN,,41975808000,42000000000.0,0.003980585
Transformer_block,Overall_Transformer_block,,3350357803008,3350000000000.0,0.317716


In [151]:
a = 6 * total_params * 50e3
a = 6 * num_layers * d_model*d_model * context_length + 12 * num_layers * context_length * context_length  * d_model
a / 10545134573568

0.16323569411002548

In [130]:
gpt2_small = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=12,
    d_model=768,
    n_heads=12,
    d_ff=6400,
)

gpt2_small

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,2361344,2360000.0,2e-06
Transformer_block,MultiHeadAttentio,QKV_proj,3623878656,3620000000.0,0.002887
Transformer_block,MultiHeadAttentio,Attention_scores,1610612736,1610000000.0,0.001283
Transformer_block,MultiHeadAttentio,Softmax,62914560,62900000.0,5e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,1610612736,1610000000.0,0.001283
Transformer_block,MultiHeadAttentio,O_proj,1207959552,1210000000.0,0.000962
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,8115978240,8120000000.0,0.006465
Transformer_block,RMSNorm_2,,2361344,2360000.0,2e-06
Transformer_block,FFN,,20165427200,20200000000.0,0.016062
Transformer_block,Overall_Transformer_block,,339433537536,339000000000.0,0.270368


In [131]:
gpt2_medium = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=24,
    d_model=1024,
    n_heads=16,
    d_ff=6400,
)

gpt2_medium

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,3147776,3150000.0,9.881577e-07
Transformer_block,MultiHeadAttentio,QKV_proj,6442450944,6440000000.0,0.00202243
Transformer_block,MultiHeadAttentio,Attention_scores,2147483648,2150000000.0,0.0006741434
Transformer_block,MultiHeadAttentio,Softmax,83886080,83900000.0,2.633373e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,2147483648,2150000000.0,0.0006741434
Transformer_block,MultiHeadAttentio,O_proj,2147483648,2150000000.0,0.0006741434
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,12968787968,13000000000.0,0.004071194
Transformer_block,RMSNorm_2,,3147776,3150000.0,9.881577e-07
Transformer_block,FFN,,26876313600,26900000000.0,0.008437079
Transformer_block,Overall_Transformer_block,,956433530880,956000000000.0,0.300246


In [132]:
gpt2_large = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=36,
    d_model=1280,
    n_heads=20,
    d_ff=6400,
)

gpt2_large

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,3934208,3930000.0,6.487582e-07
Transformer_block,MultiHeadAttentio,QKV_proj,10066329600,10100000000.0,0.001659956
Transformer_block,MultiHeadAttentio,Attention_scores,2684354560,2680000000.0,0.0004426551
Transformer_block,MultiHeadAttentio,Softmax,104857600,105000000.0,1.729121e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,2684354560,2680000000.0,0.0004426551
Transformer_block,MultiHeadAttentio,O_proj,3355443200,3360000000.0,0.0005533188
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,18895339520,18900000000.0,0.003115877
Transformer_block,RMSNorm_2,,3934208,3930000.0,6.487582e-07
Transformer_block,FFN,,33587200000,33600000000.0,0.005538592
Transformer_block,Overall_Transformer_block,,1889654685696,1890000000000.0,0.3116076


In [134]:
gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=1024,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

gpt2_xl

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,4917248,4920000.0,4.663049e-07
Transformer_block,MultiHeadAttentio,QKV_proj,15728640000,15700000000.0,0.001491554
Transformer_block,MultiHeadAttentio,Attention_scores,3355443200,3360000000.0,0.0003181982
Transformer_block,MultiHeadAttentio,Softmax,131072000,131000000.0,1.242962e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,3355443200,3360000000.0,0.0003181982
Transformer_block,MultiHeadAttentio,O_proj,5242880000,5240000000.0,0.0004971847
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,27813478400,27800000000.0,0.002637565
Transformer_block,RMSNorm_2,,4917248,4920000.0,4.663049e-07
Transformer_block,FFN,,41975808000,42000000000.0,0.003980585
Transformer_block,Overall_Transformer_block,,3350357803008,3350000000000.0,0.317716


In [135]:
gpt2_xl = flops_table_multiindex(
    vocab_size=50257,
    context_len=16384,
    n_layers=48,
    d_model=1600,
    n_heads=25,
    d_ff=6400,
)

gpt2_xl

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,FLOPs,FLOPs_fmt,FLOPs_proportion
Level-1,Level-2,Level-3,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Transformer_block,RMSNorm_1,,78675968,78700000.0,1.941752e-07
Transformer_block,MultiHeadAttentio,QKV_proj,251658240000,252000000000.0,0.000621102
Transformer_block,MultiHeadAttentio,Attention_scores,858993459200,859000000000.0,0.002120028
Transformer_block,MultiHeadAttentio,Softmax,33554432000,33600000000.0,8.28136e-05
Transformer_block,MultiHeadAttentio,Attention_weighted,858993459200,859000000000.0,0.002120028
Transformer_block,MultiHeadAttentio,O_proj,83886080000,83900000000.0,0.000207034
Transformer_block,MultiHeadAttentio,Overall_MultiHeadAttentio,2087085670400,2090000000000.0,0.005151006
Transformer_block,RMSNorm_2,,78675968,78700000.0,1.941752e-07
Transformer_block,FFN,,671612928000,672000000000.0,0.001657566
Transformer_block,Overall_Transformer_block,,132425085616128,132000000000000.0,0.3268301
