In [2]:
import torch
import torch.nn as nn
import numpy as np
import copy
from torchviz import make_dot

from src.reparam_module import ReparamModule


In [113]:
class Model(nn.Module):

    def __init__(self):
        super().__init__()

        # No bias because in MTT, we are using normalization layers
        self.features = nn.Sequential(
                nn.Conv2d(
                    in_channels=3, out_channels=32, kernel_size=3, padding=1, bias=True
                ),
                nn.GroupNorm(num_groups=32, num_channels=32, affine=True),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=2, stride=2),

                nn.Conv2d(
                    in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True
                ),
                nn.GroupNorm(num_groups=32, num_channels=32, affine=True),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=2, stride=2),

                nn.Conv2d(
                    in_channels=32, out_channels=32, kernel_size=3, padding=1, bias=True
                ),
                nn.GroupNorm(num_groups=32, num_channels=32, affine=True),
                nn.ReLU(inplace=True),
                nn.AvgPool2d(kernel_size=2, stride=2),
            )

        self.classifier = nn.Linear(32 * 4 * 4 , 9) # Features


    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1) # Reshape feature net to [batch_size, H*W/ (2 ** pooling layers)]
        out = self.classifier(out)
        return out


model = Model() # Inits with random parameters that are replaced by student_params
model_val = copy.deepcopy(model)
mtt_model = ReparamModule(model)

img_syn = torch.randn(7, 3, 32, 32).requires_grad_(True) # Image input
optimizer_img = torch.optim.SGD([img_syn], lr=0.1, momentum=0.5)
optimizer_img.zero_grad()

mtt_model.train()

# Number of learnable parameters
param_shapes = [np.prod(p.size()) for p in (mtt_model.parameters())]
num_params = sum(param_shapes)
# print(num_params)
starting_params = torch.randn(num_params)  # Imaginary starting expert param
student_params = starting_params.detach().clone().requires_grad_(True) # Student parameters to train, initially starting_params
target_params = torch.randn(num_params)  # Imaginary ending expert param

out = mtt_model(img_syn, flat_param=student_params) # Forward pass using student_params
# print(out.shape)

y_hat = torch.empty(out.shape[0], dtype=torch.long).random_(4) # Imaginary targets, shape output of final layer
# print(y_hat.shape)

criterion = nn.CrossEntropyLoss()  # Used to calculate gradients of parameters
ce_loss = criterion(out, y_hat)

# Computes graph for student_params, but does not populate the .grad attributes!
# Returns a tuple with only a single term
grad = torch.autograd.grad(ce_loss, student_params, create_graph=True)[0] # CatBackward0
grad.retain_grad()

# Update student params (instead of using torch.optim.SGD)
student_params_new = (student_params - 0.1 * grad) # SubBackward0, MulBackward0, CatBackward0, ViewBackward0,
student_params_new.retain_grad() # For validation

param_loss = torch.nn.functional.mse_loss(student_params_new, target_params, reduction="sum") # MseLossBackward0
param_loss.retain_grad() # For validation
param_dist = torch.nn.functional.mse_loss(starting_params, target_params, reduction="sum")

# MTT grand loss, with the full computational graph determined by autograd
grand_loss = param_loss / param_dist
grand_loss.retain_grad() # For validation

# Before backwards pass of grand loss, gradient should be zero
assert img_syn.grad == None
optimizer_img.zero_grad()

# Now do backprop, and calculate
grand_loss.backward(retain_graph=True) # retain_graph for torchviz

# For this case, grand loss should be about 1, because student_params_new = student_params - 0.1 * grad
# and we haven't done more update steps, so student_params_new ~ starting_params
print(grand_loss)
print(img_syn.grad.shape)

_ = 0

tensor(0.9998, grad_fn=<DivBackward0>)
torch.Size([7, 3, 32, 32])


In [None]:
graph = make_dot(grand_loss, show_attrs=True)
# Img may be cut off, hence save to a file
graph.render("img/conv_grand_loss", view=True, format="pdf")
# graph

In [114]:
class Model(nn.Module):

    def __init__(self):
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Linear(8, 5, bias=False),
            # nn.ReLU(),
            # nn.Linear(5, 9, bias=True),
            # nn.ReLU(),
            # nn.Linear(9, 13, bias=True),
            # nn.ReLU(),
        )

    def forward(self, x):
        out = self.classifier(x)
        return out


model = (
    Model()
)  # Inits with random parameters that are replaced by student_params
model_val = copy.deepcopy(model)
mtt_model = ReparamModule(model)

img_syn = torch.randn(7, 8).requires_grad_(True)  # Image input
optimizer_img = torch.optim.SGD([img_syn], lr=0.1, momentum=0.5)
optimizer_img.zero_grad()

mtt_model.train()

# Number of learnable parameters
param_shapes = [np.prod(p.size()) for p in (mtt_model.parameters())]
num_params = sum(param_shapes)
# print(num_params)
starting_params = torch.randn(num_params)  # Imaginary starting expert param
student_params = (
    starting_params.detach().clone().requires_grad_(True)
)  # Student parameters to train, initially starting_params
target_params = torch.randn(num_params)  # Imaginary ending expert param

out = mtt_model(
    img_syn, flat_param=student_params
)  # Forward pass using student_params
# print(out.shape)

y_hat = torch.empty(out.shape[0], dtype=torch.long).random_(
    4
)  # Imaginary targets, shape output of final layer
# print(y_hat.shape)

criterion = nn.CrossEntropyLoss()  # Used to calculate gradients of parameters
ce_loss = criterion(out, y_hat)

# Computes graph for student_params, but does not populate the .grad attributes!
# Returns a tuple with only a single term
grad = torch.autograd.grad(ce_loss, student_params, create_graph=True)[
    0
]  # CatBackward0
grad.retain_grad()

# Update student params (instead of using torch.optim.SGD)
student_params_new =  student_params - 0.1 * grad
  # SubBackward0, MulBackward0, CatBackward0, ViewBackward0,
student_params_new.retain_grad()  # For validation

param_loss = torch.nn.functional.mse_loss(
    student_params_new, target_params, reduction="sum"
)  # MseLossBackward0
param_loss.retain_grad()  # For validation
param_dist = torch.nn.functional.mse_loss(
    starting_params, target_params, reduction="sum"
)

# MTT grand loss, with the full computational graph determined by autograd
grand_loss = param_loss / param_dist
grand_loss.retain_grad()  # For validation

# Before backwards pass of grand loss, gradient should be zero
assert img_syn.grad == None
optimizer_img.zero_grad()

# Now do backprop, and calculate
grand_loss.backward(retain_graph=True)  # retain_graph for torchviz

# For this case, grand loss should be about 1, because student_params_new = student_params - 0.1 * grad
# and we haven't done more update steps, so student_params_new ~ starting_params
print(grand_loss)
print(img_syn.grad.shape)

tensor(0.9937, grad_fn=<DivBackward0>)
torch.Size([7, 8])


In [116]:
graph = make_dot(grand_loss, show_attrs=True)
# Img may be cut off, hence save to a file
graph.render("img/linear_grand_loss", view=True, format="pdf")
# graph

'img/linear_grand_loss.pdf'