In [1]:
import itertools as it
import torch
from torch.utils.data import Subset
from transformers import GPT2Tokenizer

%reload_ext autoreload
%autoreload 2

from src import data, modules, pipeline

In [2]:
batch_size = 16
num_heads = 12
embed_dim = 768
context_len = 1024
vocab_size = 50257
device = "cuda"

dataset = data.TinyStoriesDataset(1024, num_stories=500)
train_ds = Subset(dataset, list(range(batch_size)))
eval_ds = Subset(dataset, list(range(batch_size, 2 * batch_size)))

Tokenizing Stories: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 1298.82 stories/s]


In [3]:
# Overfit on a single batch
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
    model.to(device)
    labels = []
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 0:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.967 | Eval Loss: 9.727 | Tokens/ms: 11.41 | Avg Forward Time: 655.31 | Avg Backward Time: 781.05
fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.030 | Eval Loss: 7.048 | Tokens/ms: 13.80 | Avg Forward Time: 451.55 | Avg Backward Time: 735.85
fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.833 | Eval Loss: 6.002 | Tokens/ms: 13.74 | Avg Forward Time: 453.49 | Avg Backward Time: 738.93
fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.803 | Eval Loss: 5.840 | Tokens/ms: 13.69 | Avg Forward Time: 455.19 | Avg Backward Time: 741.41
fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.056 | Eval Loss: 5.730 | Tokens/ms: 13.66 | Avg Forward Time: 456.08 | Avg Backward Time: 743.36
fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.378 | Eval Loss: 5.754 | Tokens/ms: 13.62 | Avg Forward Time: 457.91 | Avg Backward Time: 745.35


  0%|          | 0/50 [00:00<?, ?batches/s]

tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.034 | Eval Loss: 9.754 | Tokens/ms: 31.80 | Avg Forward Time: 241.66 | Avg Backward Time: 273.50
tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.605 | Eval Loss: 7.407 | Tokens/ms: 31.81 | Avg Forward Time: 242.78 | Avg Backward Time: 272.35
tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.246 | Eval Loss: 6.281 | Tokens/ms: 31.80 | Avg Forward Time: 242.71 | Avg Backward Time: 272.49
tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.138 | Eval Loss: 5.866 | Tokens/ms: 31.81 | Avg Forward Time: 242.76 | Avg Backward Time: 272.37
tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.305 | Eval Loss: 5.765 | Tokens/ms: 31.77 | Avg Forward Time: 243.00 | Avg Backward Time: 272.79
tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.712 | Eval Loss: 5.725 | Tokens/ms: 31.76 | Avg Forward Time: 243.16 | Avg Backward Time: 272.73


  0%|          | 0/50 [00:00<?, ?batches/s]

bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.962 | Eval Loss: 9.681 | Tokens/ms: 28.05 | Avg Forward Time: 328.57 | Avg Backward Time: 255.56
bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.594 | Eval Loss: 7.227 | Tokens/ms: 33.27 | Avg Forward Time: 237.81 | Avg Backward Time: 254.66
bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.158 | Eval Loss: 6.248 | Tokens/ms: 33.27 | Avg Forward Time: 237.77 | Avg Backward Time: 254.73
bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.066 | Eval Loss: 5.856 | Tokens/ms: 33.24 | Avg Forward Time: 238.05 | Avg Backward Time: 254.83
bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.353 | Eval Loss: 5.726 | Tokens/ms: 33.23 | Avg Forward Time: 238.21 | Avg Backward Time: 254.83
bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.723 | Eval Loss: 5.727 | Tokens/ms: 33.21 | Avg Forward Time: 238.41 | Avg Backward Time: 254.95


  0%|          | 0/50 [00:00<?, ?batches/s]

tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.007 | Eval Loss: 9.740 | Tokens/ms: 33.08 | Avg Forward Time: 239.37 | Avg Backward Time: 255.97
tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 7.954 | Eval Loss: 6.861 | Tokens/ms: 33.24 | Avg Forward Time: 238.15 | Avg Backward Time: 254.80
tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.587 | Eval Loss: 6.000 | Tokens/ms: 33.22 | Avg Forward Time: 238.22 | Avg Backward Time: 254.93
tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.519 | Eval Loss: 5.706 | Tokens/ms: 33.22 | Avg Forward Time: 238.29 | Avg Backward Time: 254.91
tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.671 | Eval Loss: 5.677 | Tokens/ms: 33.21 | Avg Forward Time: 238.36 | Avg Backward Time: 255.04
tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 2.946 | Eval Loss: 5.864 | Tokens/ms: 33.19 | Avg Forward Time: 238.64 | Avg Backward Time: 255.02


In [12]:
# Same, but with flash attention
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(
        vocab_size, embed_dim, context_len, num_heads, use_flash_attention=True
    )
    model.to(device)
    labels = ["flash_attn"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 1:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.971 | Eval Loss: 9.673 | Tokens/ms: 15.07 | Avg Forward Time: 407.94 | Avg Backward Time: 678.96
flash_attn, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.208 | Eval Loss: 7.326 | Tokens/ms: 17.12 | Avg Forward Time: 366.97 | Avg Backward Time: 589.95
flash_attn, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.197 | Eval Loss: 6.810 | Tokens/ms: 17.10 | Avg Forward Time: 368.13 | Avg Backward Time: 590.16
flash_attn, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.685 | Eval Loss: 6.351 | Tokens/ms: 17.06 | Avg Forward Time: 368.89 | Avg Backward Time: 591.58
flash_attn, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.561 | Eval Loss: 6.516 | Tokens/ms: 16.99 | Avg Forward Time: 369.23 | Avg Backward Time: 595.04
flash_attn, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 5.357 | Eval Loss: 6.605 | Tokens/ms: 16.94 | Avg Forward Time: 371.01 | Avg Backward Time: 595.91


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.966 | Eval Loss: 9.719 | Tokens/ms: 39.69 | Avg Forward Time: 193.29 | Avg Backward Time: 219.47
flash_attn, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.407 | Eval Loss: 7.499 | Tokens/ms: 39.72 | Avg Forward Time: 195.63 | Avg Backward Time: 216.90
flash_attn, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.464 | Eval Loss: 6.371 | Tokens/ms: 39.73 | Avg Forward Time: 195.74 | Avg Backward Time: 216.63
flash_attn, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.639 | Eval Loss: 6.308 | Tokens/ms: 39.72 | Avg Forward Time: 195.74 | Avg Backward Time: 216.78
flash_attn, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.394 | Eval Loss: 6.114 | Tokens/ms: 39.71 | Avg Forward Time: 195.89 | Avg Backward Time: 216.72
flash_attn, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.920 | Eval Loss: 5.956 | Tokens/ms: 39.68 | Avg Forward Time: 195.82 | Avg Backward Time: 217.09


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.959 | Eval Loss: 9.655 | Tokens/ms: 49.58 | Avg Forward Time: 167.46 | Avg Backward Time: 163.00
flash_attn, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.534 | Eval Loss: 7.489 | Tokens/ms: 49.52 | Avg Forward Time: 168.89 | Avg Backward Time: 161.99
flash_attn, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.493 | Eval Loss: 6.463 | Tokens/ms: 49.52 | Avg Forward Time: 168.81 | Avg Backward Time: 162.07
flash_attn, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.760 | Eval Loss: 6.316 | Tokens/ms: 49.51 | Avg Forward Time: 168.86 | Avg Backward Time: 162.07
flash_attn, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.353 | Eval Loss: 6.260 | Tokens/ms: 49.47 | Avg Forward Time: 169.10 | Avg Backward Time: 162.07
flash_attn, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 5.298 | Eval Loss: 6.270 | Tokens/ms: 49.49 | Avg Forward Time: 168.95 | Avg Backward Time: 162.13


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.907 | Eval Loss: 9.596 | Tokens/ms: 49.69 | Avg Forward Time: 166.37 | Avg Backward Time: 163.36
flash_attn, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.717 | Eval Loss: 7.326 | Tokens/ms: 49.46 | Avg Forward Time: 169.20 | Avg Backward Time: 162.05
flash_attn, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.575 | Eval Loss: 6.692 | Tokens/ms: 49.55 | Avg Forward Time: 168.70 | Avg Backward Time: 161.96
flash_attn, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.976 | Eval Loss: 6.506 | Tokens/ms: 49.56 | Avg Forward Time: 168.60 | Avg Backward Time: 161.98
flash_attn, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.789 | Eval Loss: 6.509 | Tokens/ms: 49.58 | Avg Forward Time: 168.58 | Avg Backward Time: 161.91
flash_attn, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 5.720 | Eval Loss: 6.507 | Tokens/ms: 49.55 | Avg Forward Time: 168.75 | Avg B

In [4]:
# Same, but with torch.compile
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
    model.to(device)
    model = torch.compile(model)
    labels = ["torch.compile"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 1:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]



torch.compile, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.032 | Eval Loss: 9.736 | Tokens/ms: 0.71 | Avg Forward Time: 15093.34 | Avg Backward Time: 7892.28
torch.compile, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.639 | Eval Loss: 7.529 | Tokens/ms: 17.25 | Avg Forward Time: 305.12 | Avg Backward Time: 644.86
torch.compile, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.219 | Eval Loss: 6.170 | Tokens/ms: 17.14 | Avg Forward Time: 306.53 | Avg Backward Time: 649.11
torch.compile, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.846 | Eval Loss: 5.809 | Tokens/ms: 17.09 | Avg Forward Time: 307.90 | Avg Backward Time: 650.79
torch.compile, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.981 | Eval Loss: 5.675 | Tokens/ms: 17.05 | Avg Forward Time: 308.27 | Avg Backward Time: 652.72
torch.compile, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.261 | Eval Loss: 5.842 | Tokens/ms: 16.99 | Avg Forward Time: 308.61 | Avg Backward Time: 65

  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.001 | Eval Loss: 9.714 | Tokens/ms: 0.75 | Avg Forward Time: 14287.59 | Avg Backward Time: 7610.61
torch.compile, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.659 | Eval Loss: 7.294 | Tokens/ms: 64.92 | Avg Forward Time: 84.38 | Avg Backward Time: 168.00
torch.compile, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.187 | Eval Loss: 6.204 | Tokens/ms: 64.88 | Avg Forward Time: 84.46 | Avg Backward Time: 168.06
torch.compile, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.980 | Eval Loss: 5.862 | Tokens/ms: 64.79 | Avg Forward Time: 84.62 | Avg Backward Time: 168.26
torch.compile, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.214 | Eval Loss: 5.744 | Tokens/ms: 64.64 | Avg Forward Time: 84.71 | Avg Backward Time: 168.74
torch.compile, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.541 | Eval Loss: 5.798 | Tokens/ms: 64.63 | Avg Forward Time: 84.87 | Avg Backward Time: 168.62


  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.959 | Eval Loss: 9.696 | Tokens/ms: 0.61 | Avg Forward Time: 16613.58 | Avg Backward Time: 10352.44
torch.compile, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.609 | Eval Loss: 7.454 | Tokens/ms: 94.95 | Avg Forward Time: 73.56 | Avg Backward Time: 99.00
torch.compile, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.349 | Eval Loss: 6.432 | Tokens/ms: 94.78 | Avg Forward Time: 73.68 | Avg Backward Time: 99.18
torch.compile, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.725 | Eval Loss: 6.290 | Tokens/ms: 94.81 | Avg Forward Time: 73.72 | Avg Backward Time: 99.08
torch.compile, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 5.417 | Eval Loss: 6.138 | Tokens/ms: 94.79 | Avg Forward Time: 73.67 | Avg Backward Time: 99.18
torch.compile, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.964 | Eval Loss: 5.956 | Tokens/ms: 94.62 | Avg Forward Time: 73.89 | Avg Backward Time: 99.26


  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.974 | Eval Loss: 9.655 | Tokens/ms: 0.66 | Avg Forward Time: 15907.80 | Avg Backward Time: 9020.55
torch.compile, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.167 | Eval Loss: 7.274 | Tokens/ms: 94.98 | Avg Forward Time: 73.65 | Avg Backward Time: 98.85
torch.compile, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.309 | Eval Loss: 6.407 | Tokens/ms: 94.76 | Avg Forward Time: 73.76 | Avg Backward Time: 99.13
torch.compile, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.465 | Eval Loss: 6.050 | Tokens/ms: 94.60 | Avg Forward Time: 74.04 | Avg Backward Time: 99.15
torch.compile, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.880 | Eval Loss: 5.831 | Tokens/ms: 94.61 | Avg Forward Time: 74.00 | Avg Backward Time: 99.17
torch.compile, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.317 | Eval Loss: 5.658 | Tokens/ms: 94.50 | Avg Forward Time: 74

In [11]:
# Same, but with torch.compile & flash attention
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    g = torch.Generator().manual_seed(42)
    model = modules.GPT2(
        vocab_size, embed_dim, context_len, num_heads, use_flash_attention=True
    )
    model.to(device)
    model = torch.compile(model)
    labels = ["flash_attn", "torch.compile"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 2:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        generator=g,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]



flash_attn, torch.compile, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.967 | Eval Loss: 9.670 | Tokens/ms: 0.90 | Avg Forward Time: 11374.39 | Avg Backward Time: 6808.77
flash_attn, torch.compile, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.372 | Eval Loss: 7.470 | Tokens/ms: 20.40 | Avg Forward Time: 257.53 | Avg Backward Time: 545.66
flash_attn, torch.compile, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.452 | Eval Loss: 6.399 | Tokens/ms: 20.30 | Avg Forward Time: 258.65 | Avg Backward Time: 548.29
flash_attn, torch.compile, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.395 | Eval Loss: 6.017 | Tokens/ms: 20.22 | Avg Forward Time: 259.56 | Avg Backward Time: 550.71
flash_attn, torch.compile, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.720 | Eval Loss: 5.787 | Tokens/ms: 20.14 | Avg Forward Time: 260.32 | Avg Backward Time: 553.19
flash_attn, torch.compile, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 4.190 | Eval Loss: 5.7

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.022 | Eval Loss: 9.726 | Tokens/ms: 1.00 | Avg Forward Time: 9995.72 | Avg Backward Time: 6338.91
flash_attn, torch.compile, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.618 | Eval Loss: 7.350 | Tokens/ms: 67.24 | Avg Forward Time: 75.76 | Avg Backward Time: 167.90
flash_attn, torch.compile, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.092 | Eval Loss: 6.109 | Tokens/ms: 66.98 | Avg Forward Time: 75.97 | Avg Backward Time: 168.65
flash_attn, torch.compile, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.752 | Eval Loss: 5.764 | Tokens/ms: 66.96 | Avg Forward Time: 75.96 | Avg Backward Time: 168.71
flash_attn, torch.compile, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.856 | Eval Loss: 5.744 | Tokens/ms: 66.90 | Avg Forward Time: 76.14 | Avg Backward Time: 168.75
flash_attn, torch.compile, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.060 | Eval Loss: 5.853 | 

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.047 | Eval Loss: 9.761 | Tokens/ms: 0.82 | Avg Forward Time: 12892.33 | Avg Backward Time: 7049.79
flash_attn, torch.compile, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.393 | Tokens/ms: 117.67 | Avg Forward Time: 57.44 | Avg Backward Time: 81.80
flash_attn, torch.compile, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.140 | Eval Loss: 6.187 | Tokens/ms: 117.63 | Avg Forward Time: 57.35 | Avg Backward Time: 81.93
flash_attn, torch.compile, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.910 | Eval Loss: 5.860 | Tokens/ms: 117.64 | Avg Forward Time: 57.31 | Avg Backward Time: 81.97
flash_attn, torch.compile, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.097 | Eval Loss: 5.732 | Tokens/ms: 117.33 | Avg Forward Time: 57.58 | Avg Backward Time: 82.06
flash_attn, torch.compile, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.351 | Eval Loss: 5.875 |

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 10.937 | Eval Loss: 9.709 | Tokens/ms: 0.81 | Avg Forward Time: 13226.99 | Avg Backward Time: 7056.87
flash_attn, torch.compile, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.595 | Eval Loss: 7.305 | Tokens/ms: 117.23 | Avg Forward Time: 57.96 | Avg Backward Time: 81.80
flash_attn, torch.compile, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.291 | Eval Loss: 6.368 | Tokens/ms: 117.26 | Avg Forward Time: 57.77 | Avg Backward Time: 81.96
flash_attn, torch.compile, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.229 | Eval Loss: 5.900 | Tokens/ms: 117.12 | Avg Forward Time: 57.91 | Avg Backward Time: 81.98
flash_attn, torch.compile, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.528 | Eval Loss: 5.769 | Tokens/ms: 117.23 | Avg Forward Time: 57.70 | Avg Backward Time: 82.06
flash_attn, torch.compile, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Tr