# Mixed Precision Traininig

So far when we trained neural networks, we utilized the `torch.float32` datatype. There are layers, like linear layers and convolutions, that can be executed much faster using the lower precision like `torch.float16`. Mixed precision training allows us to train a neural network, where operations will utilize different levels of precision. 

Mixed precision training has at least two advantages.

1. Some layers are faster with `torch.float16` precision, therefore the whole training process will be significantly faster
2. Operations using `torch.float16` require less memory than `torch.float32` operations. That will reduce the necessary vram requirements and will allow us to use a larger batch size.

PyTorch provides a so called `automatic mixed precision` functionality, that automatically decides which of the operations will run with which precision. We do not have to make any of those decisions manually.

We will demonstrate this performance boost using the MNIST dataset.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms as T

import time

In [2]:
assert torch.cuda.is_available()

In [3]:
%%capture
train_dataset = MNIST(root="../datasets", train=True, download=True, transform=T.ToTensor())

In [4]:
train_dataloader=DataLoader(dataset=train_dataset, 
                            batch_size=256, 
                            shuffle=True, 
                            drop_last=True,
                            num_workers=2)

We use a much larger network, than what is required to get good performance for MINST. We do this in order to demonstrate the potential of mixed precision training. We use 14 convolutional layers and 3 fully connected layers.

In [5]:
cfg = [[1, 32, 3, 1, 1],
       [32, 64, 3, 1, 1],
       [64, 64, 2, 2, 0],
       [64, 128, 3, 1, 1],
       [128, 128, 3, 1, 1],
       [128, 128, 3, 1, 1],
       [128, 128, 2, 2, 0],
       [128, 256, 3, 1, 1],
       [256, 256, 2, 1, 0],
       [256, 512, 3, 1, 1],
       [512, 512, 3, 1, 1],
       [512, 512, 3, 1, 1],
       [512, 512, 2, 2, 0],
       [512, 1024, 3, 1, 1],
]

class BasicBlock(nn.Module):
  def __init__(self, **kwargs):
    super().__init__()
    self.block = nn.Sequential(
        nn.Conv2d(**kwargs),
        nn.BatchNorm2d(num_features=kwargs['out_channels']),
        nn.ReLU()
    )
  
  def forward(self, x):
    return self.block(x)

class Model(nn.Module):
  def __init__(self, cfg):
    super().__init__()
    self.features = self._build_layers(cfg)
    self.avgpool = nn.AdaptiveAvgPool2d(1)
    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=1024, out_features=1000),
        nn.ReLU(),
        nn.Linear(in_features=1000, out_features=1000),
        nn.ReLU(),
        nn.Linear(in_features=1000, out_features=10),
    )
  
  def _build_layers(self, cfg):
    layers = []
    for layer in cfg:
      layers += [BasicBlock(in_channels=layer[0],
                           out_channels=layer[1],
                           kernel_size=layer[2],
                           stride=layer[3],
                           padding=layer[4])]
    return nn.Sequential(*layers)
  
  def forward(self, x):
    x = self.features(x)
    x = self.avgpool(x)
    x = self.classifier(x)
    return x

In [6]:
NUM_EPOCHS=10
LR=0.0001
DEVICE = torch.device('cuda')

We start by training the neural network in a familiar manner, measuring the time an epoch takes and the reserved memory. We will use those values as a 

In [7]:
def train(data_loader, model, optimizer, criterion):
  for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    losses = []
    for img, label in data_loader:
      img = img.to(DEVICE)
      label = label.to(DEVICE)
      optimizer.zero_grad()
      prediction = model(img)
      loss = criterion(prediction, label)
      losses.append(loss.item())
      loss.backward()
      optimizer.step()

    end_time = time.time()
    s = f'Epoch: {epoch+1}, ' \
      f'Loss: {sum(losses)/len(losses):.4f}, ' \
      f'Elapsed Time: {end_time-start_time:.2f}sec, ' \
      f'Reserved Memory: {torch.cuda.memory_reserved() / 2**20:.2f}MB'
    print(s)

In [8]:
model = Model(cfg)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [9]:
torch.cuda.empty_cache()

Each epoch takes slightly over 30 seconds to complete and we need roughtly 1GB of VRAM.

In [10]:
train(train_dataloader, model, optimizer, criterion)

Epoch: 1, Loss: 0.2546, Elapsed Time: 33.29sec, Reserved Memory: 988.00MB
Epoch: 2, Loss: 0.0325, Elapsed Time: 30.91sec, Reserved Memory: 988.00MB
Epoch: 3, Loss: 0.0200, Elapsed Time: 31.16sec, Reserved Memory: 988.00MB
Epoch: 4, Loss: 0.0144, Elapsed Time: 31.36sec, Reserved Memory: 988.00MB
Epoch: 5, Loss: 0.0125, Elapsed Time: 31.56sec, Reserved Memory: 988.00MB
Epoch: 6, Loss: 0.0140, Elapsed Time: 31.73sec, Reserved Memory: 988.00MB
Epoch: 7, Loss: 0.0078, Elapsed Time: 31.71sec, Reserved Memory: 988.00MB
Epoch: 8, Loss: 0.0082, Elapsed Time: 31.67sec, Reserved Memory: 988.00MB
Epoch: 9, Loss: 0.0105, Elapsed Time: 31.73sec, Reserved Memory: 988.00MB
Epoch: 10, Loss: 0.0070, Elapsed Time: 31.86sec, Reserved Memory: 988.00MB


We repeat the training procedure, only this time we use mixed precision training. For that we will utilize the `torch.amp` module (automatic mixed precision). Look at the [official documentation](https://pytorch.org/docs/stable/amp.html), if you need more information.

- The `torch.amp.autocast` context manager runs the region below the context manager in mixed precision. For our purposes the forward pass and the loss calculation is calculated using mixed precision.

- We use `torch.cuda.amp.GradScalar` object in order to scale the gradients of the loss. If the forward pass of a layer uses 16 bit precision, so will the backward pass. For some of the calculations the gradients will be relatively small and the precision of torch.float16 will not be sufficient to hold those small values. The values will underflow. In order to remedy the problem, the loss is scaled and we let the scaler deal with backprop and gradient descent. At the end we reset the scaler object for the next batch.

The three lines from below do exactly that.
```
      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()
```

In [11]:
def optimized_train(data_loader, model, optimizer, criterion):
  scaler = torch.cuda.amp.GradScaler()
  for epoch in range(NUM_EPOCHS):
    start_time = time.time()
    losses = []
    for img, label in data_loader:
      img = img.to(DEVICE)
      label = label.to(DEVICE)
      optimizer.zero_grad()
      with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
        prediction = model(img)
        loss = criterion(prediction, label)
      losses.append(loss.item())
      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()

    end_time = time.time()
    s = f'Epoch: {epoch+1}, ' \
      f'Loss: {sum(losses)/len(losses):.4f}, ' \
      f'Elapsed Time: {end_time-start_time:.2f}sec, ' \
      f'Reserved Memory: {torch.cuda.memory_reserved() / 2**20:.2f}MB'
    print(s)

In [12]:
model = Model(cfg)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

In [13]:
torch.cuda.empty_cache()

We improve the training speed by a factor of at least 2 and reduce the memory footpring significantly as well. The overhead to use automatic mixed precision is inconsequential compared to the benefits of amp.

In [14]:
optimized_train(train_dataloader, model, optimizer, criterion)

Epoch: 1, Loss: 0.2636, Elapsed Time: 13.82sec, Reserved Memory: 622.00MB
Epoch: 2, Loss: 0.0307, Elapsed Time: 13.73sec, Reserved Memory: 622.00MB
Epoch: 3, Loss: 0.0185, Elapsed Time: 14.48sec, Reserved Memory: 622.00MB
Epoch: 4, Loss: 0.0146, Elapsed Time: 13.81sec, Reserved Memory: 622.00MB
Epoch: 5, Loss: 0.0113, Elapsed Time: 13.78sec, Reserved Memory: 622.00MB
Epoch: 6, Loss: 0.0105, Elapsed Time: 14.08sec, Reserved Memory: 622.00MB
Epoch: 7, Loss: 0.0104, Elapsed Time: 13.75sec, Reserved Memory: 622.00MB
Epoch: 8, Loss: 0.0077, Elapsed Time: 13.74sec, Reserved Memory: 622.00MB
Epoch: 9, Loss: 0.0077, Elapsed Time: 13.79sec, Reserved Memory: 622.00MB
Epoch: 10, Loss: 0.0080, Elapsed Time: 13.73sec, Reserved Memory: 622.00MB
