More details [here](https://pytorch.org/docs/stable/amp.html), [here](https://pytorch.org/docs/stable/notes/amp_examples.html).

In [1]:
import torch
from torch import nn
import torchvision
import numpy as np
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

from torch.cuda.amp import GradScaler, autocast

In [4]:
BS = 1048

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=BS, shuffle=True)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
valid_loader = torch.utils.data.DataLoader(
    testset, batch_size=BS, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


## Training AMP

In [6]:
%%time

model = models.resnet18(pretrained=True).cuda(0)
model.fc = torch.nn.Linear(in_features=512, out_features=10).cuda(0)
model.fc.requires_grad_(True)

LR = 0.0003
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
cost_function = torch.nn.CrossEntropyLoss()
TRAIN_STEP = len(trainset)/BS
VALID_STEP = len(testset)/BS
EPOCHS = 5

TRAIN_LOSS = []
VAL_LOSS = []
VAL_ACC = []

scaler = GradScaler()

for epoch in range(EPOCHS):
  print("Epoch", epoch+1)
  running_loss = 0.0
  model.train()
  for (x, y) in train_loader:

    optimizer.zero_grad()
    x, y = x.cuda(0), y.cuda(0)

    with autocast():
      z = model(x)
      loss = cost_function(z, y)
      running_loss+=loss.detach()

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
  train_l = running_loss/TRAIN_STEP
  TRAIN_LOSS.append(train_l.item())
  print("Training loss:", train_l.item())
  running_loss = 0.0
  correct = 0.0
  model.eval()    

  for (x,y) in valid_loader:
    x, y = x.cuda(0), y.cuda(0)
    z = model(x)
    loss = cost_function(z, y)
    running_loss+=loss.detach()
    _, yhat = torch.max(z,1)
    correct += torch.eq(yhat, y).sum().item()

  valid_l = running_loss/VALID_STEP
  valid_a = (correct / len(testset)) * 100
  print("Valid loss:", valid_l.item()) 
  VAL_LOSS.append(valid_l.item())
  print('VAL_Accuracy: %d %%' % valid_a)
  VAL_ACC.append(valid_a)
  print('')

Epoch 1
Training loss: 1.2421400547027588
Valid loss: 0.9059109091758728
VAL_Accuracy: 70 %

Epoch 2
Training loss: 0.7551761865615845
Valid loss: 0.7223936319351196
VAL_Accuracy: 76 %

Epoch 3
Training loss: 0.6292315125465393
Valid loss: 0.6798090934753418
VAL_Accuracy: 77 %

Epoch 4
Training loss: 0.5495712757110596
Valid loss: 0.6177423596382141
VAL_Accuracy: 79 %

Epoch 5
Training loss: 0.492849200963974
Valid loss: 0.6252909302711487
VAL_Accuracy: 79 %

CPU times: user 2min 15s, sys: 16.8 s, total: 2min 32s
Wall time: 2min 32s
