In [53]:
import torch
import torch.nn as nn
import numpy as np
import copy
from torchviz import make_dot

from src.reparam_module import ReparamModule

# These are our reference variables that we wish to keep, to compare between models

# Input batch
# N = 7 # No batch for now, just a single variable

in_shape = (2, 3) # Batch, features
l1 = (in_shape[1], 4) # Shape of Linear Layer

class TestLinear(nn.Module):

    def __init__(self):
        super().__init__()

        # No bias because in MTT, we are using normalization layers
        self.layers = nn.Sequential(
            nn.Linear(in_features=l1[0], out_features=l1[1], bias=False),
            nn.ReLU(),
            nn.Linear(in_features=l1[1], out_features=8, bias=False),
            nn.ReLU(),
            nn.Linear(8, out_features=5, bias=  ),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layers(x)


# Avoid using these directly, instead use deepcopies
_model = TestLinear()
num_params = sum(np.prod(l.weight.shape) for l in _model.layers if isinstance(l, nn.Linear))
_img_syn = torch.randn(*in_shape).requires_grad_(True)  # Input of network
_student_params = torch.cat(
    [[l.weight.reshape(-1) for l in _model.layers if isinstance(l, nn.Linear)]]
).requires_grad_(
    True
)  # Imaginary parameters, in MTT sometimes called forward_param. These are the weights!
starting_params = torch.randn(num_params)  # Imaginary starting expert param
target_params = torch.randn(num_params)  # Imaginary ending expert param
y_hat = torch.empty(in_shape[0], dtype=torch.long).random_(4) # Imaginary targets, shape output of final layer

TypeError: expected Tensor as element 0 in argument 0, but got list

In [51]:
# MTT Ground Truth, we are trying to replicate the input image gradients

# Deep copies, so we know not used by manual implementation
img_syn = copy.deepcopy(_img_syn)
# I don't think it is possible to do a deepcopy as this removes dependencies
student_params = _student_params
student_params.retain_grad()
student_net_mtt = _model
student_net_mtt = ReparamModule(student_net_mtt)  # MTT black magic

optimizer_img = torch.optim.SGD([img_syn], lr=0.1, momentum=0.5)
optimizer_img.zero_grad()

student_net_mtt.train()

out = student_net_mtt(img_syn, flat_param=student_params) # Forward pass
out.retain_grad() # For validation
criterion = nn.CrossEntropyLoss()  # Used to calculate gradients of parameters
ce_loss = criterion(out, y_hat) # NllLossBackward0
ce_loss.retain_grad() # For validation

# Computes graph for student_params, but does not populate the .grad attributes!
# Returns a tuple with only a single term
grad = torch.autograd.grad(ce_loss, student_params, create_graph=True)[0] # CatBackward0
grad.retain_grad()

# Update student params (instead of using torch.optim.SGD)
student_params_new = (student_params - 0.1 * grad) # SubBackward0, MulBackward0, CatBackward0, ViewBackward0,
student_params_new.retain_grad() # For validation

param_loss = torch.nn.functional.mse_loss(student_params_new, target_params, reduction="sum") # MseLossBackward0
param_loss.retain_grad() # For validation
param_dist = torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum")

# Believe this is their manual way of taking the mean lol. But this isn't actually needed
# because they have done num_param / num_param = 1
# param_loss /= num_params
# param_dist /= num_params

# MTT grand loss, with the full computational graph determined by autograd
grand_loss = param_loss / param_dist
grand_loss.retain_grad() # For validation

# Before backwards pass of grand loss, gradient should be zero
assert img_syn.grad == None
optimizer_img.zero_grad()

# Now do backprop, and calculate
grand_loss.backward(retain_graph=True) # retain_graph for torchviz

print(grand_loss)
print(img_syn.grad)


RuntimeError: split_with_sizes expects split_sizes to sum exactly to 84 (input tensor's size at dimension 0), but got split_sizes=[12, 4, 32, 8, 40, 5]

In [49]:
# graph = make_dot(loss_grand, show_attrs=True) # Verbose, requires retaining graph
graph = make_dot(grand_loss, show_attrs=True)
# Img may be cut off, hence save to a file
graph.render("img/mtt_grand_linear", view=True, format="pdf")
# graph

# # graph = make_dot(loss_grand, show_attrs=True) # Verbose, requires retaining graph
# graph = make_dot(grad, show_attrs=True)
# # Img may be cut off, hence save to a file
# graph.render("img/grad", view=True, format="pdf")
# # graph

'img/mtt_grand_linear.pdf'

In [27]:
# let us calculate the forward computational graph manually

# print(ce_loss)

# Calculate CELoss forward
flat_reshaped = torch.reshape(student_params, (4,3))
# y_truth_one_hot = torch.zeros_like(out, dtype=int)  # Same shape as model output
y_truth_one_hot = torch.zeros_like(img_syn @ flat_reshaped.T, dtype=int)  # This is the forward pass of normal DNN
y_truth_one_hot[torch.arange(out.shape[0]), y_hat] = (
    1  # Assign index of label as 1, others are 0
)
y_max = torch.max(out, axis=1, keepdims=True)[0]
y_out_exp = torch.exp(out - y_max)
y_out_probs = y_out_exp / torch.sum(y_out_exp, axis=1, keepdims=True) # Cached for backwards
loss_val = -y_truth_one_hot * torch.log(y_out_probs)
loss_val = loss_val.sum(axis=1).mean()
assert torch.allclose(loss_val, ce_loss)
# print(y_out_probs.shape)
# Here we are forward passing the backprop of CELoss (yes, this will confuse you)
dout_val = y_out_probs
print(dout_val)
print(torch.arange(out.shape[0]))
dout_val[torch.arange(out.shape[0]), y_hat] -= 1
print(dout_val)
dout_val /= out.shape[0]  # Hand calculated dout to validate

print(dout_val.shape)

# # Calculate MmBackward0 forward (the 'hidden' layer)
MmForward = dout_val.T @ img_syn
MmForward = MmForward.view(-1)
assert torch.allclose(MmForward, grad)
_ = 0

tensor([[0.3276, 0.1205, 0.4679, 0.0840],
        [0.2291, 0.1715, 0.5255, 0.0739]], grad_fn=<DivBackward0>)
tensor([0, 1])
tensor([[ 0.3276,  0.1205, -0.5321,  0.0840],
        [ 0.2291,  0.1715,  0.5255, -0.9261]], grad_fn=<IndexPutBackward0>)
torch.Size([2, 4])


In [24]:
# Now let us calculate img_syn.grad manually

# dparam_loss (MSELoss wrt student_params_new)
dMSE = 2 * (student_params_new - target_params) / param_dist
# print(student_params.grad)
# print(dparam_loss)
# print(student_params.grad - dparam_loss)
assert(torch.allclose(student_params_new.grad, dMSE))

# dstudent_params_grad
# Take derivative of student_param_new wrt grad * incoming gradient
# This is the gradient that leaves the first MulBackward0 at bottom of graph
dstudent_params_grad = -0.1 * dMSE
# print(grad.grad)
# print(d_student_params)
# print(grad.grad- d_student_params)
assert(torch.allclose(grad.grad, dstudent_params_grad))

# dstudent_params_student_params (incoming grad. wrt student_params)
# Take derivative of student_param wrt student_param * incoming gradient
# TODO: This is dparam_loss plus whatever the other incoming gradient is (accumulate)
# This is the gradient that enters the CatBackward0 at the top of the graph
# dstudent_params_student_params = dparam_loss
# assert(torch.allclose(student_params.grad, dstudent_params_student_params))

# dMmBackward0 (hidden layer)

# Calculate backprop through the CELoss layer
# back = dstudent_params_grad / out.shape[0]
# back = torch.reshape(back, (3,4))
# # back[torch.arange(out.shape[0]), y_hat] -= 1
# print(back.shape)
# TODO: Figure out to backprop through the backprop of CELoss
dtmp = img_syn @ torch.reshape(dstudent_params_grad, (4,3)).T # Should be same shape as dout_val (2,4)
dtmp /= out.shape[0]

print(dtmp.shape)
# dtmp = dtmp @ flat_reshaped


# Incoming gradient is dstudent_params_grad or grad.grad
# d_hidden = dout_val @ torch.reshape(grad.grad, (4,3))
d_hidden = dout_val @ torch.reshape(dstudent_params_grad, (4,3))
print(d_hidden)
print(img_syn.grad)
# print(img_syn.grad - dtmp - d_hidden) # Trying to get this to zero

# dgrad (incoming grad. wrt grad)
# TODO: This must be examined in detail, check why it is 1
# Think it is because we backprop CE itself, not its gradient
# print(ce_loss.grad)
# dgrad = torch.tensor([1.0])
# # print(ce_loss.grad - dgrad)
# assert(torch.allclose(ce_loss.grad, dgrad))

# dstudent_params wrt grad
# First we have to calculate this weird 'hidden' layer

# dhidden = torch.ones_like(grad) * -0.1
# print(dhidden)

# dCELoss
# out is the last MMBackward0 before the AccumulateGrad for x

# print(out.shape)
# print(out - d_student_params.reshape([4,3]))
# out_test = out - img_syn @ torch.reshape(d_student_params, (3,4))

# out_exp = torch.exp(out - torch.max(out, axis=1, keepdims=True)[0])
# y_out_probs = out_exp / torch.sum(out_exp, axis=1, keepdims=True) # Forwards caches this
# dce_loss = y_out_probs
# dce_loss[torch.arange(out.shape[0]), y_hat] -= 1
# dce_loss /= out.shape[0]

# out_npy = out.detach().numpy()
# y_hat = y_hat.numpy()
# out_exp = np.exp(out_npy - np.max(out_npy, axis=1, keepdims=True)[0])
# y_out_probs = out_exp / np.sum(out_exp, axis=1, keepdims=True) # Forwards caches this
# dce_loss = y_out_probs
# dce_loss[np.arange(out_npy.shape[0]), y_hat] -= 1.0
# dce_loss /= out_npy.shape[0]
# print(dce_loss - out.grad.detach().numpy())

# print(dce_loss - out.grad)
# print(img_syn)
# print(dce_loss / out.grad)
# assert(torch.allclose(out.grad, dce_loss)) # The difference here is quite large, maybe torch calculates differently?

# dout (can't calculate it directly yet)

# print(out.grad.shape)
# print(torch.reshape(student_params, [4,3]).shape)

# dhidden_truth = out.grad - img_syn.grad
# print(dhidden_truth)



_ = 0

torch.Size([2, 4])
tensor([[ 0.0030,  0.0004,  0.0008],
        [-0.0061, -0.0013, -0.0025]], grad_fn=<MmBackward0>)
tensor([[ 0.0016,  0.0010,  0.0007],
        [-0.0064, -0.0011, -0.0023]])


In [18]:
# print(grad)

# current = grad

# while current.grad_fn is not None:
#     print(current.grad_fn)
#     current = current.grad_fn.next_functions[0][0]  # Move to the next function in the graph