In [1]:
import itertools as it
import torch
from torch.utils.data import Subset
from transformers import GPT2Tokenizer

%reload_ext autoreload
%autoreload 2

from src import data, modules, pipeline

In [2]:
batch_size = 16
num_heads = 12
embed_dim = 768
context_len = 1024
vocab_size = 50257
device = "cuda"

dataset = data.TinyStoriesDataset(1024, num_stories=500)
train_ds = Subset(dataset, list(range(batch_size)))
eval_ds = Subset(dataset, list(range(batch_size, 2 * batch_size)))

Tokenizing Stories: 100%|███████████████████████████| 500/500 [00:00<00:00, 1277.10 stories/s]


In [3]:
# Overfit on a single batch
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
    model.to(device)
    labels = []
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 0:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.967 | Eval Loss: 9.727 | Tokens/ms: 11.41 | Avg Forward Time: 655.31 | Avg Backward Time: 781.05
fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.030 | Eval Loss: 7.048 | Tokens/ms: 13.80 | Avg Forward Time: 451.55 | Avg Backward Time: 735.85
fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.833 | Eval Loss: 6.002 | Tokens/ms: 13.74 | Avg Forward Time: 453.49 | Avg Backward Time: 738.93
fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.803 | Eval Loss: 5.840 | Tokens/ms: 13.69 | Avg Forward Time: 455.19 | Avg Backward Time: 741.41
fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.056 | Eval Loss: 5.730 | Tokens/ms: 13.66 | Avg Forward Time: 456.08 | Avg Backward Time: 743.36
fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.378 | Eval Loss: 5.754 | Tokens/ms: 13.62 | Avg Forward Time: 457.91 | Avg Backward Time: 745.35


  0%|          | 0/50 [00:00<?, ?batches/s]

tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.034 | Eval Loss: 9.754 | Tokens/ms: 31.80 | Avg Forward Time: 241.66 | Avg Backward Time: 273.50
tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.605 | Eval Loss: 7.407 | Tokens/ms: 31.81 | Avg Forward Time: 242.78 | Avg Backward Time: 272.35
tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.246 | Eval Loss: 6.281 | Tokens/ms: 31.80 | Avg Forward Time: 242.71 | Avg Backward Time: 272.49
tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.138 | Eval Loss: 5.866 | Tokens/ms: 31.81 | Avg Forward Time: 242.76 | Avg Backward Time: 272.37
tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.305 | Eval Loss: 5.765 | Tokens/ms: 31.77 | Avg Forward Time: 243.00 | Avg Backward Time: 272.79
tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.712 | Eval Loss: 5.725 | Tokens/ms: 31.76 | Avg Forward Time: 243.16 | Avg Backward Time: 272.73


  0%|          | 0/50 [00:00<?, ?batches/s]

bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.962 | Eval Loss: 9.681 | Tokens/ms: 28.05 | Avg Forward Time: 328.57 | Avg Backward Time: 255.56
bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.594 | Eval Loss: 7.227 | Tokens/ms: 33.27 | Avg Forward Time: 237.81 | Avg Backward Time: 254.66
bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.158 | Eval Loss: 6.248 | Tokens/ms: 33.27 | Avg Forward Time: 237.77 | Avg Backward Time: 254.73
bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.066 | Eval Loss: 5.856 | Tokens/ms: 33.24 | Avg Forward Time: 238.05 | Avg Backward Time: 254.83
bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.353 | Eval Loss: 5.726 | Tokens/ms: 33.23 | Avg Forward Time: 238.21 | Avg Backward Time: 254.83
bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.723 | Eval Loss: 5.727 | Tokens/ms: 33.21 | Avg Forward Time: 238.41 | Avg Backward Time: 254.95


  0%|          | 0/50 [00:00<?, ?batches/s]

tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.007 | Eval Loss: 9.740 | Tokens/ms: 33.08 | Avg Forward Time: 239.37 | Avg Backward Time: 255.97
tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 7.954 | Eval Loss: 6.861 | Tokens/ms: 33.24 | Avg Forward Time: 238.15 | Avg Backward Time: 254.80
tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.587 | Eval Loss: 6.000 | Tokens/ms: 33.22 | Avg Forward Time: 238.22 | Avg Backward Time: 254.93
tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.519 | Eval Loss: 5.706 | Tokens/ms: 33.22 | Avg Forward Time: 238.29 | Avg Backward Time: 254.91
tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.671 | Eval Loss: 5.677 | Tokens/ms: 33.21 | Avg Forward Time: 238.36 | Avg Backward Time: 255.04
tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 2.946 | Eval Loss: 5.864 | Tokens/ms: 33.19 | Avg Forward Time: 238.64 | Avg Backward Time: 255.02


In [6]:
# Same, but with flash attention
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(
        vocab_size, embed_dim, context_len, num_heads, use_flash_attention=True
    )
    model.to(device)
    labels = ["flash_attn"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 1:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.978 | Eval Loss: 9.687 | Tokens/ms: 16.22 | Avg Forward Time: 384.06 | Avg Backward Time: 625.82
flash_attn, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.623 | Eval Loss: 7.357 | Tokens/ms: 16.23 | Avg Forward Time: 384.99 | Avg Backward Time: 624.20
flash_attn, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.230 | Eval Loss: 6.222 | Tokens/ms: 16.12 | Avg Forward Time: 387.60 | Avg Backward Time: 628.67
flash_attn, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.021 | Eval Loss: 5.917 | Tokens/ms: 16.09 | Avg Forward Time: 388.62 | Avg Backward Time: 629.88
flash_attn, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.256 | Eval Loss: 5.808 | Tokens/ms: 16.05 | Avg Forward Time: 389.18 | Avg Backward Time: 631.92
flash_attn, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.606 | Eval Loss: 5.799 | Tokens/ms: 16.01 | Avg Forward Time: 390.19 | Avg Backward Time: 633.23


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.945 | Eval Loss: 9.659 | Tokens/ms: 34.95 | Avg Forward Time: 212.28 | Avg Backward Time: 256.49
flash_attn, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.612 | Eval Loss: 7.344 | Tokens/ms: 34.84 | Avg Forward Time: 214.63 | Avg Backward Time: 255.65
flash_attn, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.329 | Eval Loss: 6.334 | Tokens/ms: 34.81 | Avg Forward Time: 214.64 | Avg Backward Time: 255.97
flash_attn, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.341 | Eval Loss: 5.959 | Tokens/ms: 34.76 | Avg Forward Time: 215.05 | Avg Backward Time: 256.29
flash_attn, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.729 | Eval Loss: 5.785 | Tokens/ms: 34.73 | Avg Forward Time: 215.29 | Avg Backward Time: 256.41
flash_attn, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.157 | Eval Loss: 5.710 | Tokens/ms: 34.73 | Avg Forward Time: 215.43 | Avg Backward Time: 256.26


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.934 | Eval Loss: 9.695 | Tokens/ms: 45.59 | Avg Forward Time: 176.26 | Avg Backward Time: 183.12
flash_attn, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.646 | Eval Loss: 7.395 | Tokens/ms: 45.50 | Avg Forward Time: 178.13 | Avg Backward Time: 181.97
flash_attn, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.353 | Eval Loss: 6.464 | Tokens/ms: 45.51 | Avg Forward Time: 177.98 | Avg Backward Time: 182.03
flash_attn, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.483 | Eval Loss: 6.167 | Tokens/ms: 45.53 | Avg Forward Time: 177.85 | Avg Backward Time: 182.02
flash_attn, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.940 | Eval Loss: 5.848 | Tokens/ms: 45.47 | Avg Forward Time: 178.05 | Avg Backward Time: 182.26
flash_attn, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.359 | Eval Loss: 5.729 | Tokens/ms: 45.42 | Avg Forward Time: 178.24 | Avg Backward Time: 182.45


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.938 | Eval Loss: 9.649 | Tokens/ms: 45.55 | Avg Forward Time: 176.44 | Avg Backward Time: 183.22
flash_attn, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.574 | Eval Loss: 7.306 | Tokens/ms: 45.51 | Avg Forward Time: 178.07 | Avg Backward Time: 181.95
flash_attn, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.276 | Eval Loss: 6.372 | Tokens/ms: 45.51 | Avg Forward Time: 177.84 | Avg Backward Time: 182.17
flash_attn, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.177 | Eval Loss: 5.896 | Tokens/ms: 45.46 | Avg Forward Time: 178.28 | Avg Backward Time: 182.09
flash_attn, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.420 | Eval Loss: 5.754 | Tokens/ms: 45.43 | Avg Forward Time: 178.29 | Avg Backward Time: 182.37
flash_attn, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.749 | Eval Loss: 5.756 | Tokens/ms: 45.44 | Avg Forward Time: 178.21 | Avg B

In [4]:
# Same, but with torch.compile
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
    model.to(device)
    model = torch.compile(model)
    labels = ["torch.compile"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 1:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]



torch.compile, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.032 | Eval Loss: 9.736 | Tokens/ms: 0.71 | Avg Forward Time: 15093.34 | Avg Backward Time: 7892.28
torch.compile, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.639 | Eval Loss: 7.529 | Tokens/ms: 17.25 | Avg Forward Time: 305.12 | Avg Backward Time: 644.86
torch.compile, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.219 | Eval Loss: 6.170 | Tokens/ms: 17.14 | Avg Forward Time: 306.53 | Avg Backward Time: 649.11
torch.compile, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.846 | Eval Loss: 5.809 | Tokens/ms: 17.09 | Avg Forward Time: 307.90 | Avg Backward Time: 650.79
torch.compile, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.981 | Eval Loss: 5.675 | Tokens/ms: 17.05 | Avg Forward Time: 308.27 | Avg Backward Time: 652.72
torch.compile, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.261 | Eval Loss: 5.842 | Tokens/ms: 16.99 | Avg Forward Time: 308.61 | Avg Backward Time: 65

  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.001 | Eval Loss: 9.714 | Tokens/ms: 0.75 | Avg Forward Time: 14287.59 | Avg Backward Time: 7610.61
torch.compile, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.659 | Eval Loss: 7.294 | Tokens/ms: 64.92 | Avg Forward Time: 84.38 | Avg Backward Time: 168.00
torch.compile, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.187 | Eval Loss: 6.204 | Tokens/ms: 64.88 | Avg Forward Time: 84.46 | Avg Backward Time: 168.06
torch.compile, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.980 | Eval Loss: 5.862 | Tokens/ms: 64.79 | Avg Forward Time: 84.62 | Avg Backward Time: 168.26
torch.compile, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.214 | Eval Loss: 5.744 | Tokens/ms: 64.64 | Avg Forward Time: 84.71 | Avg Backward Time: 168.74
torch.compile, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.541 | Eval Loss: 5.798 | Tokens/ms: 64.63 | Avg Forward Time: 84.87 | Avg Backward Time: 168.62


  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.959 | Eval Loss: 9.696 | Tokens/ms: 0.61 | Avg Forward Time: 16613.58 | Avg Backward Time: 10352.44
torch.compile, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.609 | Eval Loss: 7.454 | Tokens/ms: 94.95 | Avg Forward Time: 73.56 | Avg Backward Time: 99.00
torch.compile, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.349 | Eval Loss: 6.432 | Tokens/ms: 94.78 | Avg Forward Time: 73.68 | Avg Backward Time: 99.18
torch.compile, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.725 | Eval Loss: 6.290 | Tokens/ms: 94.81 | Avg Forward Time: 73.72 | Avg Backward Time: 99.08
torch.compile, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.417 | Eval Loss: 6.138 | Tokens/ms: 94.79 | Avg Forward Time: 73.67 | Avg Backward Time: 99.18
torch.compile, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.964 | Eval Loss: 5.956 | Tokens/ms: 94.62 | Avg Forward Time: 73.89 | Avg Backward Time: 99.26


  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.974 | Eval Loss: 9.655 | Tokens/ms: 0.66 | Avg Forward Time: 15907.80 | Avg Backward Time: 9020.55
torch.compile, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.167 | Eval Loss: 7.274 | Tokens/ms: 94.98 | Avg Forward Time: 73.65 | Avg Backward Time: 98.85
torch.compile, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.309 | Eval Loss: 6.407 | Tokens/ms: 94.76 | Avg Forward Time: 73.76 | Avg Backward Time: 99.13
torch.compile, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.465 | Eval Loss: 6.050 | Tokens/ms: 94.60 | Avg Forward Time: 74.04 | Avg Backward Time: 99.15
torch.compile, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.880 | Eval Loss: 5.831 | Tokens/ms: 94.61 | Avg Forward Time: 74.00 | Avg Backward Time: 99.17
torch.compile, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.317 | Eval Loss: 5.658 | Tokens/ms: 94.50 | Avg Forward Time: 74

In [5]:
# Same, but with torch.compile & flash attention
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(
        vocab_size, embed_dim, context_len, num_heads, use_flash_attention=True
    )
    model.to(device)
    model = torch.compile(model)
    labels = ["flash_attn", "torch.compile"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 2:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.976 | Eval Loss: 9.711 | Tokens/ms: 0.93 | Avg Forward Time: 11386.35 | Avg Backward Time: 6199.19
flash_attn, torch.compile, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.580 | Eval Loss: 7.527 | Tokens/ms: 18.92 | Avg Forward Time: 278.51 | Avg Backward Time: 587.64
flash_attn, torch.compile, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.438 | Eval Loss: 6.363 | Tokens/ms: 18.86 | Avg Forward Time: 279.46 | Avg Backward Time: 589.48
flash_attn, torch.compile, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.325 | Eval Loss: 5.908 | Tokens/ms: 18.80 | Avg Forward Time: 280.25 | Avg Backward Time: 591.30
flash_attn, torch.compile, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.571 | Eval Loss: 5.721 | Tokens/ms: 18.70 | Avg Forward Time: 281.77 | Avg Backward Time: 594.32
flash_attn, torch.compile, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.956 | Eval Loss: 5.6

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.026 | Eval Loss: 9.683 | Tokens/ms: 1.01 | Avg Forward Time: 10227.77 | Avg Backward Time: 5971.18
flash_attn, torch.compile, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.035 | Eval Loss: 7.018 | Tokens/ms: 54.28 | Avg Forward Time: 95.05 | Avg Backward Time: 206.81
flash_attn, torch.compile, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.043 | Eval Loss: 6.319 | Tokens/ms: 54.27 | Avg Forward Time: 95.21 | Avg Backward Time: 206.70
flash_attn, torch.compile, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.700 | Eval Loss: 6.350 | Tokens/ms: 54.16 | Avg Forward Time: 95.20 | Avg Backward Time: 207.29
flash_attn, torch.compile, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.547 | Eval Loss: 6.499 | Tokens/ms: 54.18 | Avg Forward Time: 95.22 | Avg Backward Time: 207.19
flash_attn, torch.compile, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 5.415 | Eval Loss: 6.406 |

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.033 | Eval Loss: 9.774 | Tokens/ms: 0.81 | Avg Forward Time: 11740.24 | Avg Backward Time: 8395.95
flash_attn, torch.compile, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.620 | Eval Loss: 7.567 | Tokens/ms: 97.66 | Avg Forward Time: 65.67 | Avg Backward Time: 102.10
flash_attn, torch.compile, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.459 | Eval Loss: 6.410 | Tokens/ms: 97.53 | Avg Forward Time: 65.75 | Avg Backward Time: 102.23
flash_attn, torch.compile, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.305 | Eval Loss: 5.957 | Tokens/ms: 97.38 | Avg Forward Time: 65.85 | Avg Backward Time: 102.40
flash_attn, torch.compile, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.595 | Eval Loss: 5.705 | Tokens/ms: 97.19 | Avg Forward Time: 66.08 | Avg Backward Time: 102.50
flash_attn, torch.compile, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.977 | Eval Loss: 5.618 |

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.937 | Eval Loss: 9.633 | Tokens/ms: 0.82 | Avg Forward Time: 13254.40 | Avg Backward Time: 6786.15
flash_attn, torch.compile, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.597 | Eval Loss: 7.333 | Tokens/ms: 97.76 | Avg Forward Time: 65.71 | Avg Backward Time: 101.88
flash_attn, torch.compile, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.253 | Eval Loss: 6.291 | Tokens/ms: 97.60 | Avg Forward Time: 65.78 | Avg Backward Time: 102.10
flash_attn, torch.compile, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.253 | Eval Loss: 5.974 | Tokens/ms: 97.61 | Avg Forward Time: 65.80 | Avg Backward Time: 102.05
flash_attn, torch.compile, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.639 | Eval Loss: 5.785 | Tokens/ms: 97.10 | Avg Forward Time: 66.24 | Avg Backward Time: 102.49
flash_attn, torch.compile, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Tr

```
TODO: Dig into why flash attention + torch.compile is only a bit faster than just torch.compile.
Notably, the tokens processing rate is worse for TF32 and backward passes are slower for every configuration.
```