In [85]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

# Assume that we want to have a square linear layer with symmetric weights, that is, with weights X such that X = Xᵀ.
# One way to do so is to copy the upper-triangular part of the matrix into its lower-triangular part.


# Implementing parametrizations by hand
def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)


X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)  # A is symmetric
print(A)


# Implement "a linear layer with symmetric weights"
class LinearSymmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))

    def forward(self, x):
        A = symmetric(self.weight)
        return x @ A


layer = LinearSymmetric(n_features=3)
out = layer(torch.rand(8, 3))
print(out.shape)

# Promblem comes in.

# 1. It does not separate the layer and the parametrization. If the parametrization were more difficult, we would have to rewrite its code for each layer that we want to use it in.
# 2. It recomputes the parametrization every time we use the layer.
# If we use the layer several times during the forward pass, (imagine the recurrent kernel of an RNN), it would compute the same A every time that the layer is called.

# Parametrization can solve all these problems as well as others.


class Symmetric(nn.Module):
    def forward(self, x):
        print(f"Type of x: {type(x)} = nn.Parameter")
        print("x equals as `source for making symmtric matrix` below")
        return x.triu() + x.triu(1).transpose(-1, -2)


layer = nn.Linear(3, 3)
# replace `layer.weight` attribute with Symmetric()
# it inputs |3, 3| weights of original Linear layer, put into Symmetric.forward(),
# and do things as below:
# self.weight = Symmetric()(self.weight)
parametrize.register_parametrization(layer, "weight", Symmetric())
print(layer)

A = layer.weight
assert torch.allclose(A, A.T)
print(A)


# Similar case example (Skew)
# We can do the same thing with any other layer.
# For example, we can create a CNN with skew-symmetric kernels.
# We use a similar parametrization, copying the upper-triangular part with signs reversed into the lower-triangular part.
class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)


cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())

# Print a few kernels
print(cnn.weight[0, 1])
print(cnn.weight[2, 2])

tensor([[0.2041, 0.5863, 0.9884],
        [0.5863, 0.8330, 0.8776],
        [0.9884, 0.8776, 0.3282]])
torch.Size([8, 3])
Type of x: <class 'torch.nn.parameter.Parameter'> = nn.Parameter
x equals as `source for making symmtric matrix` below
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)
Type of x: <class 'torch.nn.parameter.Parameter'> = nn.Parameter
x equals as `source for making symmtric matrix` below
tensor([[-0.0467,  0.2155, -0.4387],
        [ 0.2155,  0.1236, -0.3720],
        [-0.4387, -0.3720, -0.5221]], grad_fn=<AddBackward0>)
tensor([[ 0.0000,  0.1193, -0.0805],
        [-0.1193,  0.0000, -0.1297],
        [ 0.0805,  0.1297,  0.0000]], grad_fn=<SelectBackward0>)
tensor([[ 0.0000, -0.0061, -0.0668],
        [ 0.0061,  0.0000,  0.0471],
        [ 0.0668, -0.0471,  0.0000]], grad_fn=<SelectBackward0>)
