In [34]:
# Notebook to calculate backprop for the grand loss in MTT, i.e. 'gradient of gradients'

In [35]:
import torch
import torch.nn as nn
import numpy as np

In [52]:
# First we test using only linear layers

# Linear layers y = xWT + b


# Layer 1
in_dim = (5, 4)
out_dim = 7
x1 = torch.rand(*in_dim, requires_grad=True)
WT1 = torch.rand(in_dim[-1], out_dim, requires_grad=True)
b1 = torch.randn(out_dim, requires_grad=True)
y1 = x1 @ WT1 + b1
y1.retain_grad()
# print(y1.shape)

# Layer 2
in_dim = y1.shape
out_dim = 4
WT2 = torch.rand(in_dim[-1], out_dim, requires_grad=True)
b2 = torch.randn(out_dim, requires_grad=True)
y2 = y1 @ WT2 + b2
y2.retain_grad()
# print(y2.shape)

# Layer 3
in_dim = y2.shape
out_dim = 3
WT3 = torch.rand(in_dim[-1], out_dim, requires_grad=True)
b3 = torch.randn(out_dim, requires_grad=True)
out = y2 @ WT3 + b3
out.retain_grad() # Output of network

# Sanity check
# print(out == (((x1 @ WT1 + b1) @ WT2 + b2)) @ WT3 + b3)

# Loss
target = torch.randn(out.shape)
loss_fn = nn.MSELoss()
output = loss_fn(out, target)
output.retain_grad()
output.backward()

dout_val = 2 * (out - target) / np.prod(out.shape) # derivative of MSELoss
dx1_val = 0

# print(dout_val)
# print(out.grad)


In [37]:
# Only tensors that are not computed from other tensors are leaves.
# This tells us that all weight and bias tensors are leaves initally.
# for var in [x1, y2, out, WT1, WT2, WT3, b1, b2, b3]:
#     print(var.is_leaf)

# print(WT1)
# print(WT1.grad)

# optimizer = torch.optim.SGD([WT1, WT2, WT3, b1, b2, b3], lr=0.1)
# # Perform one step of SGD
# optimizer.step()

# # WT1 has now been updated, but is still a leaf (no dependencies!)
# print(WT1)

# p_stacked = torch.cat[]

In [74]:
# Calculating the grand loss as described in MTT

# First flatten and concatenate all the parameters
parameters = [WT1, WT2, WT3, b1, b2, b3]
# All the parameters should be leaves because they don't have dependencies
# for var in parameters:
#     print(var.is_leaf)
p_stacked = torch.cat([p.reshape(-1) for p in parameters])
p_stacked.retain_grad()
# Sanity check
num_param = sum(np.prod(p.shape) for p in parameters)
# print(p_stacked.shape[0] == num_param)

# Imaginary start and end expert parameters
exp_start_stacked = torch.randn(p_stacked.shape)
exp_end_stacked = torch.randn(p_stacked.shape)

param_loss_mse = nn.MSELoss(reduction='sum')
param_dist_mse = nn.MSELoss(reduction='sum')

param_loss = param_loss_mse(p_stacked, exp_end_stacked)
param_loss /= num_param
param_dist = param_dist_mse(exp_start_stacked, exp_end_stacked)
param_dist /= num_param

grand_loss = param_loss / param_dist

print(grand_loss)

print(p_stacked.grad)
grand_loss.backward()
# print(grand_loss)
print(p_stacked.grad)


tensor(0.6664, grad_fn=<DivBackward0>)
None
tensor([ 0.0258,  0.0180, -0.0020, -0.0198,  0.0234,  0.0145,  0.0194,  0.0134,
        -0.0003,  0.0318,  0.0197,  0.0276,  0.0084,  0.0098,  0.0188,  0.0068,
         0.0016,  0.0029, -0.0031, -0.0044, -0.0030,  0.0082, -0.0004, -0.0117,
         0.0126, -0.0087,  0.0401, -0.0018, -0.0074,  0.0162, -0.0184,  0.0083,
         0.0078, -0.0076,  0.0018, -0.0013,  0.0152, -0.0048, -0.0216,  0.0157,
        -0.0019,  0.0010, -0.0014,  0.0116,  0.0028,  0.0244,  0.0031,  0.0005,
         0.0152,  0.0075,  0.0206, -0.0133, -0.0091, -0.0113, -0.0295,  0.0281,
        -0.0168, -0.0091,  0.0024, -0.0074, -0.0117,  0.0164, -0.0103, -0.0116,
         0.0151, -0.0035,  0.0132,  0.0209, -0.0029, -0.0088, -0.0212, -0.0243,
         0.0004,  0.0065,  0.0111, -0.0009,  0.0128,  0.0265, -0.0011, -0.0164,
         0.0159, -0.0255])
