# WTI Tutorial: Using Automatic Mixed Precision and Model Compilation in PyTorch

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torch.utils.data import DataLoader, TensorDataset
from torch.amp import autocast, GradScaler

In [2]:

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Dummy dataset
dim = 512
sequence_length = 128
batch_size = 32
num_batches = 100

dummy_data = torch.randn(num_batches * batch_size, sequence_length, dim).to(device)
dummy_target = torch.randint(0, 2, (num_batches * batch_size, sequence_length)).to(device)

dataset = TensorDataset(dummy_data, dummy_target)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Transformer model
nlayers, nheads = 8, 8
model = nn.Sequential(*[nn.TransformerEncoderLayer(d_model=dim, nhead=nheads) for _ in range(nlayers)])
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

## Automatic Mixed Precision

In [3]:
# Test training time with and without AMP

def eval_training_time(use_amp=False):
    model.train()
    start_time = time.time()
    for data, target in dataloader:
        optimizer.zero_grad()
        with autocast(device_type=device.type, enabled=use_amp, dtype=torch.bfloat16):
            output = model(data)
            loss = criterion(output.view(-1, output.size(-1)), target.view(-1))
        loss.backward()
    end_time = time.time()
    run_time = end_time - start_time
    toks_per_sec = num_batches * batch_size * sequence_length / run_time
    return run_time, toks_per_sec

In [4]:
# Measure training time and tokens per second
time_standard, toks_per_sec_standard = eval_training_time(use_amp=False)
time_amp, toks_per_sec_amp = eval_training_time(use_amp=True)

print(f"Training time without AMP: {time_standard:.2f} seconds ({toks_per_sec_standard:.2f} tokens/s)")
print(f"Training time with AMP (bfloat16): {time_amp:.2f} seconds ({toks_per_sec_amp:.2f} tokens/s)")
print(f"Speedup: {time_standard / time_amp:.2f}x")

Training time without AMP: 4.66 seconds (87840.74 tokens/s)
Training time with AMP (bfloat16): 1.39 seconds (294713.14 tokens/s)
Speedup: 3.36x


## Model compilation with `torch.compile`

In [5]:
compiled_model = torch.compile(model)

In [6]:
def eval_training_time(use_amp=False, compile_model=False):
    if compile_model:
        model_ = compiled_model
    else:
        model_ = model

    model_.train()
    start_time = time.time()
    for data, target in dataloader:
        optimizer.zero_grad()
        with autocast(device_type=device.type, enabled=use_amp, dtype=torch.bfloat16):
            output = model_(data)
            loss = criterion(output.view(-1, output.size(-1)), target.view(-1))
        loss.backward()
    end_time = time.time()
    run_time = end_time - start_time
    toks_per_sec = num_batches * batch_size * sequence_length / run_time
    return run_time, toks_per_sec

In [7]:
# Measure training time and tokens per second

# run compiled_model once to remove overhead
_, _ = eval_training_time(use_amp=False, compile_model=True)
_, _ = eval_training_time(use_amp=True, compile_model=True)

time_standard, toks_per_sec_standard = eval_training_time(use_amp=False, compile_model=False)
time_amp, toks_per_sec_amp = eval_training_time(use_amp=True, compile_model=False)
time_comp, toks_per_sec_comp = eval_training_time(use_amp=False, compile_model=True)
time_comp_amp, toks_per_sec_comp_amp = eval_training_time(use_amp=True, compile_model=True)

print(f"Training time without Compilation or AMP: {time_standard:.2f} seconds ({toks_per_sec_standard:.2f} tokens/s)")
print(f"Training time with AMP but no Compilation: {time_amp:.2f} seconds ({toks_per_sec_amp:.2f} tokens/s)")
print(f"Training time with Compilation but no AMP: {time_comp:.2f} seconds ({toks_per_sec_comp:.2f} tokens/s)")
print(f"Training time with AMP and Compilation: {time_comp_amp:.2f} seconds ({toks_per_sec_comp_amp:.2f} tokens/s)")
print()
print(f"Speedup with AMP: {time_standard / time_amp:.2f}x")
print(f"Speedup with Compilation: {time_standard / time_comp:.2f}x")
print(f"Speedup with AMP and Compilation: {time_standard / time_comp_amp:.2f}x")



Training time without Compilation or AMP: 4.58 seconds (89409.98 tokens/s)
Training time with AMP but no Compilation: 1.29 seconds (316841.28 tokens/s)
Training time with Compilation but no AMP: 4.25 seconds (96439.70 tokens/s)
Training time with AMP and Compilation: 1.00 seconds (408656.09 tokens/s)

Speedup with AMP: 3.54x
Speedup with Compilation: 1.08x
Speedup with AMP and Compilation: 4.57x
