In [1]:
import torch
from model import GPTModel
from thop import profile
from loguru import logger

GPT_CONFIG_124M = {
    "vocab_size": 50257,   # Vocabulary size
    "context_length": 256, # Shortened context length (orig: 1024)
    "emb_dim": 768,        # Embedding dimension
    "n_heads": 12,         # Number of attention heads
    "n_layers": 12,        # Number of layers
    "drop_rate": 0.1,      # Dropout rate
    "qkv_bias": False      # Query-key-value bias
}

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.eval()

logger.info("model loaded")

[32m2024-11-28 15:12:45.578[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m20[0m - [1mmodel loaded[0m


In [3]:
# MACS = multiply-accumulate operations
# MACS are typically counted as two FLOPS (one multiply and one accumulate)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
batch_size = [16, 32, 64, 128]

model.to(device)

for _bs in batch_size:
    input_tensor = torch.randint(0, 50257, (_bs, 256)).to(device)
    macs, params = profile(model, inputs=(input_tensor,), verbose=False)
    flops = 2*macs
    print(f"{flops:.1e} FLOPS for batch size {_bs}")
    print(f"{macs:.1e} MACS for batch size {_bs}")
    print('\n\n')

1.0e+12 FLOPS for batch size 16
5.1e+11 MACS for batch size 16



2.0e+12 FLOPS for batch size 32
1.0e+12 MACS for batch size 32



4.0e+12 FLOPS for batch size 64
2.0e+12 MACS for batch size 64



8.1e+12 FLOPS for batch size 128
4.0e+12 MACS for batch size 128



