In [1]:
import torch
import os
from transformers import AutoConfig, AutoModelForCausalLM
import math
import pandas as pd

### GPT-2

From https://huggingface.co/docs/transformers/en/model_doc/gpt2#transformers.GPT2Config we get the following information about GPT-2:

(vocab_size = 50257, n_positions = 1024, n_embd = 768, n_layer = 12, n_head = 12, ...).

**We can use this information to define the parameters.**

**Note that ideally we need to load the model and extract parameters from there. But I decided to make it less time consuming, and just found the information about the parameters (moreover, it might take a lot of time and memory to load all these models).**

In [2]:
bytes_per_param = 2 # 16 bits = 2 bytes

L = 1024
V = 50257
d = 768
N_layer = 12
d_ff = 3072 # Found in the internet, 768 * 4
n_head = 12
head_dim = 64 # d_gpt2 / n_head_gpt2



n_h = 12
n_kv = 4    # For GQA
d_ff = 3072 # FFN intermediate size (GPT-2)


E = 2       # Expansion ratio (Mamba-2)
N = 128     # State size (Mamba-2)
d_k = 64    # Key dimension (GLA)
d_v = 64    # Value dimension (GLA)

In [3]:
# We will ignore small terms such as LayerNorms, biases, etc. Even though it will result in a tiny fraction of the parameters, but it will simplify the calculations.

# Embeddings + position embeddings
param_embed = V * d
param_pos_embed = L * d

# Each layer has:
#   - QKV projection: 3 x (d x d)
#   - output projection: (d x d)
#   - FFN: first weight (d x d_ff)
#          second weight (d_ff x d)
param_per_layer = 3 * d * d + 1 * d * d + d * d_ff + d_ff * d
# total parameters
param_total = param_embed + param_pos_embed + N_layer * param_per_layer
# convert to MiB
size_bytes = param_total * bytes_per_param
size_MiB   = size_bytes / (1024**2)

print(f"Total parameters:  {param_total:,.0f}")
print(f"Model size (bfloat16): {size_MiB:,.1f} MiB")

Total parameters:  124,318,464
Model size (bfloat16): 237.1 MiB


In [4]:
# since the batch size is 1 on inference, the number of cache elements either for k or v per layer is number of attention heads times the head dimension times the sequence length
# then we multiply it by the number of layers and 2 (making sure we calculate the elements both for k and v)
kv_elements = 2 * N_layer * n_head * L * head_dim
kv_bytes = kv_elements * bytes_per_param
kv_MiB = kv_bytes / (1024**2)

print(f"KV cache elements: {kv_elements:,}")
print(f"KV cache size (bfloat16): {kv_MiB:,.1f} MiB")

KV cache elements: 18,874,368
KV cache size (bfloat16): 36.0 MiB


In [5]:
flops_per_layer = (
    6 * L * d**2        # QKV proj
  + 4 * d * L**2        # attention core
  + 2 * L * d**2        # output proj
  + 4 * L * d * d_ff    # FFN
)

# LM-head (tied embeddings): final logits = 2·L·d·V
flops_lm_head = 2 * L * d * V

# total
total_flops = N_layer * flops_per_layer + flops_lm_head

# convert to GFLOPs
gflops = total_flops / 1e9

print(f"FLOPs per layer: {flops_per_layer/1e9:.1f} GFLOPs")
print(f"LM-head FLOPs: {flops_lm_head/1e9:.1f} GFLOPs")
print(f"Total FLOPs: {gflops:.1f} GFLOPs")

FLOPs per layer: 17.7 GFLOPs
LM-head FLOPs: 79.0 GFLOPs
Total FLOPs: 291.6 GFLOPs


In [6]:
total_flops


291648307200

### GPT-2 with GQA

In [7]:
G = 4 # number of KV groups in GQA

embed_params = V * d
pos_emb_params = L * d

# GQA GPT-2 model size
q_params = d * d
kv_params_gqa = 2 * d * (G * head_dim)
o_params = d * d
ffn_params = d * d_ff + d_ff * d
layer_params_gqa = q_params + kv_params_gqa + o_params + ffn_params
total_params_gqa = embed_params + pos_emb_params + N_layer * layer_params_gqa
model_size_gqa = total_params_gqa * bytes_per_param
model_size_gqa_mib = model_size_gqa / (2**20)

# KV-Cache Size (bytes → MiB)
kv_cache_gqa = L * N_layer * G * head_dim * 2 * bytes_per_param
kv_cache_gqa_mib = kv_cache_gqa / (2**20)

# FLOPs per Forward Pass (→ GFLOPs)
attn_score_flops = 2 * n_head * L * L * head_dim
attn_apply_flops = 2 * n_head * L * L * head_dim
ffn_flops = 2 * L * d * d_ff + 2 * L * d_ff * d
# GQA projections: Q,O full; K,V reduced
proj_flops_gqa = (2 * L * d * d) + 2 * (2 * L * d * (G * head_dim)) + (2 * L * d * d)
flops_per_layer_gqa = proj_flops_gqa + attn_score_flops + attn_apply_flops + ffn_flops
total_flops_gqa = N_layer * flops_per_layer_gqa + 2 * L * d * V
total_flops_gqa_g = total_flops_gqa / 1e9

comparison = pd.DataFrame({
    "Model": ["GPT-2 Small", f"GPT-2 Small + GQA (G={G})"],
    "Total Parameters": [param_total, total_params_gqa],
    "Model Size (MiB)": [size_MiB, model_size_gqa_mib],
    "KV-Cache (MiB)": [kv_MiB, kv_cache_gqa_mib],
    "FLOPs per Forward (GFLOPs)": [gflops, total_flops_gqa_g]
})

comparison

Unnamed: 0,Model,Total Parameters,Model Size (MiB),KV-Cache (MiB),FLOPs per Forward (GFLOPs)
0,GPT-2 Small,124318464,237.118652,36.0,291.648307
1,GPT-2 Small + GQA (G=4),114881280,219.118652,12.0,272.320954


### Mamba-2

In [None]:
L = 1024
V = 50257
d = 768
N_layer = 12
E = 2
bfloat16_bytes = 2

# Model Size (parameters and MiB)
# P_total = V*d + N_layer * [3*d*d (QKV) + d*d (output) + 2*(d * (E*d)) (MLP)]
param_count = V*d + N_layer*(4*d*d + 2*d*(E*d))
model_size_bytes = param_count * bfloat16_bytes
model_size_mib = model_size_bytes / (2**20)

# KV‑Cache Size per sequence (MiB)
# KV elements = N_layer * L * 2 * d
kv_elements = N_layer * L * 2 * d
kv_bytes = kv_elements * bfloat16_bytes
kv_mib = kv_bytes / (2**20)

# FLOPs per forward (GFLOPs)
# Attention: 8*L*d^2 + 4*L^2*d per layer
# MLP: 4*E*L*d^2 per layer
attn_flops_per_layer = 8 * L * d**2 + 4 * L**2 * d
mlp_flops_per_layer = 4 * E * L * d**2
total_flops = N_layer * (attn_flops_per_layer + mlp_flops_per_layer)
total_gflops = total_flops / 1e9

comparison_new_row = ['Mamba-2', param_count, model_size_mib, kv_mib, total_gflops]
comparison = pd.concat([comparison, pd.DataFrame([comparison_new_row], columns=comparison.columns)], ignore_index=True)

print("Total parameters: {:,}".format(param_count))
print("Model size: {:.2f} MiB".format(model_size_mib))
print("KV-cache size: {:.2f} MiB".format(kv_mib))
print("FLOPs per forward: {:.1f} GFLOPs".format(total_gflops))

Total parameters: 95,220,480
Model size: 181.62 MiB
KV-cache size: 36.00 MiB
FLOPs per forward: 154.6 GFLOPs


In [9]:
comparison

Unnamed: 0,Model,Total Parameters,Model Size (MiB),KV-Cache (MiB),FLOPs per Forward (GFLOPs)
0,GPT-2 Small,124318464,237.118652,36.0,291.648307
1,GPT-2 Small + GQA (G=4),114881280,219.118652,12.0,272.320954
2,Mamba-2,95220480,181.618652,36.0,154.618823


### GLA

In [10]:
d = 768
N_layer = 12
n_h = 12
d_k = d // n_h
d_v = d_k
d_ff = 4 * d
V = 50257
L = 1024

# Model Size: count parameters
# Embedding + LM head (weight tying cancels doubling?)
emb_params = V * d
# Per-layer parameters:
# Attention: QKV proj (3*d*d) + output proj (d*d)
attn_params = 3 * d * d + d * d
# Feed-forward: d*d_ff (in) + d_ff*d (out)
ffn_params = d * d_ff + d_ff * d
# Total per layer
per_layer_params = attn_params + ffn_params
# Total parameters
total_params = emb_params + per_layer_params * N_layer
# Bytes in bfloat16
bytes_per_param = 2
model_size_bytes = total_params * bytes_per_param
model_size_mib = model_size_bytes / (1024**2)

# KV Cache Size: keys + values for all layers and heads
def calc_kv_cache_bytes(L, N_layer, n_h, d_k, d_v):
    # total elements: N_layer * n_h * L * (d_k + d_v)
    elements = N_layer * n_h * L * (d_k + d_v)
    return elements * bytes_per_param

total_kv_bytes = calc_kv_cache_bytes(L, N_layer, n_h, d_k, d_v)
total_kv_mib = total_kv_bytes / (1024**2)

# FLOPs per forward pass
def matmul_flops(a, b, c):
    return 2 * a * b * c

# Attention FLOPs per layer:
# QKV projections: 3 * (L*d*d)
qkv_flops = 3 * matmul_flops(L, d, d)
# Q*K^T: L * (d_k) * L
qk_flops = matmul_flops(L, d_k, L)
# attn weights * V: L * L * d_v
attn_v_flops = matmul_flops(L, L, d_v)
# output projection: L * d * d
proj_flops = matmul_flops(L, d, d)
# Feed-forward FLOPs per layer: input and output
ffn_fflops = matmul_flops(L, d, d_ff) + matmul_flops(L, d_ff, d)
# Total per layer
per_layer_flops = qkv_flops + qk_flops + attn_v_flops + proj_flops + ffn_fflops
# Total FLOPs
total_flops = per_layer_flops * N_layer
# Convert to GFLOPs
total_gflops = total_flops / 1e9

comparison_new_row = ['GLA', total_params, model_size_mib, total_kv_mib, total_gflops]
comparison = pd.concat([comparison, pd.DataFrame([comparison_new_row], columns=comparison.columns)], ignore_index=True)

print(f"Total parameters: {total_params:,d} ({model_size_mib:.1f} MiB)")
print(f"KV cache size: {total_kv_mib:.1f} MiB")
print(f"FLOPs per forward pass: {total_gflops:.1f} GFLOPs")


Total parameters: 123,532,032 (235.6 MiB)
KV cache size: 36.0 MiB
FLOPs per forward pass: 177.2 GFLOPs


**print the table summarizing all the results**

In [11]:
comparison

Unnamed: 0,Model,Total Parameters,Model Size (MiB),KV-Cache (MiB),FLOPs per Forward (GFLOPs)
0,GPT-2 Small,124318464,237.118652,36.0,291.648307
1,GPT-2 Small + GQA (G=4),114881280,219.118652,12.0,272.320954
2,Mamba-2,95220480,181.618652,36.0,154.618823
3,GLA,123532032,235.618652,36.0,177.167401
