In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


In [57]:
class SimpleModel(nn.Module):
    def __init__(self, vocab, num_classes, dim=4, dtype=torch.float32):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab, dim, dtype=dtype)
        self.classifier = torch.nn.Linear(dim, num_classes, dtype=dtype)

    def forward(self, x):
        x = self.embedding(x)
        x = torch.mean(x, dim=-2)
        return self.classifier(x)
    

def test_train(
    weight_dtype=torch.float32,
    use_amp=False,
):
    vocab = 10
    num_classes = 3
    n_samples = 50
    input_seq = 6
    epochs = 4
    lr = 1e-4

    train_x = torch.randint(0, vocab, [n_samples, input_seq], dtype=torch.long)
    train_y = torch.randint(0, num_classes, [n_samples], dtype=torch.long)

    model = SimpleModel(vocab, num_classes, dtype=weight_dtype)
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, )
    scaler = torch.amp.GradScaler() if use_amp else None

    for epoch in range(epochs):
        with torch.autocast('cpu', dtype=torch.float16, enabled=use_amp):
            logits = model(train_x)
            loss = criterion(logits, train_y)
        print(f"epoch {epoch}: train loss {loss.item()}")

        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        optimizer.zero_grad()


test_train()

epoch 0: train loss 1.1121506690979004
epoch 1: train loss 1.1121093034744263
epoch 2: train loss 1.1120681762695312
epoch 3: train loss 1.1120266914367676


In [59]:
test_train(weight_dtype=torch.float16)

epoch 0: train loss 1.2509765625
epoch 1: train loss nan
epoch 2: train loss nan
epoch 3: train loss nan


In [60]:
# Unfortunetely Automatic Mixed Precision doesn't work on CPU
test_train(weight_dtype=torch.float16, use_amp=True)

epoch 0: train loss 1.1404216289520264
epoch 1: train loss nan
epoch 2: train loss nan
epoch 3: train loss nan


