In [2]:
from dataclasses import dataclass
from transformers import AutoTokenizer


@dataclass
class ArchParams:
    name: str
    model_dim: int
    layers: int
    vocab_size: int
    ffn_factor: float
    sequence_length: int

def ff_cost(model_dim, ffn_factor):
    return model_dim * 2 * ffn_factor

def attention_cost(model_dim, sequence_length):
    return (model_dim * 4 + sequence_length )

def env_transformer_cost(arch_params: ArchParams) -> dict:
    # These calculations are per token, and have a model_dim factored out, which cancels
    # in all three branches.
    ff = arch_params.layers * ff_cost(arch_params.model_dim, arch_params.ffn_factor)
    att = arch_params.layers * attention_cost(arch_params.model_dim, arch_params.sequence_length)    
    total = ff + att + arch_params.vocab_size

    return {
        'ff': ff / total,
        'att': att / total,
        'vocab': arch_params.vocab_size / total
    }

arch_params_list = [
    ArchParams(
        name='gemma2b',
        model_dim=2048,
        layers=18,
        vocab_size=256128,
        ffn_factor=16.0,
        sequence_length=8192
    ),
    ArchParams(
        name='gemma7b',
        model_dim=3072,
        layers=28,
        vocab_size=256128,
        ffn_factor=16.0,
        sequence_length=8192
    ),
    ArchParams(
        name='llama8b',
        model_dim=4096,
        layers=32,
        vocab_size=128000,
        ffn_factor=3.5,
        sequence_length=8192
    ),
    ArchParams(
        name='llama70b',
        model_dim=8192,
        layers=64,
        vocab_size=128000,
        ffn_factor=3.5,
        sequence_length=8192
    ),
]
print(f"{'name'.ljust(10)}{'vocab':>10}{'att':>12}{'ff':>15}")
for param in arch_params_list:
    costs = env_transformer_cost(param)
    print(f"{param.name.ljust(10)} {costs['vocab']:>10.1%} {costs['att']:>12.1%} {costs['ff']:>15.1%}")



name           vocab         att             ff
gemma2b         14.8%        17.0%           68.2%
gemma7b          7.2%        16.0%           76.8%
llama8b          7.0%        42.9%           50.1%
llama70b         2.0%        40.8%           57.2%


In [3]:
# tok = AutoTokenizer.from_pretrained('NousResearch/gemma-2b-it-tokenizer')
# v = tok.get_vocab()

tokenizer_config.json:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]