In [58]:
from data import Openwebtext

# Hyper-parameters
batch_size = 4
block_size = 1024
vocabulary_size = Openwebtext.VOCABULARY_SIZE
embedding_dimensions = 768
num_hidden_layers = 12

First, we'll estimate the total number of parameters in the network.

In [65]:
parameter_counts = {
    "token_embeddings": vocabulary_size * embedding_dimensions,
    "position_embeddings": block_size * embedding_dimensions,
    "multihead_attention": (embedding_dimensions ** 2 + embedding_dimensions * 3 * embedding_dimensions) * num_hidden_layers,
    "mlp": embedding_dimensions * 4 * embedding_dimensions * 2 * num_hidden_layers,
    "layer_norm": embedding_dimensions * num_hidden_layers * 2 + embedding_dimensions,
}

total_parameter_count = sum(parameter_counts.values())

for name, count in parameter_counts.items():
    print(f"{name:20s} {count:20,d} {count / total_parameter_count * 100:10.2f}%")

print("\n")

print(f"Total parameters: {total_parameter_count:,}")

token_embeddings               38,598,144      31.04%
position_embeddings               786,432       0.63%
multihead_attention            28,311,552      22.77%
mlp                            56,623,104      45.54%
layer_norm                         19,200       0.02%


Total parameters: 124,338,432


Next, we'll estimate the size of the model in memory and on disk. Note that this does not include any intermediate variables that get memorized during training.

In [66]:
bytes_per_parameter = 32 // 8 # Assuming 32-bit floating point

buffers_per_parameter = 2 # Assuming AdamW optimizer

total_values = total_parameter_count + buffers_per_parameter * total_parameter_count

total_bytes_per_parameter = total_values * bytes_per_parameter

total_gigabytes = total_bytes_per_parameter / 1e9

print(f"Total gigabytes: {total_gigabytes:,.2f}")

Total gigabytes: 1.49


Next, we'll estimate the number of floating point operations (FLOPs) required to perform a full forward pass of the network. Note that we do not include layer norm operations in this estimate as they are negligible.

In [72]:
ops_per_matmul = 2 # Multiply + accumulate
ops_per_activation = 9 # Assuming GELU

# K, Q, V projections
attention = ops_per_matmul * block_size * embedding_dimensions * 3 * embedding_dimensions

# Attention logits
attention += 2 * ops_per_matmul * block_size ** 2 * embedding_dimensions

# Output projection
attention += ops_per_matmul * block_size * embedding_dimensions ** 2

attention *= num_hidden_layers

# Linear transformations
mlp = 2 * ops_per_matmul * block_size * embedding_dimensions * 4 * embedding_dimensions

# Non-linear activations
mlp += ops_per_activation * 4 * embedding_dimensions

mlp *= num_hidden_layers

output_layer = ops_per_matmul * block_size * embedding_dimensions * vocabulary_size

flops = {
    "attention": attention,
    "mlp": mlp,
    "output_layer": output_layer,
}

total_forward_flops = sum(flops.values())

for name, count in flops.items():
    print(f"{name:20s} {count:20,d} {count / total_forward_flops * 100:10.2f}%")

print("\n")

print(f"Total forward FLOPs: {total_forward_flops:,}")

attention                  96,636,764,160      33.13%
mlp                       115,964,448,768      39.76%
output_layer               79,048,998,912      27.10%


Total forward FLOPs: 291,650,211,840


Next, we'll estimate the number of FLOPs for the backward pass. For this we use a simple heuristic of 2X the forward pass.

In [73]:
total_backward_flops = 2 * total_forward_flops

print(f"Total backward FLOPs: {total_backward_flops:,}")

Total backward FLOPs: 583,300,423,680
