In [32]:
import m2_utilities.flops as flops
from m2_utilities.qwen import load_qwen

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [33]:
model, tokenizer = load_qwen()
print(model)

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(151936, 896)
    (layers): ModuleList(
      (0-23): 24 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=896, out_features=896, bias=True)
          (k_proj): Linear(in_features=896, out_features=128, bias=True)
          (v_proj): Linear(in_features=896, out_features=128, bias=True)
          (o_proj): Linear(in_features=896, out_features=896, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=896, out_features=4864, bias=False)
          (up_proj): Linear(in_features=896, out_features=4864, bias=False)
          (down_proj): Linear(in_features=4864, out_features=896, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((896,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbe

### Single Forward Pass

In [46]:
N_TOKENS = 512

n_flops = flops.compute_flops(N_TOKENS, backpropagate=False)
print(f"Total FLOPS: {n_flops:.4e}")

# Total FLOPS: 5.6518e+11

Total FLOPS: 5.6518e+11


### Adding Poisitional Embeddings

In [47]:
n_flops = flops.embedding(N_TOKENS, D_MODEL)
print(f"Embbeding Layer: {n_flops:.2e}")

Embbeding Layer: 4.59e+05


### All Self-Attention Blocks

In [48]:
n_flops = flops.block(N_TOKENS, N_HEADS, D_MODEL, HIDDEN_SIZE)
print(f"Single Block: {n_flops:.2e}")
print(f"{N_LAYERS} Blocks: {N_LAYERS * n_flops:.2e}")

Single Block: 1.77e+10
24 Blocks: 4.25e+11


### Breakdown of a Single Block

In [49]:
n_flops = flops.ffn(N_TOKENS, D_MODEL, HIDDEN_SIZE)
print(f"FFN: {n_flops:.2e}")

n_flops = flops.multi_head_self_attention(N_TOKENS, N_HEADS, D_MODEL)
print(f"MHSA: {n_flops:.2e}")

n_flops = flops.rms_norm(N_TOKENS, D_MODEL)
print(f"RMSNorm: {n_flops:.2e}")

n_flops = flops.add_residual(N_TOKENS, D_MODEL)
print(f"Residual: {n_flops:.2e}")

FFN: 1.34e+10
MHSA: 4.27e+09
RMSNorm: 5.97e+06
Residual: 4.59e+05


### Post Self-Attention Operations

In [53]:
n_flops = flops.final_linear(N_TOKENS, D_MODEL, VOCAB_SIZE)
print(f"Final Linear Transform: { n_flops:.2e}")

n_flops = flops.softmax(N_TOKENS, VOCAB_SIZE)
print(f"Final Softmax: {n_flops:.2e}")


Final Linear Transform: 1.39e+11
Final Softmax: 4.77e+11
