## torch.nn.utils.parametrize

In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

# Assume that we want to have a square linear layer with symmetric weights, that is, with weights X such that X = Xᵀ.
# One way to do so is to copy the upper-triangular part of the matrix into its lower-triangular part.


# Implementing parametrizations by hand
def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)


X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)  # A is symmetric
print(A)


# Implement "a linear layer with symmetric weights"
class LinearSymmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))

    def forward(self, x):
        A = symmetric(self.weight)
        return x @ A


layer = LinearSymmetric(n_features=3)
out = layer(torch.rand(8, 3))
print(out.shape)

# Promblem comes in.

# 1. It does not separate the layer and the parametrization. If the parametrization were more difficult, we would have to rewrite its code for each layer that we want to use it in.
# 2. It recomputes the parametrization every time we use the layer.
# If we use the layer several times during the forward pass, (imagine the recurrent kernel of an RNN), it would compute the same A every time that the layer is called.

# Parametrization can solve all these problems as well as others.


class Symmetric(nn.Module):
    def forward(self, x):
        print(f"Type of x: {type(x)} = nn.Parameter")
        print("x equals as `source for making symmtric matrix` below")
        return x.triu() + x.triu(1).transpose(-1, -2)


layer = nn.Linear(3, 3)
# replace `layer.weight` attribute with Symmetric()
# it inputs |3, 3| weights of original Linear layer, put into Symmetric.forward(),
# and do things as below:
# self.weight = Symmetric()(self.weight)
parametrize.register_parametrization(layer, "weight", Symmetric())
print(layer)

A = layer.weight
assert torch.allclose(A, A.T)
print(A)


# Similar case example (Skew)
# We can do the same thing with any other layer.
# For example, we can create a CNN with skew-symmetric kernels.
# We use a similar parametrization, copying the upper-triangular part with signs reversed into the lower-triangular part.
class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)


cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())

# Print a few kernels
print(cnn.weight[0, 1])
print(cnn.weight[2, 2])

tensor([[0.5177, 0.4886, 0.2425],
        [0.4886, 0.3903, 0.7523],
        [0.2425, 0.7523, 0.3229]])
torch.Size([8, 3])
Type of x: <class 'torch.nn.parameter.Parameter'> = nn.Parameter
x equals as `source for making symmtric matrix` below
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)
Type of x: <class 'torch.nn.parameter.Parameter'> = nn.Parameter
x equals as `source for making symmtric matrix` below
tensor([[-0.2961,  0.5623,  0.2940],
        [ 0.5623, -0.3487,  0.3402],
        [ 0.2940,  0.3402,  0.1318]], grad_fn=<AddBackward0>)
tensor([[ 0.0000,  0.0506,  0.0236],
        [-0.0506,  0.0000,  0.1251],
        [-0.0236, -0.1251,  0.0000]], grad_fn=<SelectBackward0>)
tensor([[ 0.0000,  0.1292,  0.0750],
        [-0.1292,  0.0000, -0.1343],
        [-0.0750,  0.1343,  0.0000]], grad_fn=<SelectBackward0>)


In [42]:
# Check if the layer is parametrize
parametrizations = getattr(layer, "parametrizations", None)
print(type(parametrizations), len(parametrizations))
print(parametrizations)

# use function below
from torch.nn.utils.parametrize import is_parametrized, type_before_parametrizations

print(is_parametrized(layer))
print(type(layer), type_before_parametrizations(layer))

<class 'torch.nn.modules.container.ModuleDict'> 1
ModuleDict(
  (weight): ParametrizationList(
    (0): Symmetric()
  )
)
True
<class 'torch.nn.utils.parametrize.ParametrizedLinear'> <class 'torch.nn.modules.linear.Linear'>


## torch.quantization.fuse_modules

In [7]:
from torchvision.models import resnet18

model = resnet18(num_classes=10, weights=None)

In [63]:
# Models to fuse: conv1, bn1, relu

# Function: get_submodule
print("--------- layer 1 ---------")
print(model.get_submodule("layer1"))

print("--------- layer 1 > first module > conv1 ---------")
print(model.get_submodule("layer1")[0].get_submodule("conv1"))

# Function: named_submodules
named_modules = list(model.named_modules())
# First element would be: "", entire model
# next element would be continued from 1
print("--------- named_modules ---------")
print(named_modules[1][0], named_modules[1][1])

# Function: get directly
print("--------- named_modules vs. __getattr__---------")
print(getattr(model, "conv1") == named_modules[1][1])


# Function: fuse modules
print("--------- Fuse modules ---------")
from typing import List, Optional
from torch.ao.quantization.fuse_modules import fuse_known_modules
from torch.ao.quantization.fuser_method_mappings import get_fuser_method

model.eval()

modules_to_fuse = ["conv1", "bn1", "relu"]
mod_list = [model.get_submodule(mod) for mod in modules_to_fuse]
types = tuple(type_before_parametrizations(m) for m in mod_list)
print(types)
fuser_method = get_fuser_method(types)

is_qat = False
new_mod: List[Optional[nn.Module]] = [None] * len(mod_list)
fused = fuser_method(is_qat, *mod_list)
print(fused)

# NOTE: forward hooks not processed in the two following for loops will be lost after the fusion
# Move pre forward hooks of the base module to resulting fused module
for handle_id, pre_hook_fn in mod_list[0]._forward_pre_hooks.items():
    fused.register_forward_pre_hook(pre_hook_fn)
    del mod_list[0]._forward_pre_hooks[handle_id]

# Move post forward hooks of the last module to resulting fused module
for handle_id, hook_fn in mod_list[-1]._forward_hooks.items():
    fused.register_forward_hook(hook_fn)
    del mod_list[-1]._forward_hooks[handle_id]

# The first element in the output module list performs the fused operation.
# The rest of the elements are set to nn.Identity()
new_mod[0] = fused
for i in range(1, len(mod_list)):
    identity = nn.Identity()
    identity.training = mod_list[0].training
    new_mod[i] = identity
print(new_mod)

--------- layer 1 ---------
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
--------- layer 1 > first module > conv1 ---------
Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=