In [38]:
from importlib.metadata import version

import torch
import torch.nn as nn

#### The rank of a matrix is the maximum number of linearly independent rows or columns in the matrix.

In [39]:
w = torch.rand((15, 4)) @ torch.rand((4, 15))
w.shape

torch.Size([15, 15])

In [40]:
w_rank = torch.linalg.matrix_rank(w)
w_rank

tensor(4)

#### To determine the rank of a matrix, you can use various methods such as row reduction (Gaussian elimination), singular value decomposition (SVD), or eigenvalue decomposition.

- U: The left singular vectors matrix U has dimensions m × m, where m is the number of rows in the original matrix w. The columns of U are orthogonal unit vectors that represent the directions of maximum variance in the input space.
- S: The singular values S is a 1-D tensor of length k, where k is the minimum of the number of rows and columns of w. The singular values represent the amount of variance captured by each singular vector.
- V: The right singular vectors matrix V has dimensions n × n, where n is the number of columns in the original matrix w. The rows of V are orthogonal unit vectors that represent the directions of maximum variance in the output space.
Together, these matrices satisfy the equation:


w = U * diag(S) * V^T


In [68]:
w = torch.randint(0, 4, (2, 3))
w = torch.tensor(w, dtype=torch.float)
w

  w = torch.tensor(w,dtype=torch.float)


tensor([[2., 0., 0.],
        [1., 2., 3.]])

In [41]:
U, S, V = torch.svd(w)
print(U.shape, S.shape, V.shape)

torch.Size([15, 15]) torch.Size([15]) torch.Size([15, 15])


In [42]:
a = U[:, :w_rank] @ torch.diag(S[:w_rank])
b = V[:, :w_rank].T
w_lr = a @ b

In [43]:
w.nelement()

225

In [44]:
a.nelement() + b.nelement()

120

In [54]:
# z = w @ x + b
x = torch.rand(15)
b = torch.rand(15)

print(w @ x + b)
print(w_lr @ x + b)

tensor([ 7.8744,  5.7316,  9.0137,  5.0998,  9.1993, 13.8932,  9.4766, 12.2242,
        13.1615,  5.5654,  9.5690,  8.8458, 13.6005,  7.6197,  7.7441])
tensor([ 7.8744,  5.7316,  9.0137,  5.0998,  9.1993, 13.8932,  9.4766, 12.2242,
        13.1615,  5.5654,  9.5690,  8.8458, 13.6005,  7.6197,  7.7441])
