# Demostration of the FLOPS calculations

In [2]:
from m2_cw.flops import flops_qwen, printsf

hidden_features=4864
# vocabulary_size=151936
vocabulary_size=13
embedding_dimension=896
sequence_length=512
attention_heads=14
transformer_layers=24
lookup_table=True

total_flops = int(1e17)

So, we have a budget of 88,000 optimiser steps at 512 context length. almost doubling the number of optimiser steps that we can take since we have reduced the size of the vocabulary to only what we need

## Embeddings

In [None]:
vocabulary_sizes = [vocabulary_size, 13]

for vocab in vocabulary_sizes:
    embedding = flops_embedding(embedding_dimension=embedding_dimension,
                                        sequence_length=sequence_length,
                                        vocabulary_size=vocab)
    printsf(embedding, 2, prefix=f"vocab size {vocab}\n")

## Transformer Block

In [None]:
one_self_atte = flops_self_atte(embedding_dimension=embedding_dimension,
                                attention_heads=attention_heads,
                                sequence_length=sequence_length)

printsf(one_self_atte)

In [None]:
one_rms_norm = flops_rmsnorm(in_features=embedding_dimension,
                             sequence_length=sequence_length)

printsf(one_rms_norm)

In [None]:
one_mlp = flops_mlp(in_features=embedding_dimension,
                    hidden_features=hidden_features,
                    out_features=embedding_dimension,
                    sequence_length=sequence_length)

printsf(one_mlp)

In [None]:
one_transformer_block = flops_transformer(embedding_dimension=embedding_dimension,
                                          hidden_features=hidden_features,
                                          sequence_length=sequence_length,
                                          attention_heads=attention_heads)

all_transformer_blocks = transformer_layers * one_transformer_block

printsf(one_transformer_block)
printsf(all_transformer_blocks)

## Language Model Head

In [None]:
vocabulary_sizes = [vocabulary_size, 13]

for vocab in vocabulary_sizes:
    one_lm_head = flops_linear(in_features=embedding_dimension,
                            out_features=vocab,
                            sequence_length=sequence_length,
                            bias=True)

    printsf(one_lm_head, 2, prefix=f"vocab size {vocab}\n", add_newline=True)

In [None]:
forward_pass_via_lookup_table_embedding = flops_qwen(hidden_features=hidden_features,
                                           vocabulary_size=vocabulary_size,
                                           embedding_dimension=embedding_dimension,
                                           sequence_length=sequence_length,
                                           attention_heads=attention_heads,
                                           transformer_layers=transformer_layers,
                                           lookup_table=True)

forward_pass_via_matrix_embedding = flops_qwen(hidden_features=hidden_features,
                                           vocabulary_size=vocabulary_size,
                                           embedding_dimension=embedding_dimension,
                                           sequence_length=sequence_length,
                                           attention_heads=attention_heads,
                                           transformer_layers=transformer_layers,
                                           lookup_table=False)

printsf(forward_pass_via_lookup_table_embedding) 
printsf(forward_pass_via_matrix_embedding) 

In [None]:
context_lengths = [128, 512, 768]
vocabulary_sizes = [vocabulary_size, 13]

for seq_length in context_lengths:
    for vocab_size in vocabulary_sizes:
        forward_pass = flops_qwen(vocabulary_size=vocab_size,
                                  sequence_length=seq_length,
                                  lookup_table=False)
        printsf(forward_pass, 2, prefix=f"context {seq_length}, vocab {vocab_size}\n", add_newline=True)

In [7]:
flops_training = flops_qwen(embedding_dimension=embedding_dimension,
                   hidden_features=hidden_features,
                   sequence_length=sequence_length,
                   attention_heads=attention_heads,
                   transformer_layers=transformer_layers,
                   vocabulary_size=vocabulary_size,
                   lookup_table=False,
                   mode="training",
                   batch_size=1,
                   lora_rank=1)

flops_inference = flops_qwen(embedding_dimension=embedding_dimension,
                   hidden_features=hidden_features,
                   sequence_length=sequence_length,
                   attention_heads=attention_heads,
                   transformer_layers=transformer_layers,
                   vocabulary_size=vocabulary_size,
                   lookup_table=False,
                   mode="inference",
                   lora_rank=1,
                   generation_length=20 * 13)
                   
total_inference_flops = 0
total_inference_flops += flops_inference * 100 * 3 # Flops for the three big inference steps at the beginning, middle, and end
total_inference_flops += flops_inference * 10 * 10 # Flops for small inference steps in hyperparameter search

total_training_flops = total_flops - total_inference_flops

printsf(total_inference_flops, prefix="Inference Flops", sf=2)
printsf(total_training_flops, prefix="Training Flops", sf=2)

print(total_training_flops // flops_training)

Inference Flops: 3.6e+16
Training Flops: 6.4e+16
59200
