In [1]:
# add ../doclang_scaling to python path
import sys
import os
sys.path.append(os.path.realpath(os.path.join( '..', 'doclang_scaling')))
sys.path.append(os.path.realpath(os.path.join( '..')))

In [2]:
from doclang_scaling.config import ModelShape
from doclang_scaling.alibi_transformer import AlibiTransformer


In [3]:

m2_model_shape = ModelShape(
    layers=4,
    d_model=128,
    n_heads=4,
    d_vocab=8000
)


# model_shape:
#   layers: 6
#   d_model: 256
#   n_heads: 4
#   d_vocab: 88
m7_model_shape = ModelShape(
    layers=6,
    d_model=256,
    n_heads=4,
    d_vocab=50257
)


# model_shape:
#   layers: 6
#   d_model: 384
#   n_heads: 6
#   d_vocab: 88
m14_model_shape = ModelShape(
    layers=6,
    d_model=384,
    n_heads=6,
    d_vocab=50257
)



In [5]:
gpt2_small = ModelShape(
    layers=12,
    d_model=768,
    n_heads=12,
    d_vocab=50_257,
    ffw_size=4
)

print(gpt2_small.num_params / 1e6, "Million parameters")

162.250752 Million parameters


In [6]:
trans = AlibiTransformer(**gpt2_small.__dict__)
print("Num params: ", trans.count_params() / 1e6, "Million parameters")


Num params:  162.301009 Million parameters


In [7]:
context_lengths = [128, 512, 2048, 8192]
results = {}
for context_length in context_lengths:
    result = gpt2_small.calculate_flops_per_token(context_length=context_length)
    results[context_length] = result
    print(f'{context_length}: {result/1e6:.2f} MegaFLOPs/token')

128: 215.93 MegaFLOPs/token
512: 230.26 MegaFLOPs/token
2048: 287.54 MegaFLOPs/token
8192: 516.69 MegaFLOPs/token
