In [6]:
from m2_cw.flops import TrainConfig, InferenceConfig, flops_qwen
import numpy as np

context_lengths = [512, 768]

train_costs = []
infer_costs = []

for length in context_lengths:
    train_conf = TrainConfig(
        sequence_length=length,
        lora_rank=2,
        batch_size=2,
    )
    inf_conf = InferenceConfig(
        sequence_length=length,
        generation_length=20*11,
        lora_rank=2
    )
    flops_train = flops_qwen(train_conf)
    flops_infer = flops_qwen(inf_conf)
    print(f"Context Length: {length}")
    print(f" - One step: {flops_train:.3g}")
    print(f" - One forecast: {flops_infer:.3g}\n")
    train_costs.append(flops_train)
    infer_costs.append(flops_infer)

print(f"Train ratio: {train_costs[0] / train_costs[1] :.2f}")
print(f"Infer ratio: {infer_costs[0] / infer_costs[1] :.2f}")

Context Length: 512
 - One step: 1.48e+12
 - One forecast: 7.79e+13

Context Length: 768
 - One step: 2.29e+12
 - One forecast: 1.19e+14

Train ratio: 0.65
Infer ratio: 0.66


In [8]:
opt_steps = 1500
num_forecasts = 5

context_lengths = [128, 512, 768]

train_costs = []
infer_costs = []

for length in context_lengths:
    train_conf = TrainConfig(
        sequence_length=length,
        lora_rank=2,
        batch_size=2,
    )
    inf_conf = InferenceConfig(
        sequence_length=length,
        generation_length=20*11,
        lora_rank=2
    )
    flops_train = flops_qwen(train_conf) * opt_steps
    flops_infer = flops_qwen(inf_conf) * num_forecasts
    train_costs.append(flops_train)
    infer_costs.append(flops_infer)

for i in range(3):
    print(f"Context Length: {context_lengths[i]}")
    print(f" - Train cost: {train_costs[i]:.3g}")
    print(f" - Infer cost: {infer_costs[i]:.3g}\n")

total_train_flops = np.sum(train_costs)
total_infer_flops = np.sum(infer_costs)
total_flops = total_infer_flops + total_train_flops

print(f"Train: {total_train_flops:.3g}")
print(f"Inference: {total_infer_flops:.3g}")
print(f"Total: {total_flops:.3g}")

Context Length: 128
 - Train cost: 5.27e+14
 - Infer cost: 9.48e+13

Context Length: 512
 - Train cost: 2.22e+15
 - Infer cost: 3.89e+14

Context Length: 768
 - Train cost: 3.43e+15
 - Infer cost: 5.94e+14

Train: 6.17e+15
Inference: 1.08e+15
Total: 7.25e+15
