<a href="https://colab.research.google.com/github/01PrathamS/LoRA-implementation/blob/main/parametrizations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize

def symmetric(X):
  return X.triu() + X.triu(1).transpose(-1, -2)

X = torch.rand(3, 3)
A = symmetric(X)
assert torch.allclose(A, A.T)
print(A)

tensor([[0.3591, 0.4300, 0.2926],
        [0.4300, 0.3828, 0.5075],
        [0.2926, 0.5075, 0.7881]])


In [4]:
class LinearSymmetric(nn.Module):
  def __init__(self, n_features):
    super().__init__()
    self.weight = nn.Parameter(torch.rand(n_features, n_features))

  def forward(self, x):
    A = symmetric(self.weight)
    return x @ A

layer = LinearSymmetric(3)
out = layer(torch.rand(8, 3))

In [5]:
class Symmetric(nn.Module):
  def forward(self, X):
    return X.triu() + X.triu(1).transpose(-1, -2)

In [6]:
layer = nn.Linear(3, 3)
parametrize.register_parametrization(layer, "weight", Symmetric())

ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)

In [7]:
A = layer.weight
assert torch.allclose(A, A.T)
print(A)

tensor([[ 0.5212, -0.1496,  0.2497],
        [-0.1496, -0.0876, -0.1411],
        [ 0.2497, -0.1411,  0.3449]], grad_fn=<AddBackward0>)


In [9]:
class Skew(nn.Module):
  def forward(self, X):
    A = X.triu(1)
    return A - A.transpose(-1, -2)

cnn = nn.Conv2d(in_channels=5, out_channels=8, kernel_size=3)
parametrize.register_parametrization(cnn, "weight", Skew())

print(cnn.weight[0, 1])
print(cnn.weight[2, 2])

tensor([[ 0.0000, -0.0672, -0.0074],
        [ 0.0672,  0.0000, -0.1400],
        [ 0.0074,  0.1400,  0.0000]], grad_fn=<SelectBackward0>)
tensor([[ 0.0000, -0.1461, -0.0274],
        [ 0.1461,  0.0000,  0.0941],
        [ 0.0274, -0.0941,  0.0000]], grad_fn=<SelectBackward0>)


In [10]:
layer = nn.Linear(3, 3)

print(f"Unparametrized:\n{layer}")
parametrize.register_parametrization(layer, "weight", Symmetric())
print(f"\nParametrized:\n{layer}")

Unparametrized:
Linear(in_features=3, out_features=3, bias=True)

Parametrized:
ParametrizedLinear(
  in_features=3, out_features=3, bias=True
  (parametrizations): ModuleDict(
    (weight): ParametrizationList(
      (0): Symmetric()
    )
  )
)


In [11]:
print(layer.parametrizations)
print(layer.parametrizations.weight)

ModuleDict(
  (weight): ParametrizationList(
    (0): Symmetric()
  )
)
ParametrizationList(
  (0): Symmetric()
)


In [12]:
print(layer.parametrizations.weight[0])

Symmetric()


In [13]:
print(dict(layer.named_parameters()))

{'bias': Parameter containing:
tensor([0.4899, 0.4747, 0.2256], requires_grad=True), 'parametrizations.weight.original': Parameter containing:
tensor([[ 0.0733,  0.4334, -0.2508],
        [ 0.5538,  0.0838, -0.3891],
        [-0.2744, -0.2643,  0.4379]], requires_grad=True)}


In [14]:
print(layer.parametrizations.weight.original)

Parameter containing:
tensor([[ 0.0733,  0.4334, -0.2508],
        [ 0.5538,  0.0838, -0.3891],
        [-0.2744, -0.2643,  0.4379]], requires_grad=True)


In [16]:
symmetric = Symmetric()
weight_orig = layer.parametrizations.weight.original
print(torch.dist(layer.weight, symmetric(weight_orig)))

tensor(0., grad_fn=<DistBackward0>)


In [17]:
class NoisyParametrization(nn.Module):
  def forward(self, X):
    print("Computing the Parametrization")
    return X

layer = nn.Linear(4, 4)
parametrize.register_parametrization(layer, "weight", NoisyParametrization())
print("Here, layer.weight is recomputed every time we call it")
foo = layer.weight + layer.weight.T
bar = layer.weight.sum()
with parametrize.cached():
  print("Here, it is computed just the first time layer.weight is called.")
  foo = layer.weight + layer.weight.T
  bar = layer.weight.sum()

Computing the Parametrization
Here, layer.weight is recomputed every time we call it
Computing the Parametrization
Computing the Parametrization
Computing the Parametrization
Here, it is computed just the first time layer.weight is called.
Computing the Parametrization


In [None]:
# Concatenating Parametrizations

