# Normal ANN

In [1]:
import torch
import torch.nn.functional as F
from torch import nn

from copy import deepcopy

class ANN(nn.Module):
  def __init__(self):
    super().__init__()
    self.relu = nn.ReLU()
    self.dense_1 = nn.LazyLinear(50)
    self.dense_2 = nn.LazyLinear(50)
    self.dense_3 = nn.LazyLinear(1)

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    return z3


def net_eq(self, other):   # Overloading __eq__ gives issues
  for s_param, o_param in zip(self.parameters(), other.parameters()):
    if not torch.all(s_param == o_param):
      return False
  return True


#############################################
# Test: Does deep copy work as anticipated? #
#############################################
for i in range(10):
  loss_fun = nn.MSELoss()
  x = torch.randn((1,50))
  y = torch.randn([1])

  # Inference to initialize network weights
  net1 = ANN()
  pred1 = net1(x)
  net2 = deepcopy(net1)   # So they get the same parameters
  pred2 = net2(x)
  assert net_eq(net1, net2)

  # Define optimizers differently
  optim1 = torch.optim.Adam(net1.parameters())
  optim2s = [
      torch.optim.Adam(layer.parameters()) for layer in [
          net2.dense_1, net2.dense_2, net2.dense_3
      ]]

  # Do a weight update on the first net
  loss1 = loss_fun(pred1, y)
  loss1.backward()
  optim1.step()

  # Do a weight update on the second net
  loss2 = loss_fun(pred2, y)
  loss2.backward()
  for optim2 in reversed(optim2s):
    optim2.step()

  # See if they are the same
  assert net_eq(net1, net2)


########################################################################################################################
# Test to be sure that the parameters are not just references of each other (otherwise it would give a false positive) #
########################################################################################################################
for i in range(10):
  # Just to be entirely sure we did not just end up copying the parameters
  # (from a failure in the deep copy), we do it again but don't update one of the layers on purpose
  loss_fun = nn.MSELoss()
  x = torch.randn((1,50))
  y = torch.randn([1])

  # Inference to initialize network weights
  net1 = ANN()
  pred1 = net1(x)
  net2 = deepcopy(net1)   # So they get the same parameters
  pred2 = net2(x)
  assert net_eq(net1, net2)

  # Define optimizers differently
  optim1 = torch.optim.Adam(net1.parameters())
  optim2s = [
      torch.optim.Adam(layer.parameters()) for layer in [
          net2.dense_1, net2.dense_2, net2.dense_3
      ]]

  # Do a weight update on the first net
  loss1 = loss_fun(pred1, y)
  loss1.backward()
  optim1.step()

  # Do a weight update on the second net
  loss2 = loss_fun(pred2, y)
  loss2.backward()
  for optim2 in reversed(optim2s[1:]):
    optim2.step()

  # See if they are the same
  assert not net_eq(net1, net2)

  return F.mse_loss(input, target, reduction=self.reduction)


# Manually computing the gradient across a single linear+relu layer, and comparing it with PyTorch

In [1]:
#######################################################################################################################
# Sanity check: Computing the gradient across a single layer and comparing with the chain rule (manual derivation)... #
#######################################################################################################################
import torch
from torch import nn
from torch.nn import functional as F

for i in range(10):   # Just for a sanity check...
  B = 2
  layer = nn.Linear(5, 3, bias=True)
  z0 = torch.randn([B, 5], requires_grad=True, dtype=torch.float32)  # Note: This is transposed compared to the derivation in the thesis.
  z1 = layer(z0)
  z2 = F.relu(z1)

  upstream_grad = torch.randn([B,3], dtype=torch.float32)
  z2.backward(upstream_grad, inputs=[z0]+list(layer.parameters()))

  relu_grad = (z1>0).float()  # Note: This is transposed compared to the derivation in the thesis.

  # Theoretical layer weight grad
  e = upstream_grad * relu_grad
  assert e.shape == torch.Size([B,3])
  weight_grad = e.T @ z0
  assert torch.all(layer.weight.grad == weight_grad)

  # Theoretical bias grad
  bias_grad = (e.T @ torch.ones([B,1])).squeeze(1)
  assert torch.all(layer.bias.grad == bias_grad)

  # Theoretical error signal
  downstream_grad = e @ layer.weight
  assert torch.all(z0.grad == downstream_grad)



#print((layer.weight.grad, weight_grad))
#print((layer.bias.grad, bias_grad))
#print((z0.grad, pass_through_grad))

# A proof-of-concept implementation of MEBP:

In [3]:
class ANN(nn.Module):
  def __init__(self):
    super().__init__()
    self.relu = nn.ReLU()
    self.dense_1 = nn.LazyLinear(50)
    self.dense_2 = nn.LazyLinear(50)
    self.dense_3 = nn.LazyLinear(1)
    self.optimizer = torch.optim.Adam(self.parameters())

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    return z3

  def backward(self, loss):
    loss.backward()
    self.optimizer.step()
    self.optimizer.zero_grad()


class MemoryEfficientCopyANN(nn.Module):
  def __init__(self, other):
    super().__init__()
    self.relu = nn.ReLU()
    self.dense_1 = deepcopy(other.dense_1) #nn.LazyLinear(50)
    self.dense_2 = deepcopy(other.dense_2) #nn.LazyLinear(50)
    self.dense_3 = deepcopy(other.dense_3) #nn.LazyLinear(1)
    self.dense_1_optim = torch.optim.Adam(self.dense_1.parameters())
    self.dense_2_optim = torch.optim.Adam(self.dense_2.parameters())
    self.dense_3_optim = torch.optim.Adam(self.dense_3.parameters())
    self.activations = []

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    self.activations = [z0,z1,z2,z3]   # Store references to tensors
    return z3

  def backward(self, loss):
    [z0,z1,z2,z3] = self.activations
    self.activations = []

    # Target error
    loss.backward(inputs=[z3], retain_graph=True)
    delta_L = z3.grad

    # Backpropagation with in-place optimization application and deletion
    z3.backward(delta_L, inputs=[z2]+list(self.dense_3.parameters()), retain_graph=True)
    delta_3 = z2.grad
    self.dense_3_optim.step()
    self.dense_3_optim.zero_grad()
    assert self.dense_3.weight.grad is None
    assert self.dense_3.bias.grad is None
    del delta_L
    del z3

    z2.backward(delta_3, inputs=[z1]+list(self.dense_2.parameters()), retain_graph=True)
    delta_2 = z1.grad
    self.dense_2_optim.step()
    self.dense_2_optim.zero_grad()
    assert self.dense_2.weight.grad is None
    assert self.dense_2.bias.grad is None
    del delta_3
    del z2

    z1.backward(delta_2, inputs=list(self.dense_1.parameters()), retain_graph=True)
    self.dense_1_optim.step()
    self.dense_1_optim.zero_grad()
    assert self.dense_1.weight.grad is None
    assert self.dense_1.bias.grad is None
    del delta_2
    del z1


def net_eq(self, other):   # Overloading __eq__ gives issues
  """
    Defines two networks as the same if all their parameters are the same
    regardless of which class they are...
  """
  for s_param, o_param in zip(self.parameters(), other.parameters()):
    if not torch.all(s_param == o_param):
      return False
  return True


##########################################################
# Test that it gives the same result as the previous ANN #
##########################################################
for i in range(10):
  # Initialize networks
  loss_fun = nn.MSELoss()
  x = torch.randn((1,50))
  net1 = ANN()
  pred1 = net1(x)  # Initialize weights
  net2 = MemoryEfficientCopyANN(net1) # They get the same parameters
  assert net_eq(net1, net2)

  for j in range(10):
    x = torch.randn((1,50))
    y = torch.randn([1])

    # Inference and backward
    pred1 = net1(x)
    loss1 = loss_fun(pred1, y)
    net1.backward(loss1)

    pred2 = net2(x)
    loss2 = loss_fun(pred2, y)
    net2.backward(loss2)

    # See if their parameters are equal
    assert net_eq(net1, net2)

In [4]:
# Smaller test setup:
dense_1 = nn.Linear(50,50)
dense_2 = nn.Linear(50,50)
dense_3 = nn.Linear(50,1)

loss_fun = nn.MSELoss()
x = torch.randn((1,50), requires_grad=True)
y = torch.randn([1])

z0 = x
z1 = F.relu(dense_1(z0))
z2 = F.relu(dense_2(z1))
z3 = dense_3(z2)


loss = loss_fun(z3, y)
loss.backward(inputs=[z3], retain_graph=True)
delta_L = z3.grad
print(z3.grad)
print(z2.grad)

z3.backward(delta_L, inputs=[z2], retain_graph=True)
delta_3 = z2.grad
print(z2.grad)

z2.backward(delta_3, inputs=[z1], retain_graph=True)
delta_2 = z1.grad
print(z1.grad)

z1.backward(delta_2, inputs=[z0], retain_graph=True)
print(z0.grad)

tensor([[1.0622]])
None
tensor([[-0.1235,  0.0092, -0.0157,  0.1239,  0.1425, -0.0889,  0.1394, -0.0997,
         -0.0472,  0.0039, -0.0728,  0.0582, -0.0176, -0.1152,  0.0290,  0.0760,
         -0.0502, -0.0193, -0.1090, -0.0832, -0.0413, -0.0591,  0.1392,  0.1319,
          0.1019, -0.1360, -0.1283,  0.0134, -0.1374,  0.0843,  0.1450,  0.1448,
          0.0108, -0.0203,  0.1238, -0.0971, -0.1053,  0.0800, -0.0249,  0.0301,
          0.1482, -0.1274, -0.1231,  0.0243, -0.0395, -0.0117, -0.0603,  0.0351,
         -0.0201,  0.1259]])
tensor([[-0.0125,  0.0138,  0.0082,  0.0052, -0.0614, -0.0092,  0.0079, -0.0165,
          0.0175, -0.0181,  0.0129,  0.0071,  0.0183, -0.0718,  0.0019, -0.0054,
          0.0170, -0.0150, -0.0127,  0.0065,  0.0844, -0.0414,  0.0362, -0.0496,
          0.0241, -0.0246, -0.0027, -0.0040,  0.0093,  0.0221, -0.0219,  0.0251,
         -0.0045,  0.0125, -0.0197, -0.0251,  0.0449,  0.0015,  0.0338, -0.0589,
          0.0163,  0.0328,  0.0352, -0.0282, -0.0039, -0

  print(z2.grad)


# Testing the one suggested here (using hooks):
https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html

In [9]:
class MemoryEfficientCopyANN2(nn.Module):
  def __init__(self, other):
    super().__init__()
    self.relu = nn.ReLU()
    self.dense_1 = deepcopy(other.dense_1) #nn.LazyLinear(50)
    self.dense_2 = deepcopy(other.dense_2) #nn.LazyLinear(50)
    self.dense_3 = deepcopy(other.dense_3) #nn.LazyLinear(1)
    self.setup_optimizer()

  def _optimizer_hook(self, parameter) -> None:
    self.optimizer_dict[parameter].step()
    self.optimizer_dict[parameter].zero_grad()

  def setup_optimizer(self):
    self.optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in self.parameters()}
    for p in self.parameters():
      p.register_post_accumulate_grad_hook(self._optimizer_hook)

  def forward(self, x):
    z0 = x
    z1 = F.relu(self.dense_1(z0))
    z2 = F.relu(self.dense_2(z1))
    z3 = self.dense_3(z2)
    return z3

  def backward(self, loss):
    loss.backward()


for i in range(10):
  # Initialize networks
  loss_fun = nn.MSELoss()
  x = torch.randn((1,50))
  net1 = ANN()
  pred1 = net1(x)  # Initialize weights
  net2 = MemoryEfficientCopyANN2(net1) # They get the same parameters
  assert net_eq(net1, net2)

  for j in range(10):
    x = torch.randn((1,50))
    y = torch.randn([1])

    # Inference and backward
    pred1 = net1(x)
    loss1 = loss_fun(pred1, y)
    net1.backward(loss1)

    pred2 = net2(x)
    loss2 = loss_fun(pred2, y)
    net2.backward(loss2)

    # See if their parameters are equal
    assert net_eq(net1, net2)