Final results: With all the following improvements

- BFloat16 mixed precision training
- `torch.compile()`
- Flash Attention 2
- Vocabulary rounded to the nearest multiple of 64

I see a ~9.3% improve in tokens throughput.

Notes:
- I think the BFloat16 option clobbers the TFloat32 option, since `torch.amp` will optimize matrix multiplations to BFloat16, overriding the default of `torch.set_float32_matmul_precision('high')`. For more information, see [here](https://pytorch.org/docs/stable/amp.html#cuda-op-specific-behavior).
- `F.scaled_dot_product_attention` does not necessarily use a Flash Attention kernel. Previously I had bug where I provided an attention-mask to the function, which caused it to fall back to the "Memory-Efficient Attention" kernel. I haven't investigated this further, but it's possible that other configurations, like the use of float32, could also cause a similar behavior.

In [1]:
import itertools as it
import torch
from torch.utils.data import Subset

%reload_ext autoreload
%autoreload 2

from src import data, modules, pipeline, utils

In [2]:
batch_size = 16
num_heads = 12
embed_dim = 768
context_len = 1024
vocab_size = 50257
device = "cuda"

dataset = data.TinyStoriesDataset(1024, num_stories=500)
train_ds = Subset(dataset, list(range(batch_size)))
eval_ds = Subset(dataset, list(range(batch_size, 2 * batch_size)))

Tokenizing Stories: 100%|████████████████████████████████████████████████████████████████████████████| 500/500 [00:00<00:00, 1402.65 stories/s]


In [3]:
# Overfit on a single batch
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
    model.to(device)
    labels = []
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 0:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 11.53 | Avg Forward Time: 589.73 | Avg Backward Time: 830.79
fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.283 | Tokens/ms: 13.68 | Avg Forward Time: 456.10 | Avg Backward Time: 741.99
fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.102 | Tokens/ms: 13.60 | Avg Forward Time: 458.27 | Avg Backward Time: 746.44
fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 13.54 | Avg Forward Time: 460.97 | Avg Backward Time: 749.33
fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.890 | Eval Loss: 5.663 | Tokens/ms: 13.49 | Avg Forward Time: 462.36 | Avg Backward Time: 751.97
fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.081 | Eval Loss: 5.848 | Tokens/ms: 13.46 | Avg Forward Time: 463.21 | Avg Backward Time: 754.43


  0%|          | 0/50 [00:00<?, ?batches/s]

tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 29.52 | Avg Forward Time: 280.80 | Avg Backward Time: 274.21
tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.284 | Tokens/ms: 31.62 | Avg Forward Time: 245.79 | Avg Backward Time: 272.29
tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.103 | Tokens/ms: 31.57 | Avg Forward Time: 246.35 | Avg Backward Time: 272.66
tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 31.56 | Avg Forward Time: 246.42 | Avg Backward Time: 272.77
tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.895 | Eval Loss: 5.666 | Tokens/ms: 31.54 | Avg Forward Time: 246.42 | Avg Backward Time: 273.05
tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.090 | Eval Loss: 5.841 | Tokens/ms: 31.55 | Avg Forward Time: 246.47 | Avg Backward Time: 272.89


  0%|          | 0/50 [00:00<?, ?batches/s]

bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 27.90 | Avg Forward Time: 331.45 | Avg Backward Time: 255.75
bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.282 | Tokens/ms: 33.20 | Avg Forward Time: 239.25 | Avg Backward Time: 254.25
bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.090 | Tokens/ms: 33.19 | Avg Forward Time: 239.41 | Avg Backward Time: 254.29
bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.722 | Eval Loss: 5.746 | Tokens/ms: 33.16 | Avg Forward Time: 239.64 | Avg Backward Time: 254.42
bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.878 | Eval Loss: 5.661 | Tokens/ms: 33.17 | Avg Forward Time: 239.58 | Avg Backward Time: 254.33
bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.074 | Eval Loss: 5.848 | Tokens/ms: 33.15 | Avg Forward Time: 239.87 | Avg Backward Time: 254.39


  0%|          | 0/50 [00:00<?, ?batches/s]

tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 33.16 | Avg Forward Time: 238.54 | Avg Backward Time: 255.55
tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.282 | Tokens/ms: 33.20 | Avg Forward Time: 239.29 | Avg Backward Time: 254.24
tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.090 | Tokens/ms: 33.17 | Avg Forward Time: 239.57 | Avg Backward Time: 254.41
tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.722 | Eval Loss: 5.746 | Tokens/ms: 33.14 | Avg Forward Time: 239.85 | Avg Backward Time: 254.49
tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.878 | Eval Loss: 5.661 | Tokens/ms: 33.09 | Avg Forward Time: 240.48 | Avg Backward Time: 254.68
tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.074 | Eval Loss: 5.848 | Tokens/ms: 33.09 | Avg Forward Time: 240.40 | Avg Backward Time: 254.73


In [7]:
# Same, but with flash attention
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    model = modules.GPT2(
        vocab_size, embed_dim, context_len, num_heads, use_flash_attention=True
    )
    model.to(device)
    labels = ["flash_attn"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 1:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 16.75 | Avg Forward Time: 374.49 | Avg Backward Time: 603.39
flash_attn, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.283 | Tokens/ms: 16.76 | Avg Forward Time: 376.10 | Avg Backward Time: 601.50
flash_attn, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.102 | Tokens/ms: 16.70 | Avg Forward Time: 376.96 | Avg Backward Time: 604.17
flash_attn, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 16.64 | Avg Forward Time: 378.40 | Avg Backward Time: 606.06
flash_attn, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.890 | Eval Loss: 5.663 | Tokens/ms: 16.58 | Avg Forward Time: 379.65 | Avg Backward Time: 608.60
flash_attn, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.081 | Eval Loss: 5.848 | Tokens/ms: 16.58 | Avg Forward Time: 379.83 | Avg Backward Time: 608.62


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 39.12 | Avg Forward Time: 197.02 | Avg Backward Time: 221.84
flash_attn, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.706 | Eval Loss: 7.284 | Tokens/ms: 38.88 | Avg Forward Time: 201.00 | Avg Backward Time: 220.36
flash_attn, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.103 | Tokens/ms: 38.83 | Avg Forward Time: 201.26 | Avg Backward Time: 220.72
flash_attn, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 38.84 | Avg Forward Time: 200.92 | Avg Backward Time: 220.91
flash_attn, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.895 | Eval Loss: 5.667 | Tokens/ms: 38.87 | Avg Forward Time: 200.78 | Avg Backward Time: 220.70
flash_attn, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.090 | Eval Loss: 5.841 | Tokens/ms: 38.86 | Avg Forward Time: 200.75 | Avg Backward Time: 220.83


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.007 | Eval Loss: 9.714 | Tokens/ms: 49.24 | Avg Forward Time: 168.74 | Avg Backward Time: 163.97
flash_attn, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.704 | Eval Loss: 7.282 | Tokens/ms: 49.02 | Avg Forward Time: 171.63 | Avg Backward Time: 162.57
flash_attn, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.086 | Tokens/ms: 48.98 | Avg Forward Time: 171.83 | Avg Backward Time: 162.65
flash_attn, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.717 | Eval Loss: 5.747 | Tokens/ms: 49.00 | Avg Forward Time: 171.82 | Avg Backward Time: 162.54
flash_attn, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.871 | Eval Loss: 5.653 | Tokens/ms: 49.01 | Avg Forward Time: 171.81 | Avg Backward Time: 162.50
flash_attn, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.067 | Eval Loss: 5.851 | Tokens/ms: 48.97 | Avg Forward Time: 171.98 | Avg Backward Time: 162.57


  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.007 | Eval Loss: 9.714 | Tokens/ms: 49.21 | Avg Forward Time: 169.07 | Avg Backward Time: 163.84
flash_attn, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.282 | Tokens/ms: 49.01 | Avg Forward Time: 171.86 | Avg Backward Time: 162.45
flash_attn, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.973 | Eval Loss: 6.068 | Tokens/ms: 48.97 | Avg Forward Time: 171.98 | Avg Backward Time: 162.59
flash_attn, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.682 | Eval Loss: 5.758 | Tokens/ms: 48.95 | Avg Forward Time: 172.10 | Avg Backward Time: 162.58
flash_attn, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.822 | Eval Loss: 5.660 | Tokens/ms: 48.98 | Avg Forward Time: 171.90 | Avg Backward Time: 162.61
flash_attn, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 2.993 | Eval Loss: 5.859 | Tokens/ms: 48.97 | Avg Forward Time: 171.97 | Avg B

In [4]:
# Same, but with torch.compile
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    model = modules.GPT2(vocab_size, embed_dim, context_len, num_heads)
    model.to(device)
    model = torch.compile(model)
    labels = ["torch.compile"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 1:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]



torch.compile, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 0.74 | Avg Forward Time: 14797.34 | Avg Backward Time: 7258.47
torch.compile, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.283 | Tokens/ms: 16.96 | Avg Forward Time: 309.42 | Avg Backward Time: 656.58
torch.compile, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.102 | Tokens/ms: 16.88 | Avg Forward Time: 311.67 | Avg Backward Time: 658.91
torch.compile, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 16.77 | Avg Forward Time: 314.33 | Avg Backward Time: 662.72
torch.compile, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.890 | Eval Loss: 5.663 | Tokens/ms: 16.72 | Avg Forward Time: 314.42 | Avg Backward Time: 665.24
torch.compile, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.081 | Eval Loss: 5.848 | Tokens/ms: 16.70 | Avg Forward Time: 315.24 | Avg Backward Time: 66

  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 0.81 | Avg Forward Time: 13226.05 | Avg Backward Time: 6946.94
torch.compile, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.706 | Eval Loss: 7.284 | Tokens/ms: 63.84 | Avg Forward Time: 86.11 | Avg Backward Time: 170.54
torch.compile, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.104 | Tokens/ms: 63.75 | Avg Forward Time: 86.22 | Avg Backward Time: 170.77
torch.compile, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 63.72 | Avg Forward Time: 86.24 | Avg Backward Time: 170.89
torch.compile, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.896 | Eval Loss: 5.667 | Tokens/ms: 63.62 | Avg Forward Time: 86.39 | Avg Backward Time: 171.14
torch.compile, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.091 | Eval Loss: 5.841 | Tokens/ms: 63.55 | Avg Forward Time: 86.40 | Avg Backward Time: 171.41


  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 0.66 | Avg Forward Time: 15459.08 | Avg Backward Time: 9517.88
torch.compile, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.278 | Tokens/ms: 94.35 | Avg Forward Time: 75.21 | Avg Backward Time: 98.44
torch.compile, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.969 | Eval Loss: 6.037 | Tokens/ms: 94.01 | Avg Forward Time: 75.56 | Avg Backward Time: 98.71
torch.compile, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.598 | Eval Loss: 5.717 | Tokens/ms: 93.86 | Avg Forward Time: 75.91 | Avg Backward Time: 98.65
torch.compile, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.679 | Eval Loss: 5.738 | Tokens/ms: 93.82 | Avg Forward Time: 75.80 | Avg Backward Time: 98.84
torch.compile, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 2.856 | Eval Loss: 5.896 | Tokens/ms: 93.60 | Avg Forward Time: 76.09 | Avg Backward Time: 98.95


  0%|          | 0/50 [00:00<?, ?batches/s]

torch.compile, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 0.72 | Avg Forward Time: 14580.64 | Avg Backward Time: 8195.06
torch.compile, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.277 | Tokens/ms: 94.51 | Avg Forward Time: 75.25 | Avg Backward Time: 98.10
torch.compile, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.969 | Eval Loss: 6.031 | Tokens/ms: 94.17 | Avg Forward Time: 75.56 | Avg Backward Time: 98.42
torch.compile, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.590 | Eval Loss: 5.751 | Tokens/ms: 94.07 | Avg Forward Time: 75.89 | Avg Backward Time: 98.27
torch.compile, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.717 | Eval Loss: 5.716 | Tokens/ms: 93.84 | Avg Forward Time: 76.00 | Avg Backward Time: 98.59
torch.compile, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 2.876 | Eval Loss: 5.891 | Tokens/ms: 93.69 | Avg Forward Time: 75

In [5]:
# Same, but with torch.compile & flash attention
for enable_bf16_amp, enable_tf32 in it.product([False, True], repeat=2):
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    model = modules.GPT2(
        vocab_size, embed_dim, context_len, num_heads, use_flash_attention=True
    )
    model.to(device)
    model = torch.compile(model)
    labels = ["flash_attn", "torch.compile"]
    if enable_tf32:
        labels.append("tf32")
    if enable_bf16_amp:
        labels.append("bf16")
    if len(labels) == 2:
        labels.append("fp32")
    label = ", ".join(labels)
    pipeline.train_gpt2(
        model,
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        num_epochs=50,
        batch_size=batch_size,
        device=device,
        enable_tf32=enable_tf32,
        enable_bf16_amp=enable_bf16_amp,
        label=label,
        logging_interval=10,
    )

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, fp32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 1.03 | Avg Forward Time: 10227.33 | Avg Backward Time: 5682.05
flash_attn, torch.compile, fp32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.283 | Tokens/ms: 19.90 | Avg Forward Time: 264.31 | Avg Backward Time: 558.82
flash_attn, torch.compile, fp32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.102 | Tokens/ms: 19.79 | Avg Forward Time: 265.91 | Avg Backward Time: 562.14
flash_attn, torch.compile, fp32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 19.71 | Avg Forward Time: 266.83 | Avg Backward Time: 564.55
flash_attn, torch.compile, fp32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.890 | Eval Loss: 5.663 | Tokens/ms: 19.64 | Avg Forward Time: 267.94 | Avg Backward Time: 566.39
flash_attn, torch.compile, fp32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.081 | Eval Loss: 5.8

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, tf32 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 1.13 | Avg Forward Time: 9128.33 | Avg Backward Time: 5363.55
flash_attn, torch.compile, tf32 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.284 | Tokens/ms: 65.18 | Avg Forward Time: 79.35 | Avg Backward Time: 172.02
flash_attn, torch.compile, tf32 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.975 | Eval Loss: 6.104 | Tokens/ms: 64.78 | Avg Forward Time: 79.45 | Avg Backward Time: 173.48
flash_attn, torch.compile, tf32 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.741 | Eval Loss: 5.743 | Tokens/ms: 64.85 | Avg Forward Time: 79.37 | Avg Backward Time: 173.29
flash_attn, torch.compile, tf32 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.896 | Eval Loss: 5.667 | Tokens/ms: 64.80 | Avg Forward Time: 79.62 | Avg Backward Time: 173.21
flash_attn, torch.compile, tf32 | Epoch   49 | Minibatch    0 | Avg Train Loss: 3.091 | Eval Loss: 5.841 | 

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 0.90 | Avg Forward Time: 10552.16 | Avg Backward Time: 7608.76
flash_attn, torch.compile, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.278 | Tokens/ms: 116.19 | Avg Forward Time: 59.02 | Avg Backward Time: 81.99
flash_attn, torch.compile, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.967 | Eval Loss: 6.056 | Tokens/ms: 115.77 | Avg Forward Time: 59.13 | Avg Backward Time: 82.40
flash_attn, torch.compile, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.636 | Eval Loss: 5.765 | Tokens/ms: 116.06 | Avg Forward Time: 59.00 | Avg Backward Time: 82.17
flash_attn, torch.compile, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.681 | Eval Loss: 5.697 | Tokens/ms: 115.54 | Avg Forward Time: 59.39 | Avg Backward Time: 82.42
flash_attn, torch.compile, bf16 | Epoch   49 | Minibatch    0 | Avg Train Loss: 2.819 | Eval Loss: 5.916 |

  0%|          | 0/50 [00:00<?, ?batches/s]

flash_attn, torch.compile, tf32, bf16 | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.006 | Eval Loss: 9.714 | Tokens/ms: 0.91 | Avg Forward Time: 10367.14 | Avg Backward Time: 7734.33
flash_attn, torch.compile, tf32, bf16 | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.705 | Eval Loss: 7.278 | Tokens/ms: 116.29 | Avg Forward Time: 58.94 | Avg Backward Time: 81.95
flash_attn, torch.compile, tf32, bf16 | Epoch   20 | Minibatch    0 | Avg Train Loss: 5.969 | Eval Loss: 6.038 | Tokens/ms: 115.92 | Avg Forward Time: 59.36 | Avg Backward Time: 81.97
flash_attn, torch.compile, tf32, bf16 | Epoch   30 | Minibatch    0 | Avg Train Loss: 4.607 | Eval Loss: 5.752 | Tokens/ms: 116.20 | Avg Forward Time: 59.17 | Avg Backward Time: 81.83
flash_attn, torch.compile, tf32, bf16 | Epoch   40 | Minibatch    0 | Avg Train Loss: 3.701 | Eval Loss: 5.716 | Tokens/ms: 115.92 | Avg Forward Time: 59.27 | Avg Backward Time: 82.07
flash_attn, torch.compile, tf32, bf16 | Epoch   49 | Minibatch    0 | Avg Tr

In [6]:
# torch.compile, flash attention, and a vocab_size rounded
torch.manual_seed(42)
torch.cuda.manual_seed(42)
model = modules.GPT2(
    utils.round_to_multiple(vocab_size, 64), embed_dim, context_len, num_heads, use_flash_attention=True
)
model.to(device)
model = torch.compile(model)
labels = ["rounded_vocab", "flash_attn", "torch.compile"]
label = ", ".join(labels)
pipeline.train_gpt2(
    model,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    num_epochs=50,
    batch_size=batch_size,
    device=device,
    enable_tf32=True,
    enable_bf16_amp=True,
    label=label,
    logging_interval=10,
)

  0%|          | 0/50 [00:00<?, ?batches/s]

rounded_vocab, flash_attn, torch.compile | Epoch    0 | Minibatch    0 | Avg Train Loss: 11.002 | Eval Loss: 9.692 | Tokens/ms: 0.89 | Avg Forward Time: 10446.84 | Avg Backward Time: 7903.99
rounded_vocab, flash_attn, torch.compile | Epoch   10 | Minibatch    0 | Avg Train Loss: 8.606 | Eval Loss: 7.465 | Tokens/ms: 124.38 | Avg Forward Time: 57.40 | Avg Backward Time: 74.32
rounded_vocab, flash_attn, torch.compile | Epoch   20 | Minibatch    0 | Avg Train Loss: 6.363 | Eval Loss: 6.410 | Tokens/ms: 124.77 | Avg Forward Time: 57.19 | Avg Backward Time: 74.12
rounded_vocab, flash_attn, torch.compile | Epoch   30 | Minibatch    0 | Avg Train Loss: 5.270 | Eval Loss: 5.925 | Tokens/ms: 124.50 | Avg Forward Time: 57.42 | Avg Backward Time: 74.17
rounded_vocab, flash_attn, torch.compile | Epoch   40 | Minibatch    0 | Avg Train Loss: 4.513 | Eval Loss: 5.751 | Tokens/ms: 124.33 | Avg Forward Time: 57.47 | Avg Backward Time: 74.31
rounded_vocab, flash_attn, torch.compile | Epoch   49 | Minib