Gradient Accumulation
In many situations, we want to have a high batch size (desired batch size), however our GPU can only handle a specific batch size (tolerable batch size). One option is to have multiple GPUs and use distributed data training. But what if only one GPU is available? The solution is gradient accumulation.


- Gradient accumulation (summation) is performing **multiple** backwards passes **before** updating the parameters. The goal is to have the same model parameters for multiple inputs (batches) and then update the model's parameters based on all these batches, instead of performing an update after every single batch. So we run each torelarbale batch size individually with the same model parameters and calculate the gradients without updating the model. When the desired batch size is reached, we can then update the gradients.

- Point of confusion. The **computational graph** is automatically destroyed when .backward() is called (unless retain_graph=True is specified), and **NOT** the gradients. The gradients are only reset when calling optimizer.zero_grad()


In [None]:
import torch
import torch.nn as nn
import torchvision

In [None]:
model = torchvision.models.resnet101()
num_iterations = 10
xe = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

In [None]:
batch_size = 50
for i in range(num_iterations):
  inputs = torch.randn(batch_size, 3, 224, 224)
  labels = torch.LongTensor(batch_size).random_(0, 100)
  loss = xe(model(inputs), labels)
  loss.backward()
  optimizer.step()
  optimizer.zero_grad()
  print('One batch')

One batch
One batch
One batch
One batch
One batch
One batch
One batch
One batch
One batch
One batch


In [None]:
desired_batch_size = 100
tolerable_batch_size = 50
accum_steps = desired_batch_size / tolerable_batch_size

In [None]:
for i in range(num_iterations):
  inputs = torch.randn(tolerable_batch_size, 3, 224, 224)
  labels = torch.LongTensor(tolerable_batch_size).random_(0, 100)
  loss = xe(model(inputs), labels)
  loss = loss / accum_steps
  loss.backward()

  if ((i + 1) % accum_steps == 0) or ((i + 1) == num_iterations):
    optimizer.step()
    optimizer.zero_grad()
    print('One Batch')