In [6]:
from typing import Tuple

In [14]:
# utils to assess model size given its architecture
def n_params(model_dim: int, n_heads: int, n_kv_heads: int, head_dim : int, n_layers: int, hidden_dim: int, vocab_size: int, verbose: bool = False, num_experts: int = 1) -> Tuple[int, int]:
    """
    Computes the number of parameters of a Transformer model given its architecture.
    """
    hidden_dim = int(2 * 2 * hidden_dim / 3)
    real_hidden_dim = 16 * ((hidden_dim + 16 - 1) // 16)
    embedding_dim = model_dim * vocab_size
    attention_qo = model_dim * n_heads * head_dim * 2 # no bias
    attention_kv = model_dim * n_kv_heads * head_dim * 2 # no bias
    ff = model_dim * real_hidden_dim * 3 # no bias
    ff_moe = ff * num_experts # no bias
    attention_norm = model_dim * 2
    transformer_block = attention_qo + attention_kv + ff + attention_norm
    transformer_block_moe = attention_qo + attention_kv + ff_moe + attention_norm
    final_norm = model_dim 
    
    total = embedding_dim + n_layers * transformer_block + final_norm 
    total_moe = embedding_dim + n_layers * transformer_block_moe + final_norm

    if verbose:

        print("== EMBEDDING ==")
        print(f"EMBEDDING DIM: {embedding_dim: ,}")

        print("== TRANSFORMER BLOCK ==")
        print(f"ATTENTION Query / Output: {attention_qo: ,}")
        print(f"ATTENTION Key / Value: {attention_kv: ,}")
        print(f"FF: {ff: ,}")
        print(f"FF MOE for {num_experts} experts: {ff_moe: ,}")
        print(f"ATTENTION NORM: {attention_norm: ,}")
        print(f"TOTAL PER TRANSFORMER BLOCK: {transformer_block: ,}")

        print("== FINAL NORM ==")
        print(f"FINAL NORM: {final_norm: ,}")

        print("== TOTAL PARAMS ==")
        print(f"TOTAL PARAMS: {total: ,}")
        print(f"TOTAL PARAMS MOE: {total_moe: ,}")
    
    return total, total_moe

In [20]:
_ = n_params(model_dim=1024, n_heads=16, n_kv_heads=4, head_dim=64, n_layers=14, hidden_dim=1536, vocab_size=32000, num_experts=12, verbose=True)

== EMBEDDING ==
EMBEDDING DIM:  32,768,000
== TRANSFORMER BLOCK ==
ATTENTION Query / Output:  2,097,152
ATTENTION Key / Value:  524,288
FF:  6,291,456
FF MOE for 12 experts:  75,497,472
ATTENTION NORM:  2,048
TOTAL PER TRANSFORMER BLOCK:  8,914,944
== FINAL NORM ==
FINAL NORM:  1,024
== TOTAL PARAMS ==
TOTAL PARAMS:  157,578,240
TOTAL PARAMS MOE:  1,126,462,464


In [29]:
def compute_flops(context_length: int, vocab_size: int, dim_model: int, n_heads: int, n_layers: int, hidden_dim: int, verbose: bool = False) -> int:
    """
    Computes the number of FLOPS of a Transformer model given its architecture.
    Following Chinchilla paper and nanoGPT implementation from A. Karpathy.
    """
    key_size = dim_model // n_heads
    hidden_dim = int(2 * 2 * hidden_dim / 3)
    real_hidden_dim = 16 * ((hidden_dim + 16 - 1) // 16)

    # embedding
    embedding = 2 * vocab_size * dim_model * context_length # TODO: check why we have 2 operations here
    # attention
    ## key, query, value projections
    attention = 2 * 3 * dim_model * key_size * context_length 
    ## key @ query logits
    attlogits = 2 * context_length * context_length * n_heads * key_size
    ## softmax
    attsoftmax = 3 * n_heads * context_length * context_length
    ## value @ attention
    attvalue = 2 * context_length * context_length * n_heads * key_size
    ## attention output projection
    attout = 2 * context_length * key_size * n_heads * dim_model

    att= attention + attlogits + attsoftmax + attvalue + attout

    # feedforward
    ff = 2 * context_length * dim_model * real_hidden_dim * 3 # we have 3 projections with the activation function

    # logits
    logits = 2 * context_length * vocab_size * dim_model

    forward_flops = embedding + n_layers * (att + ff) + logits # NB this apparently differs from Chinchilla paper as they do not count embedding and logits
    backward_flops = 2 * forward_flops # from Kaplan et al. 2020 paper

    total_flops = forward_flops + backward_flops

    if verbose:
        print(f"TOTAL FLOPS: {total_flops: e}") 
    return total_flops

    

In [30]:
_ = compute_flops(context_length=1024, vocab_size=8000, dim_model=2048, n_heads=16, n_layers=6, hidden_dim=2048, verbose=True)

TOTAL FLOPS:  1.160144e+12


In [31]:
# Defining optimal compute budget

# raw data (params, tokens) from Chinchilla paper
raw = [
    [400e6, 7.7e9],
    [1e9, 20.0e9],
    [10e9, 219.5e9],
    [67e9, 1.7e12],
    [175e9, 4.3e12],
    [280e9, 7.1e12],
    [520e9, 13.4e12],
    [1e12, 26.5e12],
    [10e12, 292.0e12]
]

In [32]:
# fitting a linear regression
import numpy as np

x = np.array([np.log10(x[0]) for x in raw])
y = np.array([np.log10(x[1]) for x in raw])
A = np.vstack([x, np.ones(len(x))]).T
m, c = np.linalg.lstsq(A, y, rcond=None)[0]
print(f"y = {m:.2f}x + {c:.2f}")

y = 1.04x + 0.94


In [33]:
def get_optimal_tokens_for_params(params: int) -> None:
    """
    Computes the optimal number of tokens given a number of parameters.
    """
    print(f"Number of optimal tokens for model params {params: e} ==> {10 ** ((np.log10(params) * m) + c): e}" ) 

In [35]:
get_optimal_tokens_for_params(2.3e9)

Number of optimal tokens for model params  2.300000e+09 ==>  4.792342e+10


### Simulate largest possible model for a given flops budget

In [None]:
import optuna

In [None]:
def objective(trial):

    vocab_size = trial.suggest_categorical('vocab_size', [2000, 4000, 8000, 16000, 32000])
    model_dim = trial.suggest_categorical('model_dim', [256, 512, 640, 768, 1024])
    n_layers = trial.suggest_int('n_layers', 1, 12)
    n_heads = trial.suggest_int('n_heads', 1, 12)
    head_dim = model_dim // n_heads
    hidden_dim = trial.suggest_categorical('hidden_dim', [256, 512, 640, 768, 1024])

    flops = compute_flops(context_length=256, vocab_size=vocab_size, dim_model=model_dim, n_heads=n_heads, n_layers=n_layers, hidden_dim=hidden_dim)

    if abs(flops - 4.9e10) < 1e10: # this is our compute budget
        params = n_params(model_dim=model_dim, n_heads=n_heads, head_dim=head_dim, n_layers=n_layers, hidden_dim=hidden_dim, vocab_size=vocab_size)
    else:
        params = 0
    return params

In [None]:
study = optuna.create_study(study_name="max_params_for_budget",direction='maximize')
study.optimize(objective, n_trials=3000)

In [None]:
study.best_params

In [None]:
_ = n_params(head_dim=1024//16, model_dim=1024, n_layers=11, n_heads=16, hidden_dim=256, vocab_size=2000, verbose=True)

In [None]:
_ = compute_flops(context_length=256, vocab_size=2000, dim_model=1024, n_heads=16, n_layers=11, hidden_dim=256, verbose=True)

In [None]:
get_optimal_tokens_for_params(0.2e6)

In [None]:
# Number of iteration to reach 1e9 tokens

total_tokens = 1.1e9
batch_size = 32
context_length = 256

total_tokens / (batch_size * context_length)