# Back propagation trick
Experiments here require torch, preferably with cuda.

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
from copy import deepcopy
from pickle import dump
from datetime import datetime, timedelta

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Manual back propagation implementation
Sanity check: How exactly do we compute the gradient across a single dense layer.

**Key take away:** The gradient updates to the weights are simply just outer-products. They can be computed using only the "error signal" and the activation.

In [2]:
batch_size = 5
n_inputs = 10
n_hidden_units = 3

for i in range(10000):
  layer = nn.Linear(n_inputs, n_hidden_units, bias=True)
  z0 = torch.randn([batch_size,n_inputs], requires_grad=True, dtype=torch.float32)  # 5 batches, 5 inputs
  z1 = layer(z0)
  z2 = F.relu(z1)

  upstream_grad = torch.randn([batch_size, n_hidden_units], dtype=torch.float32)          # 5 batches, 3 n_hidden_units
  z2.backward(upstream_grad, inputs=[z0]+list(layer.parameters()))

  relu_grad = (z1>0).float()

  error_signal = relu_grad*upstream_grad          # (batch_size, n_hidden_units)

  # Theoretical layer weight grad
  weight_grad = error_signal.T @ z0         # The gradient is a rank-1 sum (outer product):    (n_hidden_units, batch_size) @ (batch_size, n_inputs = (n_hidden_units, n_inputs)
  assert torch.all(layer.weight.grad == weight_grad)

  # Theoretical bias grad
  bias_grad = error_signal.sum(axis=0)      # The gradient is a sum over the batch axis:  (n_hidden_units,)
  assert torch.all(layer.bias.grad == bias_grad)

  # Theoretical "pass-through gradient"
  pass_through_grad = error_signal @ layer.weight # The gradient is a rank-1 sum:    (batch_size, n_hidden_units) @ (n_hidden_units, n_inputs) = (batch_size, n_inputs)
  assert torch.all(z0.grad == pass_through_grad)


print(f"Weight computation: {error_signal.T.shape} @ {z0.shape} = {weight_grad.shape}")
print(f"Bias computation: {error_signal.sum(axis=0).shape}")
print(f"Pass-through computation: {error_signal.shape} @ {layer.weight.shape} = {pass_through_grad.shape}")

Weight computation: torch.Size([3, 5]) @ torch.Size([5, 10]) = torch.Size([3, 10])
Bias computation: torch.Size([3])
Pass-through computation: torch.Size([5, 3]) @ torch.Size([3, 10]) = torch.Size([5, 10])


## Testing the fused approach
Do the two approaches produce exactly the same result during training? If Adam performs some sort of Gradient normalization over the entire network, then no. But it turns out it doesn't...

In [3]:
class ANN(nn.Module):
  def __init__(self):
    super().__init__()
    self.dense_1 = nn.LazyLinear(50)
    self.dense_2 = nn.LazyLinear(50)
    self.dense_3 = nn.LazyLinear(1)
    self.optimizer = torch.optim.Adam(self.parameters())

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    return z3

  def backward(self, loss):
    loss.backward()
    self.optimizer.step()
    self.optimizer.zero_grad()


class MemoryEfficientCopyANN(nn.Module):
  def __init__(self, other):
    super().__init__()
    self.dense_1 = deepcopy(other.dense_1) #nn.LazyLinear(50)
    self.dense_2 = deepcopy(other.dense_2) #nn.LazyLinear(50)
    self.dense_3 = deepcopy(other.dense_3) #nn.LazyLinear(1)
    self.dense_1_optim = torch.optim.Adam(self.dense_1.parameters())
    self.dense_2_optim = torch.optim.Adam(self.dense_2.parameters())
    self.dense_3_optim = torch.optim.Adam(self.dense_3.parameters())
    self.activations = []

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    self.activations = [z0,z1,z2,z3]   # Store references to tensors
    return z3

  def backward(self, loss):
    z0 = self.activations[0]
    z1 = self.activations[1]
    z2 = self.activations[2]
    z3 = self.activations[3]
    #[z0,z1,z2,z3] = self.activations
    self.activations = []

    # Target error
    loss.backward(inputs=[z3], retain_graph=True)
    delta_L = z3.grad

    # Backpropagation with in-place optimization application and deletion
    z3.backward(delta_L, inputs=[z2]+list(self.dense_3.parameters()), retain_graph=True)
    delta_3 = z2.grad
    self.dense_3_optim.step()
    self.dense_3_optim.zero_grad()
    assert self.dense_3.weight.grad is None
    assert self.dense_3.bias.grad is None
    del delta_L
    del z3

    z2.backward(delta_3, inputs=[z1]+list(self.dense_2.parameters()), retain_graph=True)
    delta_2 = z1.grad
    self.dense_2_optim.step()
    self.dense_2_optim.zero_grad()
    assert self.dense_2.weight.grad is None
    assert self.dense_2.bias.grad is None
    del delta_3
    del z2

    z1.backward(delta_2, inputs=list(self.dense_1.parameters()), retain_graph=False)
    self.dense_1_optim.step()
    self.dense_1_optim.zero_grad()
    assert self.dense_1.weight.grad is None
    assert self.dense_1.bias.grad is None
    del delta_2
    del z1


def net_eq(self, other):   # Overloading __eq__ gives issues
  """
    Defines two networks as the same if all their parameters are the same
    regardless of which class they are...
  """
  for s_param, o_param in zip(self.parameters(), other.parameters()):
    if not torch.all(s_param == o_param):
      return False
  return True


for i in range(100):
  # Initialize networks
  loss_fun = nn.MSELoss()
  x = torch.randn((5,50)).to(device)    # Batch of 5, 50 inputs
  net1 = ANN().to(device)
  pred1 = net1(x)  # Initialize weights
  net2 = MemoryEfficientCopyANN(net1).to(device) # They get the same parameters
  assert net_eq(net1, net2)

  for j in range(10):
    x = torch.randn((5,50)).to(device)
    y = torch.randn([5,1]).to(device)

    # Inference and backward
    pred1 = net1(x)
    loss1 = loss_fun(pred1, y)
    net1.backward(loss1)

    pred2 = net2(x)
    loss2 = loss_fun(pred2, y)
    net2.backward(loss2)

    # See if their parameters are equal
    assert net_eq(net1, net2)




## Using the PyTorch hooks implementation
Source: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html

In [4]:
class MemoryEfficientANNWithHooks(nn.Module):
  # Using this implementation: https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
  def __init__(self):
    super().__init__()
    self.dense_1 = nn.Linear(50, 50)
    self.dense_2 = nn.Linear(50, 50)
    self.dense_3 = nn.Linear(50, 1)
    self.optimizer = torch.optim.Adam(self.parameters())
    self.setup_optimizer()
  
  def _optimizer_hook(self, param, optimizer):
    optimizer.step()
    optimizer.zero_grad()

  def setup_optimizer(self):
    self.optimizer_list = [(p, torch.optim.Adam([p])) for p in self.parameters()]
    for p, optimizer in self.optimizer_list:
      hook = lambda param: self._optimizer_hook(param, optimizer)
      p.register_post_accumulate_grad_hook(hook)

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    return z3
  
  def backward(self, loss):
    loss.backward()

# Timing the methods

In [5]:
#%timeit <statement>
x = torch.randn((5,50)).to(device)
net1 = ANN().to(device)
pred1 = net1(x)  # Initialize weights
net2 = MemoryEfficientCopyANN(net1).to(device) # They get the same parameters
net3 = MemoryEfficientANNWithHooks().to(device)


def experiment(net, max_iter=1000):
  loss_fun = nn.MSELoss()
  for j in range(max_iter):
    # Sample random datapoint
    x = torch.randn((5,50)).to(device)
    y = torch.randn([5,1]).to(device)

    # Do inference
    pred = net(x)
    loss = loss_fun(pred, y)
    net.backward(loss)


print("Timing experiment: net 1")
%timeit experiment(net1)
print("\nTiming experiment: net 2")
%timeit experiment(net2)
print("\nTiming experiment: net 3")
%timeit experiment(net3)

Timing experiment: net 1
1.71 s ± 109 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Timing experiment: net 2
3.68 s ± 226 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Timing experiment: net 3
3.33 s ± 184 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Memory profiling of all 3 approaches
Run this in Google Colab instead... Then download the files and upload them to https://pytorch.org/memory_viz for visualization

In [6]:
TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

def experiment_memory_profiling(net, max_iter=1000):
  loss_fun = nn.MSELoss()

  # tell CUDA to start recording memory allocations
  torch.cuda.memory._record_memory_history(enabled='all')
  for j in range(max_iter):
    # Sample random datapoint
    x = torch.randn((5,50)).cuda()
    y = torch.randn([5]).cuda()

    # Do inference
    pred = net(x)
    loss = loss_fun(pred, y)
    net.backward(loss)

  # save a snapshot of the memory allocations
  timestamp = datetime.now().strftime(TIME_FORMAT_STR)
  file_prefix = f"{timestamp}"

  s = torch.cuda.memory._snapshot()
  with open(f"{file_prefix}.pickle", "wb") as f:
      dump(s, f)

  # tell CUDA to stop recording memory allocations now
  torch.cuda.memory._record_memory_history(enabled=None)

experiment_memory_profiling(net1.cuda())
experiment_memory_profiling(net2.cuda())
experiment_memory_profiling(net3.cuda())

RuntimeError: record_context_cpp is not support on non-linux non-x86_64 platforms

## Paper on this
https://arxiv.org/pdf/2306.09782.pdf