## A (silly?) way to get mutable value-only views of module parameters

Different approach to
https://discuss.pytorch.org/t/how-to-flatten-and-then-unflatten-all-model-parameters/34730

Rather than copy-to and then require copy-from this will copy all parameters to a flat vector,
then reset those parameters to a view of the vector.

The *value* of the view will be connected to each parameter.

However, the *gradient* of the view will **not** be connected to each parameter (see below).

In [1]:
import dataclasses as dc
from functools import partial
from pprint import pprint
from textwrap import indent
from typing import Callable

import torch
from torch import nn

In [2]:
@dc.dataclass
class ParamView:
    numel: int
    get: Callable
    set: Callable
    
    @staticmethod
    def make(module, name, param):
        return ParamView(
            numel=param.numel(),
            get=partial(getattr, module, name),
            set=partial(setattr, module, name),
        )

def replace_parameters_with_view(module, predicate=lambda p: p.requires_grad):
    """
    Replaces all parameters (filtered by ``predicate``) from ``module`` with
    a view into a larger vector of all parameters, copying data.
    """
    views = []

    def recurse(sub):
        # First add immediate parameters.
        for name, p in sub.named_parameters(recurse=False):
            if predicate(p):
                assert p.grad is None
                views.append(ParamView.make(sub, name, p))
        # Then go to immediate children plz.
        for child in sub.children():
            recurse(child)

    # Recurse so that we can access objects.
    recurse(module)

    numel = sum(view.numel for view in views)
    p_full = torch.zeros(numel, requires_grad=True)
    offset = 0
    for view in views:
        # Get a sliced view.
        p_view = p_full[offset:offset + view.numel]
        # Copy data into full.
        p = view.get()
        p_shape = p.shape
        with torch.no_grad():
            p_view[:] = p.reshape(-1)
        # Replace parameter with (mutable) view.
        view.set(nn.Parameter(p_view.reshape(p_shape)))
        offset += view.numel

    return p_full

In [3]:
def print_param(module, extract=lambda p: p.data):
    for name, p in module.named_parameters():
        print(f"{name}:")
        print(indent(f"{extract(p)}", "  "))

In [4]:
class Test(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 2, bias=False)
        self.conv = nn.Conv1d(1, 1, 1, bias=False)

In [5]:
# Make stuff.
torch.random.manual_seed(0)
net = Test()
# Setup.
p = replace_parameters_with_view(net)

print("Initial values (seem about right as random init)")
print_param(net)
print()
print(p)

print("\nChange full vector to something silly (but unique)")
with torch.no_grad():
    p[:] = torch.arange(len(p))

print(p)
print()
print_param(net)

print("\nChange one param and show that it reflects in full vector")
with torch.no_grad():
    net.fc.weight[:] = 3.0
print_param(net)
print()
print(p)

Initial values (seem about right as random init)
fc.weight:
  tensor([[-0.0075],
          [ 0.5364]])
conv.weight:
  tensor([[[-0.8230]]])

tensor([-0.0075,  0.5364, -0.8230], requires_grad=True)

Change full vector to something silly (but unique)
tensor([0., 1., 2.], requires_grad=True)

fc.weight:
  tensor([[0.],
          [1.]])
conv.weight:
  tensor([[[2.]]])

Change one param and show that it reflects in full vector
fc.weight:
  tensor([[3.],
          [3.]])
conv.weight:
  tensor([[[2.]]])

tensor([3., 3., 2.], requires_grad=True)


In [6]:
# WARNING: gradients between tensors are *not* connected (in either direction).
# Or rather - can't have non-leaf parameters... derp!
# https://pytorch.org/docs/1.8.1/_modules/torch/nn/modules/module.html#Module.register_parameter
# https://discuss.pytorch.org/t/non-leaf-variables-as-a-modules-parameters/65775

# Clear gradients.
if p.grad is not None:
    p.grad = None
for param in net.parameters():
    if param.grad is not None:
        param.grad = None

print("From p -> param: no grad connection")
loss = torch.sum(p**2)
loss.backward()
# As expected here - no connection.
print(p.grad)
print()
print_param(net, lambda p: p.grad)

# Reset.
print()
print("From param -> p: also no grad connection")
p.grad = None
loss = torch.sum(net.fc.weight**2) + torch.sum(net.conv.weight**2)
loss.backward()
print(p.grad)
print()
print_param(net, lambda p: p.grad)

From p -> param: no grad connection
tensor([6., 6., 4.])

fc.weight:
  None
conv.weight:
  None

From param -> p: also no grad connection
None

fc.weight:
  tensor([[6.],
          [6.]])
conv.weight:
  tensor([[[4.]]])
