In [54]:
import torch
import torch.nn as nn
from torch.nn.utils import vector_to_parameters, parameters_to_vector

In [107]:
from models import MLPModel

In [115]:
def zero_grad(model):
    for param in model.parameters():
        param.grad = None

def view_grad(model):
    grads = []
    for param in model.parameters():
        grads.append(param.grad.view(-1))
    grads = torch.cat(grads)
    print(grads)

In [116]:
def flip_parameters_to_tensors(module):
    attr = []
    while bool(module._parameters):
        attr.append( module._parameters.popitem() )
    setattr(module, 'registered_parameters_name', [])

    for i in attr:
        setattr(module, i[0], torch.zeros(i[1].shape,requires_grad=True))
        module.registered_parameters_name.append(i[0])

    module_name = [k for k,v in module._modules.items()]

    for name in module_name:
        flip_parameters_to_tensors(module._modules[name])

In [130]:
def set_all_parameters(module, theta):
    count = 0  

    for name in module.registered_parameters_name:
        a = count
        b = a + getattr(module, name).numel()
        t = torch.reshape(theta[a:b], getattr(module, name).shape)
        setattr(module, name, t)

        count += getattr(module, name).numel()

    module_name = [k for k,v in module._modules.items()]
    for name in module_name:
        count += set_all_parameters(module._modules[name], theta)
    return count

In [131]:
model = MLPModel(2, 2, 4, 1)
zero_grad(model)

In [132]:
data = torch.randn(3,2)
data

tensor([[ 0.1628,  1.3552],
        [-0.2975, -0.7004],
        [-1.5971, -1.1088]])

In [133]:
nparam = len(parameters_to_vector(model.parameters()))
w = torch.randn(nparam, requires_grad=True)
print(w.shape)

torch.Size([57])


In [134]:
vector_to_parameters(w, model.parameters())

In [135]:
zero_grad(model)
output = torch.sum( model(data) )
output.backward()

In [136]:
view_grad(model)

tensor([ 0.0000e+00,  0.0000e+00,  6.0401e-01,  5.7570e-01,  3.0632e-01,
         2.9196e-01, -2.7470e-04, -2.2868e-03,  0.0000e+00, -6.4582e-01,
        -3.2753e-01, -1.6874e-03,  0.0000e+00, -4.2719e-01, -2.0477e+00,
        -4.2029e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -8.6706e-01,  0.0000e+00,
         0.0000e+00,  0.0000e+00, -7.9182e-01,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         1.7573e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -4.9329e-01,  0.0000e+00,
         1.5833e+00,  0.0000e+00,  4.5765e-01,  0.0000e+00,  4.9648e-01,
         0.0000e+00,  1.3579e+00])


In [137]:
w.grad # parameters_to_vector does not keep track of gradients

In [138]:
l = 0
for param in model.parameters():
    nl = param.numel()
    param = w[l:l+nl].reshape(param.shape)
    l += nl

In [139]:
zero_grad(model)
output = torch.sum( model(data) )
output.backward()

In [140]:
view_grad(model)

tensor([ 0.0000e+00,  0.0000e+00,  6.0401e-01,  5.7570e-01,  3.0632e-01,
         2.9196e-01, -2.7470e-04, -2.2868e-03,  0.0000e+00, -6.4582e-01,
        -3.2753e-01, -1.6874e-03,  0.0000e+00, -4.2719e-01, -2.0477e+00,
        -4.2029e-03,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -8.6706e-01,  0.0000e+00,
         0.0000e+00,  0.0000e+00, -7.9182e-01,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         1.7573e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00, -4.9329e-01,  0.0000e+00,
         1.5833e+00,  0.0000e+00,  4.5765e-01,  0.0000e+00,  4.9648e-01,
         0.0000e+00,  1.3579e+00])


In [141]:
w.grad # simple slicing and assigning does not keep track of gradients

In [142]:
flip_parameters_to_tensors(model)
set_all_parameters(model, w)

57

In [143]:
w.grad

In [144]:
output = torch.sum( model(data) )
output.backward()

In [145]:
w.grad

tensor([ 0.8476,  2.5635, -0.2921, -1.0884, -0.5252, -0.1708, -0.1968, -1.0864,
        -0.6340, -0.2164, -0.1360, -0.1719, -0.5305, -0.3840, -0.2965, -0.4678,
        -1.6629,  0.0000,  0.0000, -0.3666,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
         0.0000])