In [1]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [2]:
import torch
import math
from einops import einsum, rearrange
import numpy as np
import json
import pathlib

from llm.transformer import TransformerLM

from fvcore.nn import FlopCountAnalysis
from fvcore.nn import parameter_count_table

In [3]:
vocab_size = 50257
context_length = 1024
# context_length = 16384
num_layers = 48
d_model = 1600
num_heads = 25
d_ff = 6400
model = TransformerLM(
    vocab_size=vocab_size,
    context_length=context_length,
    num_layers=num_layers,
    num_heads=num_heads,
    d_model=d_model,
    d_ff = d_ff,
    rope_theta = 1000.0
)

In [4]:
model

TransformerLM(
  (token_embeddings): Embedding(vocab_size=50257, d=1600)
  (RoPE): RotaryPositionalEmbedding(context_length=1024, dim/2=32)
  (layers): ModuleList(
    (0-47): 48 x TransformerBlock(
      (attn): CausalMHSARoPE(
        (qkv_proj): Linear(d_out=4800, d_in=1600)
        (output_proj): Linear(d_out=1600, d_in=1600)
        (RoPE): RotaryPositionalEmbedding(context_length=1024, dim/2=32)
      )
      (ffn): SwiGLU(
        (w1): Linear(d_out=6400, d_in=1600)
        (w2): Linear(d_out=1600, d_in=6400)
        (w3): Linear(d_out=6400, d_in=1600)
      )
      (ln1): RMSNorm(hidden_size=1600, eps=1e-05)
      (ln2): RMSNorm(hidden_size=1600, eps=1e-05)
    )
  )
  (ln_final): RMSNorm(hidden_size=1600, eps=1e-05)
  (lm_head): Linear(d_out=50257, d_in=1600)
)

In [5]:
total_params = sum(p.numel() for p in model.parameters())
total_params

2127057600

In [6]:
def count_parameters_in_millions(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Trainable parameters: {total_params / 1e6:.2f}M")
def parameter_memory_in_megabytes(model):
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_bytes = total_params * 4  # float32 = 4 bytes
    total_megabytes = total_bytes / (1024 ** 2)  # bytes to MB
    print(f"Estimated memory (float32): {total_megabytes:.2f} MB")

In [7]:
count_parameters_in_millions(model)
parameter_memory_in_megabytes(model)

Trainable parameters: 2127.06M
Estimated memory (float32): 8114.08 MB


In [8]:
with torch.no_grad():
    in_indices = torch.zeros(1, context_length, dtype=torch.int64)
    flops = FlopCountAnalysis(model, in_indices)
    print(f"FLOPs: {flops.total()/1e9:.2f} GFLOPs")  # in billions

Unsupported operator aten::pow encountered 97 time(s)
Unsupported operator aten::mean encountered 97 time(s)
Unsupported operator aten::add encountered 289 time(s)
Unsupported operator aten::sqrt encountered 97 time(s)
Unsupported operator aten::div encountered 145 time(s)
Unsupported operator aten::mul encountered 625 time(s)
Unsupported operator aten::repeat_interleave encountered 192 time(s)
Unsupported operator aten::neg encountered 96 time(s)
Unsupported operator aten::reshape_as encountered 96 time(s)
Unsupported operator aten::tril encountered 48 time(s)
Unsupported operator aten::div_ encountered 48 time(s)
Unsupported operator aten::sub encountered 48 time(s)
Unsupported operator aten::exp encountered 48 time(s)
Unsupported operator aten::sum encountered 48 time(s)
Unsupported operator aten::mul_ encountered 48 time(s)
Unsupported operator aten::sigmoid encountered 48 time(s)


FLOPs: 2256.58 GFLOPs


In [9]:
print("Parameter count table:")
print(parameter_count_table(model, max_depth=2))

Parameter count table:
| name                      | #elements or shape   |
|:--------------------------|:---------------------|
| model                     | 2.1G                 |
|  token_embeddings         |  80.4M               |
|   token_embeddings.weight |   (50257, 1600)      |
|  layers                   |  2.0G                |
|   layers.0                |   41.0M              |
|   layers.1                |   41.0M              |
|   layers.2                |   41.0M              |
|   layers.3                |   41.0M              |
|   layers.4                |   41.0M              |
|   layers.5                |   41.0M              |
|   layers.6                |   41.0M              |
|   layers.7                |   41.0M              |
|   layers.8                |   41.0M              |
|   layers.9                |   41.0M              |
|   layers.10               |   41.0M              |
|   layers.11               |   41.0M              |
|   layers.12          

In [10]:
for name, num_flops in flops.by_module().items():
    gflops = f"{num_flops/1e9:.2f} GFLOPs"
    if name == "":
        name = "TOTAL"
    if num_flops > 0:
        print(f"{name}: {gflops}")

TOTAL: 2256.58 GFLOPs
layers: 2174.23 GFLOPs
layers.0: 45.30 GFLOPs
layers.0.attn: 13.84 GFLOPs
layers.0.attn.qkv_proj: 7.87 GFLOPs
layers.0.attn.output_proj: 2.62 GFLOPs
layers.0.ffn: 31.45 GFLOPs
layers.0.ffn.w1: 10.48 GFLOPs
layers.0.ffn.w2: 10.48 GFLOPs
layers.0.ffn.w3: 10.48 GFLOPs
layers.1: 45.30 GFLOPs
layers.1.attn: 13.84 GFLOPs
layers.1.attn.qkv_proj: 7.87 GFLOPs
layers.1.attn.output_proj: 2.62 GFLOPs
layers.1.ffn: 31.45 GFLOPs
layers.1.ffn.w1: 10.48 GFLOPs
layers.1.ffn.w2: 10.48 GFLOPs
layers.1.ffn.w3: 10.48 GFLOPs
layers.2: 45.30 GFLOPs
layers.2.attn: 13.84 GFLOPs
layers.2.attn.qkv_proj: 7.87 GFLOPs
layers.2.attn.output_proj: 2.62 GFLOPs
layers.2.ffn: 31.45 GFLOPs
layers.2.ffn.w1: 10.48 GFLOPs
layers.2.ffn.w2: 10.48 GFLOPs
layers.2.ffn.w3: 10.48 GFLOPs
layers.3: 45.30 GFLOPs
layers.3.attn: 13.84 GFLOPs
layers.3.attn.qkv_proj: 7.87 GFLOPs
layers.3.attn.output_proj: 2.62 GFLOPs
layers.3.ffn: 31.45 GFLOPs
layers.3.ffn.w1: 10.48 GFLOPs
layers.3.ffn.w2: 10.48 GFLOPs
layers.3.ffn.

In [11]:
print(f"lm_head FLOPs ratio: {100 * 82.35 / 2174.23:.2f}%")

lm_head FLOPs ratio: 3.79%


In [12]:
# (1,25,1024,64) @ (1,25,1024,64)
# FLOPs = 2 * 1024 * 1024 * 64 for the (1024,64) @ (1024,64) multiplication
# FLOPs across batch = 1*25*FLOPs from above
flops_attn_QV =  num_heads * (2 * context_length * context_length * (d_model / num_heads))
print(f"Attention Q@K^T {flops_attn_QV/1e9:.2f} GFLOPs")
# (1,25,1024,1024) * (1,25,1024,64)^T = (1,25,1024,64)
flops_attn_out = num_heads * (2 * context_length * context_length* (d_model / num_heads))
print(f"Attention * V^T {flops_attn_out/1e9:.2f} GFLOPs")

sqrt_flops = num_heads * (context_length**2)
print(f"Scaling by 1/sqrt(d_k): {sqrt_flops/1e9:.2f} GFLOPs")

Attention Q@K^T 3.36 GFLOPs
Attention * V^T 3.36 GFLOPs
Scaling by 1/sqrt(d_k): 0.03 GFLOPs


In [13]:
model_small = TransformerLM(
    vocab_size=50257,
    context_length=1024,
    num_layers=12,
    num_heads=12,
    d_model=768,
    d_ff = 768*4,
    rope_theta = 1000.0
)
model_med = TransformerLM(
    vocab_size=50257,
    context_length=1024,
    num_layers=24,
    num_heads=16,
    d_model=1024,
    d_ff = 1024*4,
    rope_theta = 1000.0
)
model_large = TransformerLM(
    vocab_size=50257,
    context_length=1024,
    num_layers=36,
    num_heads=20,
    d_model=1280,
    d_ff = 1280*4,
    rope_theta = 1000.0
)

In [14]:
for gpt in (model_small, model_med, model_large):
    with torch.no_grad():
        in_indices = torch.zeros(1, context_length, dtype=torch.int64)
        flops = FlopCountAnalysis(gpt, in_indices)
        print(f"FLOPs: {flops.total()/1e9:.2f} GFLOPs")  # in billions
        total_flops = flops.total()
        print(f"lm_head FLOPs ratio: {100 * flops.by_module()["lm_head"] / total_flops:.2f}%")
        for name, num_flops in flops.by_module().items():
            gflops = f"{num_flops/1e9:.2f} GFLOPs"
            if name == "":
                name = "TOTAL"
            if num_flops > 0:
                print(f"{name}: {gflops}")
        


Unsupported operator aten::pow encountered 25 time(s)
Unsupported operator aten::mean encountered 25 time(s)
Unsupported operator aten::add encountered 73 time(s)
Unsupported operator aten::sqrt encountered 25 time(s)
Unsupported operator aten::div encountered 37 time(s)
Unsupported operator aten::mul encountered 157 time(s)
Unsupported operator aten::repeat_interleave encountered 48 time(s)
Unsupported operator aten::neg encountered 24 time(s)
Unsupported operator aten::reshape_as encountered 24 time(s)
Unsupported operator aten::tril encountered 12 time(s)
Unsupported operator aten::div_ encountered 12 time(s)
Unsupported operator aten::sub encountered 12 time(s)
Unsupported operator aten::exp encountered 12 time(s)
Unsupported operator aten::sum encountered 12 time(s)
Unsupported operator aten::mul_ encountered 12 time(s)
Unsupported operator aten::sigmoid encountered 12 time(s)


FLOPs: 174.82 GFLOPs
lm_head FLOPs ratio: 22.61%
TOTAL: 174.82 GFLOPs
layers: 135.30 GFLOPs
layers.0: 11.28 GFLOPs
layers.0.attn: 4.03 GFLOPs
layers.0.attn.qkv_proj: 1.81 GFLOPs
layers.0.attn.output_proj: 0.60 GFLOPs
layers.0.ffn: 7.25 GFLOPs
layers.0.ffn.w1: 2.42 GFLOPs
layers.0.ffn.w2: 2.42 GFLOPs
layers.0.ffn.w3: 2.42 GFLOPs
layers.1: 11.28 GFLOPs
layers.1.attn: 4.03 GFLOPs
layers.1.attn.qkv_proj: 1.81 GFLOPs
layers.1.attn.output_proj: 0.60 GFLOPs
layers.1.ffn: 7.25 GFLOPs
layers.1.ffn.w1: 2.42 GFLOPs
layers.1.ffn.w2: 2.42 GFLOPs
layers.1.ffn.w3: 2.42 GFLOPs
layers.2: 11.28 GFLOPs
layers.2.attn: 4.03 GFLOPs
layers.2.attn.qkv_proj: 1.81 GFLOPs
layers.2.attn.output_proj: 0.60 GFLOPs
layers.2.ffn: 7.25 GFLOPs
layers.2.ffn.w1: 2.42 GFLOPs
layers.2.ffn.w2: 2.42 GFLOPs
layers.2.ffn.w3: 2.42 GFLOPs
layers.3: 11.28 GFLOPs
layers.3.attn: 4.03 GFLOPs
layers.3.attn.qkv_proj: 1.81 GFLOPs
layers.3.attn.output_proj: 0.60 GFLOPs
layers.3.ffn: 7.25 GFLOPs
layers.3.ffn.w1: 2.42 GFLOPs
layers.3.ffn.w

Unsupported operator aten::pow encountered 49 time(s)
Unsupported operator aten::mean encountered 49 time(s)
Unsupported operator aten::add encountered 145 time(s)
Unsupported operator aten::sqrt encountered 49 time(s)
Unsupported operator aten::div encountered 73 time(s)
Unsupported operator aten::mul encountered 313 time(s)
Unsupported operator aten::repeat_interleave encountered 96 time(s)
Unsupported operator aten::neg encountered 48 time(s)
Unsupported operator aten::reshape_as encountered 48 time(s)
Unsupported operator aten::tril encountered 24 time(s)
Unsupported operator aten::div_ encountered 24 time(s)
Unsupported operator aten::sub encountered 24 time(s)
Unsupported operator aten::exp encountered 24 time(s)
Unsupported operator aten::sum encountered 24 time(s)
Unsupported operator aten::mul_ encountered 24 time(s)
Unsupported operator aten::sigmoid encountered 24 time(s)


FLOPs: 516.54 GFLOPs
lm_head FLOPs ratio: 10.20%
TOTAL: 516.54 GFLOPs
layers: 463.84 GFLOPs
layers.0: 19.33 GFLOPs
layers.0.attn: 6.44 GFLOPs
layers.0.attn.qkv_proj: 3.22 GFLOPs
layers.0.attn.output_proj: 1.07 GFLOPs
layers.0.ffn: 12.88 GFLOPs
layers.0.ffn.w1: 4.29 GFLOPs
layers.0.ffn.w2: 4.29 GFLOPs
layers.0.ffn.w3: 4.29 GFLOPs
layers.1: 19.33 GFLOPs
layers.1.attn: 6.44 GFLOPs
layers.1.attn.qkv_proj: 3.22 GFLOPs
layers.1.attn.output_proj: 1.07 GFLOPs
layers.1.ffn: 12.88 GFLOPs
layers.1.ffn.w1: 4.29 GFLOPs
layers.1.ffn.w2: 4.29 GFLOPs
layers.1.ffn.w3: 4.29 GFLOPs
layers.2: 19.33 GFLOPs
layers.2.attn: 6.44 GFLOPs
layers.2.attn.qkv_proj: 3.22 GFLOPs
layers.2.attn.output_proj: 1.07 GFLOPs
layers.2.ffn: 12.88 GFLOPs
layers.2.ffn.w1: 4.29 GFLOPs
layers.2.ffn.w2: 4.29 GFLOPs
layers.2.ffn.w3: 4.29 GFLOPs
layers.3: 19.33 GFLOPs
layers.3.attn: 6.44 GFLOPs
layers.3.attn.qkv_proj: 3.22 GFLOPs
layers.3.attn.output_proj: 1.07 GFLOPs
layers.3.ffn: 12.88 GFLOPs
layers.3.ffn.w1: 4.29 GFLOPs
layers.3.f

Unsupported operator aten::pow encountered 73 time(s)
Unsupported operator aten::mean encountered 73 time(s)
Unsupported operator aten::add encountered 217 time(s)
Unsupported operator aten::sqrt encountered 73 time(s)
Unsupported operator aten::div encountered 109 time(s)
Unsupported operator aten::mul encountered 469 time(s)
Unsupported operator aten::repeat_interleave encountered 144 time(s)
Unsupported operator aten::neg encountered 72 time(s)
Unsupported operator aten::reshape_as encountered 72 time(s)
Unsupported operator aten::tril encountered 36 time(s)
Unsupported operator aten::div_ encountered 36 time(s)
Unsupported operator aten::sub encountered 36 time(s)
Unsupported operator aten::exp encountered 36 time(s)
Unsupported operator aten::sum encountered 36 time(s)
Unsupported operator aten::mul_ encountered 36 time(s)
Unsupported operator aten::sigmoid encountered 36 time(s)


FLOPs: 1128.80 GFLOPs
lm_head FLOPs ratio: 5.83%
TOTAL: 1128.80 GFLOPs
layers: 1062.95 GFLOPs
layers.0: 29.53 GFLOPs
layers.0.attn: 9.40 GFLOPs
layers.0.attn.qkv_proj: 5.04 GFLOPs
layers.0.attn.output_proj: 1.68 GFLOPs
layers.0.ffn: 20.13 GFLOPs
layers.0.ffn.w1: 6.71 GFLOPs
layers.0.ffn.w2: 6.71 GFLOPs
layers.0.ffn.w3: 6.71 GFLOPs
layers.1: 29.53 GFLOPs
layers.1.attn: 9.40 GFLOPs
layers.1.attn.qkv_proj: 5.04 GFLOPs
layers.1.attn.output_proj: 1.68 GFLOPs
layers.1.ffn: 20.13 GFLOPs
layers.1.ffn.w1: 6.71 GFLOPs
layers.1.ffn.w2: 6.71 GFLOPs
layers.1.ffn.w3: 6.71 GFLOPs
layers.2: 29.53 GFLOPs
layers.2.attn: 9.40 GFLOPs
layers.2.attn.qkv_proj: 5.04 GFLOPs
layers.2.attn.output_proj: 1.68 GFLOPs
layers.2.ffn: 20.13 GFLOPs
layers.2.ffn.w1: 6.71 GFLOPs
layers.2.ffn.w2: 6.71 GFLOPs
layers.2.ffn.w3: 6.71 GFLOPs
layers.3: 29.53 GFLOPs
layers.3.attn: 9.40 GFLOPs
layers.3.attn.qkv_proj: 5.04 GFLOPs
layers.3.attn.output_proj: 1.68 GFLOPs
layers.3.ffn: 20.13 GFLOPs
layers.3.ffn.w1: 6.71 GFLOPs
layers.3