# Demostration of the FLOPS calculations

In [2]:
from m2_cw.flops import FlopsConfig

config = FlopsConfig()

total_flops = int(1e17)

## Embeddings

In [3]:
from m2_cw.flops import flops_embedding

print(f"{flops_embedding(config):.3g}")

1.15e+07


## Self Attention

In [4]:
from m2_cw.flops import flops_self_attention

print(f"{flops_self_attention(config):.3g}")

2e+09


## RMSNorm

In [5]:
from m2_cw.flops import flops_rmsnorm

print(f"{flops_rmsnorm(config.embedding_dimension, config.sequence_length):.3g}")

1.38e+06


## MLP

In [6]:
from m2_cw.flops import flops_mlp

print(f"{flops_mlp(config):.3g}")

1.34e+10


## Language Model Head

In [7]:
from m2_cw.flops import flops_linear

print(f"{flops_linear(config.embedding_dimension, config.vocabulary_size, config.sequence_length):.3g}")

1.19e+07


## Full Training Pass

In [8]:
from dataclasses import fields
from m2_cw.flops import flops_qwen

training_config = FlopsConfig(batch_size=2)
inference_config = FlopsConfig(
    generation_length=20*10,
    mode="inference",
)

sf = 2

training_pass = flops_qwen(training_config)
forward_pass = training_pass / (training_config.batch_size + 2)
inference_pass = flops_qwen(inference_config)

print(f"Training Pass: {training_pass:.{sf}g}")
print(f"Inference Pass: {inference_pass:.{sf}g}")
print(f"Forward Pass: {forward_pass:.{sf}g}")
print(f"Inference / Forward: {inference_pass // forward_pass}")

Training Pass: 1.5e+12
Inference Pass: 7.1e+13
Forward Pass: 3.7e+11
Inference / Forward: 191.0


In [9]:
max_training_passes = total_flops // training_pass
max_inference_passes = total_flops // inference_pass

print(f"{max_training_passes}")
print(f"{max_inference_passes}")

67662.0
1412.0


## Flop Breakdown

In [63]:
from m2_cw.flops import InferenceConfig, TrainConfig

tokens_per_time_step = 11
num_time_steps = 20
generation_length = tokens_per_time_step * num_time_steps

full_forecasts = 50
mid_forecasts = 25
h1_forecasts = 5
h2_forecasts = 5


inference_configs = {
    "baseline": [
        {
            "config": InferenceConfig(generation_length=generation_length,
                                      lora_rank=0),
            "num_forecasts": full_forecasts
        },
    ],
    "default": [
        {
            "config": InferenceConfig(generation_length=generation_length,
                                  lora_rank=4),
            "num_forecasts": mid_forecasts
        },
    ],
    "hyper1": [
        {
            "config": InferenceConfig(generation_length=generation_length,
                                  lora_rank=2),
            "num_forecasts": 3 * h1_forecasts,
        },
        {
            "config": InferenceConfig(generation_length=generation_length,
                                  lora_rank=4),
            "num_forecasts": 3 * h1_forecasts,
        },
        {
            "config": InferenceConfig(generation_length=generation_length,
                                  lora_rank=8),
            "num_forecasts": 3* h1_forecasts,
        },
    ],
    "hyper2": [
        {
            "config": InferenceConfig(sequence_length=128,
                                  generation_length=generation_length,
                                  lora_rank=4),
            "num_forecasts": h2_forecasts,
        },
        {
            "config": InferenceConfig(sequence_length=512,
                                  generation_length=generation_length,
                                  lora_rank=4),
            "num_forecasts": h2_forecasts,
        },
        {
            "config": InferenceConfig(sequence_length=768),
            "num_forecasts": h2_forecasts,
        },
    ],
    "final": [
        {
            "config": InferenceConfig(sequence_length=512,
                                  generation_length=generation_length,
                                  lora_rank=4),
            "num_forecasts": full_forecasts,
        },
    ]
}

batch_size = 2
max_steps = 20000
mid_steps = 5000
h1_steps = 3150
h2_steps = 1500

train_configs = {
    "default": [
        {
            "config": TrainConfig(lora_rank=4,
                                  batch_size=batch_size),
            "num_steps": mid_steps
        },
    ],
    "hyper1": [
        {
            "config": TrainConfig(lora_rank=2,
                                  batch_size=batch_size),
            "num_steps": 3 * h1_steps,
        },
        {
            "config": TrainConfig(lora_rank=4,
                                  batch_size=batch_size),
            "num_steps": 3 * h1_steps,
        },
        {
            "config": TrainConfig(lora_rank=8,
                                  batch_size=batch_size),
            "num_steps": 3* h1_steps,
        },
    ],
    "hyper2": [
        {
            "config": TrainConfig(sequence_length=128,
                                  lora_rank=4,
                                  batch_size=batch_size),
            "num_steps": h2_steps,
        },
        {
            "config": TrainConfig(sequence_length=512,
                                  lora_rank=4,
                                  batch_size=batch_size),
            "num_steps": h2_steps,
        },
        {
            "config": TrainConfig(sequence_length=768,
                                  lora_rank=4,
                                  batch_size=batch_size),
            "num_steps": h2_steps,
        },
    ],
    "final": [
        {
            "config": TrainConfig(sequence_length=512,
                                  lora_rank=4,
                                  batch_size=batch_size),
            "num_steps": max_steps,
        },
    ]
}

inference_flops = {}
inference_passes = 0
for title, config_list in inference_configs.items():
    flops = 0
    for run in config_list:
        flops += flops_qwen(run["config"]) * run["num_forecasts"]
        inference_passes += run["num_forecasts"]
    
    inference_flops[title] = flops

total_inference_flops = 0
for k, v in inference_flops.items():
    print(f"{k}: {v:.3g}")
    total_inference_flops += v

print(f"Total Inference Flops: {total_inference_flops:.3g}, {100 * total_inference_flops / total_flops:.2f}% of budget. \nTotal Forecasts: {inference_passes}.\n")

train_flops = {}
opt_steps = 0
for title, config_list in train_configs.items():
    flops = 0
    for run in config_list:
        flops += flops_qwen(run["config"]) * run["num_steps"]
        opt_steps += run["num_steps"]
    
    train_flops[title] = flops

total_train_flops = 0
for k, v in train_flops.items():
    print(f"{k}: {v:.3g}")
    total_train_flops += v

print(f"Total Train Flops: {total_train_flops:.3g}, {100 * total_train_flops / total_flops:.2f}% of budget.\nTotal Optimiser Steps: {opt_steps}.\n")

print(f"Total Usage: { 100 * (total_train_flops + total_inference_flops) / total_flops :.2f}%.")

baseline: 3.89e+15
default: 1.95e+15
hyper1: 3.51e+15
hyper2: 1.19e+15
final: 3.89e+15
Total Inference Flops: 1.44e+16, 14.43% of budget. 
Total Forecasts: 185.

default: 7.39e+15
hyper1: 4.19e+16
hyper2: 6.18e+15
final: 2.96e+16
Total Train Flops: 8.5e+16, 85.03% of budget.
Total Optimiser Steps: 57850.

Total Usage: 99.46%.
