In [2]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize


In [3]:
def symmetric(X):
    return X.triu() + X.triu(1).transpose(-1, -2)

X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)
print(A)


tensor([[0.6928, 0.6676, 0.3568],
        [0.6676, 0.9464, 0.4415],
        [0.3568, 0.4415, 0.1418]])


In [4]:
class LinearSymmetric(nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.weight = nn.Parameter(torch.rand(n_features, n_features))
    
    def forward(self, x):
        A = symmetric(self.weight)
        return x @ A

In [13]:
layer = LinearSymmetric(3)
out = layer(torch.rand(8,3))

In [14]:
class Symmetric(nn.Module):
    def forward(self, X):
        return X.triu() + X.triu(1).transpose(-1, -2)

In [15]:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())

ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)

In [16]:
A = layer.weight
assert torch.allclose(A, A.T)
print(A)

tensor([[ 0.0837,  0.1981, -0.1527],
        [ 0.1981, -0.0275, -0.2829],
        [-0.1527, -0.2829,  0.4862]], grad_fn=<AddBackward0>)


In [17]:
class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)
    
cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())

print(cnn.weight[0, 1])
print(cnn.weight[2, 2])


tensor([[ 0.0000, -0.1216,  0.1468],
        [ 0.1216,  0.0000, -0.1141],
        [-0.1468,  0.1141,  0.0000]], grad_fn=<SelectBackward0>)
tensor([[ 0.0000,  0.0427, -0.0385],
        [-0.0427,  0.0000, -0.0232],
        [ 0.0385,  0.0232,  0.0000]], grad_fn=<SelectBackward0>)


In [18]:
layer = nn.Linear(3,3)
print(f"Unparametrized: \n {layer}")
parametrize.register_parametrization(layer, "weight", Symmetric())
print(f"Parametrized: \n {layer}")


Unparametrized: 
 Linear(in_features=3, out_features=3, bias=True)
Parametrized: 
 ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)


In [19]:
print(layer.parametrizations)
print(layer.parametrizations.weight)

ModuleDict(
  (weight): ParametrizationList(
    (0): Symmetric()
  )
)
ParametrizationList(
  (0): Symmetric()
)


In [20]:
print(layer.parametrizations.weight[0])

Symmetric()


In [21]:
print(dict(layer.named_parameters()))

{'bias': Parameter containing:
tensor([-0.1309, -0.2774,  0.5645], requires_grad=True), 'parametrizations.weight.original': Parameter containing:
tensor([[-0.1720, -0.3462, -0.0612],
        [-0.1022, -0.3147, -0.2553],
        [-0.3584, -0.3606, -0.5419]], requires_grad=True)}


In [22]:
symmetric = Symmetric()
weight_orig = layer.parametrizations.weight.original
print(torch.dist(layer.weight, symmetric(weight_orig)))

tensor(0., grad_fn=<DistBackward0>)


In [23]:
class NoisyParametrization(nn.Module):
    def forward(self, X):
        print("Computing the Parametrization")
        return X

layer = nn.Linear(4, 4)
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
print("Here, layer.weight is recomputed every time we call it")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
with parametrize.cached():
    print("Here, it is computed just the first time layer.weight is called.")
    foo = layer.weight + layer.weight.T
    bar = layer.weight.sum()

Computing the Parametrization
Here, layer.weight is recomputed every time we call it
Computing the Parametrization
Computing the Parametrization
Computing the Parametrization
Here, it is computed just the first time layer.weight is called.
Computing the Parametrization


In [25]:
class CayleyMap(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.register_buffer("Id", torch.eye(n))
    
    def forward(self, X):
        # (I + X)(I - X)^{-1}
        return torch.linalg.solve(self.Id - X, self.Id + X)

layer = nn.Linear(3,3)
parametrize.register_parametrization(layer, "weight", Skew())
parametrize.register_parametrization(layer, "weight", CayleyMap(3))

X = layer.weight
print(torch.dist(X.T @ X, torch.eye(3)))


tensor(1.6859e-07, grad_fn=<DistBackward0>)


In [27]:
class MatrixExponential(nn.Module):
    def forward(self, X):
        return torch.matrix_exp(X)

layer_orthogonal = nn.Linear(3, 3)
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
parametrize.register_parametrization(layer_orthogonal, "weight", MatrixExponential())
X = layer_orthogonal.weight
print(torch.dist(X.T @ X, torch.eye(3)))         # X is orthogonal

layer_spd = nn.Linear(3, 3)
parametrize.register_parametrization(layer_spd, "weight", Symmetric())
parametrize.register_parametrization(layer_spd, "weight", MatrixExponential())
X = layer_spd.weight
print(torch.dist(X, X.T))                        # X is symmetric
print((torch.linalg.eigvalsh(X) > 0.).all())  # X is positive definite

tensor(1.9158e-07, grad_fn=<DistBackward0>)
tensor(5.2684e-09, grad_fn=<DistBackward0>)
tensor(True)


In [43]:
class Skew(nn.Module):
    def forward(self, X):
        A = X.triu(1)
        return A - A.transpose(-1, -2)

    def right_inverse(self, A):
        # We assume that A is skew-symmetric
        # We take the upper-triangular elements, as these are those used in the forward function
        return A.triu(1)

In [44]:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Skew())
X = torch.rand(3, 3)
X = X - X.T                             # X is now skew-symmetric
layer.weight = X                        # Initialize layer.weight to be X
print(torch.dist(layer.weight, X))      # layer.weight == X

tensor(0., grad_fn=<DistBackward0>)


In [46]:
class CayleyMap(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.register_buffer("Id", torch.eye(n))

    def forward(self, X):
        # Assume X skew-symmetric
        # (I + X)(I - X)^{-1}
        return torch.linalg.solve(self.Id - X, self.Id + X)

    def right_inverse(self, A):
        # Assume A orthogonal
        # See https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
        # (A - I)(A + I)^{-1}
        return torch.linalg.solve(A + self.Id, self.Id - A)

layer_orthogonal = nn.Linear(3, 3)
parametrize.register_parametrization(layer_orthogonal, "weight", Skew())
parametrize.register_parametrization(layer_orthogonal, "weight", CayleyMap(3))
# Sample an orthogonal matrix with positive determinant
X = torch.empty(3, 3)
nn.init.orthogonal_(X)
if X.det() < 0.:
    X[0].neg_()
layer_orthogonal.weight = X
print(torch.dist(layer_orthogonal.weight, X))  # layer_orthogonal.weight == X

tensor(1.9122, grad_fn=<DistBackward0>)


In [47]:
class PruningParametrization(nn.Module):
    def __init__(self, X, p_drop=0.2):
        super().__init__()
        # sample zeros with probability p_drop
        mask = torch.full_like(X, 1.0 - p_drop)
        self.mask = torch.bernoulli(mask)

    def forward(self, X):
        return X * self.mask

    def right_inverse(self, A):
        return A

In [55]:
layer = nn.Linear(3, 4)
X = torch.rand_like(layer.weight)
print(f"Initialization matrix:\n{X}")
parametrize.register_parametrization(layer, "weight", PruningParametrization(layer.weight))
layer.weight = X
print(f"\nInitialized weight:\n{layer.weight}")

Initialization matrix:
tensor([[0.9677, 0.1552, 0.1022],
        [0.6760, 0.2348, 0.3606],
        [0.8375, 0.4543, 0.1959],
        [0.0444, 0.1885, 0.0218]])

Initialized weight:
tensor([[0.0000, 0.1552, 0.0000],
        [0.0000, 0.0000, 0.3606],
        [0.8375, 0.0000, 0.1959],
        [0.0444, 0.0000, 0.0218]], grad_fn=<MulBackward0>)


In [56]:
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight")
print("\nAfter. Weight has skew-symmetric values but it is unconstrained:")
print(layer)
print(layer.weight)

Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[-0.5208, -0.1568, -0.0597],
        [ 0.4842, -0.5202, -0.2884],
        [-0.5494, -0.4066, -0.5236]], requires_grad=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Skew()
    )
  )
)
tensor([[ 0.0000, -0.1568, -0.0597],
        [ 0.1568,  0.0000, -0.2884],
        [ 0.0597,  0.2884,  0.0000]], grad_fn=<SubBackward0>)

After. Weight has skew-symmetric values but it is unconstrained:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0000, -0.1568, -0.0597],
        [ 0.1568,  0.0000, -0.2884],
        [ 0.0597,  0.2884,  0.0000]], requires_grad=True)


In [57]:
layer = nn.Linear(3, 3)
print("Before:")
print(layer)
print(layer.weight)
parametrize.register_parametrization(layer, "weight", Skew())
print("\nParametrized:")
print(layer)
print(layer.weight)
parametrize.remove_parametrizations(layer, "weight", leave_parametrized=False)
print("\nAfter. Same as Before:")
print(layer)
print(layer.weight)

Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[-0.2894, -0.0838, -0.0855],
        [ 0.4133,  0.3882, -0.1892],
        [-0.1343,  0.3262,  0.3094]], requires_grad=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Skew()
    )
  )
)
tensor([[ 0.0000, -0.0838, -0.0855],
        [ 0.0838,  0.0000, -0.1892],
        [ 0.0855,  0.1892,  0.0000]], grad_fn=<SubBackward0>)

After. Same as Before:
Linear(in_features=3, out_features=3, bias=True)
Parameter containing:
tensor([[ 0.0000, -0.0838, -0.0855],
        [ 0.0000,  0.0000, -0.1892],
        [ 0.0000,  0.0000,  0.0000]], requires_grad=True)
