In [61]:
import numpy as np
import dataclasses

@dataclasses.dataclass(frozen=True)
class OptConfig:
    name: str = "opt-125m"
    num_hidden_layers: int = 12
    max_seq_len: int = 2048
    hidden_size: int = 768
    n_head: int = 12
    input_dim: int = 768
    ffn_embed_dim: int = 3072
    pad: int = 1
    activation_fn: str = 'relu'
    vocab_size: int = 50272
    layer_norm_eps: float = 0.00001
    pad_token_id: int = 1
    dtype: type = np.float16

    def model_bytes(self):
        h = self.input_dim
        return self.vocab_size * (h + 1), self.num_hidden_layers *(
        # self-attention
        h * (3 * h + 1) + h * (h + 1) +
        # mlp
        h * (4 * h + 1) + h * 4 * (h + 1) +
        # layer norm
        h * 4)
        return 	2 * (self.num_hidden_layers * (
        # self-attention
        h * (3 * h + 1) + h * (h + 1) +
        # mlp
        h * (4 * h + 1) + h * 4 * (h + 1) +
        # layer norm
        h * 4) +
        # embedding
        self.vocab_size * (h + 1))

    def cache_bytes(self, batch_size, seq_len):
        return 2 * batch_size * seq_len * self.num_hidden_layers * self.input_dim * 2

    def hidden_bytes(self, batch_size, seq_len):
        return batch_size * seq_len * self.input_dim * 2


config = OptConfig(name='175b',
            max_seq_len=2048, num_hidden_layers=96, n_head=96,
            hidden_size=12288, input_dim=12288, ffn_embed_dim=12288 * 4,
        )

In [63]:
[x / (2**30) for x in config.model_bytes() ]

[0.5753642022609711, 162.0120849609375]

In [45]:
def dfs(model, prefix='model'):
    for n, m in model.named_children():
        if not len(list(m.children())) and hasattr(m, 'weight'):
            print(prefix + '.' + n, m.weight.numel() * 2 // (2 ** 20), 'MB') # float16
        else:
            dfs(m, prefix + '.' + n)
dfs(model)

model.model.decoder.embed_tokens 73 MB
model.model.decoder.embed_positions 3 MB
model.model.decoder.final_layer_norm 0 MB
model.model.decoder.layers.0.self_attn.k_proj 1 MB
model.model.decoder.layers.0.self_attn.v_proj 1 MB
model.model.decoder.layers.0.self_attn.q_proj 1 MB
model.model.decoder.layers.0.self_attn.out_proj 1 MB
model.model.decoder.layers.0.self_attn_layer_norm 0 MB
model.model.decoder.layers.0.fc1 4 MB
model.model.decoder.layers.0.fc2 4 MB
model.model.decoder.layers.0.final_layer_norm 0 MB
model.model.decoder.layers.1.self_attn.k_proj 1 MB
model.model.decoder.layers.1.self_attn.v_proj 1 MB
model.model.decoder.layers.1.self_attn.q_proj 1 MB
model.model.decoder.layers.1.self_attn.out_proj 1 MB
model.model.decoder.layers.1.self_attn_layer_norm 0 MB
model.model.decoder.layers.1.fc1 4 MB
model.model.decoder.layers.1.fc2 4 MB
model.model.decoder.layers.1.final_layer_norm 0 MB
model.model.decoder.layers.2.self_attn.k_proj 1 MB
model.model.decoder.layers.2.self_attn.v_proj 1 MB
