# Singular Value Decomposition

In [2]:
import torch
import numpy as np

torch.manual_seed(0)

<torch._C.Generator at 0x1303437f0>

Generate a rank-deficient (not full rank: has linearly dependent rows/columns) matrix W

In [3]:
d, k = 10, 10

W_rank = 2
W = torch.randn(d, W_rank) @ torch.randn(W_rank, k)
W

tensor([[-1.0797,  0.5545,  0.8058, -0.7140, -0.1518,  1.0773,  2.3690,  0.8486,
         -1.1825, -3.2632],
        [-0.3303,  0.2283,  0.4145, -0.1924, -0.0215,  0.3276,  0.7926,  0.2233,
         -0.3422, -0.9614],
        [-0.5256,  0.9864,  2.4447, -0.0290,  0.2305,  0.5000,  1.9831, -0.0311,
         -0.3369, -1.1376],
        [ 0.7900, -1.1336, -2.6746,  0.1988, -0.1982, -0.7634, -2.5763, -0.1696,
          0.6227,  1.9294],
        [ 0.1258,  0.1458,  0.5090,  0.1768,  0.1071, -0.1327, -0.0323, -0.2294,
          0.2079,  0.5128],
        [ 0.7697,  0.0050,  0.5725,  0.6870,  0.2783, -0.7818, -1.2253, -0.8533,
          0.9765,  2.5786],
        [ 1.4157, -0.7814, -1.2121,  0.9120,  0.1760, -1.4108, -3.1692, -1.0791,
          1.5325,  4.2447],
        [-0.0119,  0.6050,  1.7245,  0.2584,  0.2528, -0.0086,  0.7198, -0.3620,
          0.1865,  0.3410],
        [ 1.0485, -0.6394, -1.0715,  0.6485,  0.1046, -1.0427, -2.4174, -0.7615,
          1.1147,  3.1054],
        [ 0.9088,  

Evaluate the rank of matrix W

In [None]:
W_rank = np.linalg.matrix_rank(W)
print(f'Rank of W: {W_rank}')

Rank of W: 2


### What is SVD?
At its core, SVD is a way to decompose any matrix into three simpler matrices that, when multiplied together, reproduce the original matrix.
For any matrix W of dimensions m×n, SVD expresses it as:
W = U × Σ × V^T
Where:

U is an m×m orthogonal matrix (its columns are orthonormal eigenvectors of WW^T)
Σ (Sigma) is an m×n diagonal matrix with non-negative real numbers on the diagonal (the singular values)
V^T is the transpose of an n×n orthogonal matrix (columns of V are orthonormal eigenvectors of W^TW)

### The Geometric Interpretation
SVD provides a powerful geometric interpretation of matrix transformations:

V^T represents a rotation in the input space
Σ represents a scaling along the coordinate axes
U represents a rotation in the output space

Think of it as: the matrix W takes vectors from one space, rotates them (V^T), stretches or compresses them along certain directions (Σ), and then rotates them again (U).

### Why Singular Values Matter
The singular values (the diagonal elements of Σ) are sorted in descending order and represent the "importance" of each dimension. They tell us how much the matrix stretches space along each principal direction.
The number of non-zero singular values equals the rank of the matrix. This is crucial for low-rank approximations.

### Low-Rank Approximations
One of the most powerful applications of SVD is creating low-rank approximations of matrices. If we keep only the r largest singular values (and corresponding columns of U and V), we get the best possible rank-r approximation to the original matrix in terms of minimizing the Frobenius norm of the difference.

In [None]:
# Perform SVD on W (W = U x S x V^T)
U, S, V = torch.svd(W) # U = (10, 10), S = (10,), V = (10, 10)

# For rank-r factorization, keep only the first r singular values (and corresponding columns of U and V)
U_r = U[:, :W_rank]
S_r = torch.diag(S[:W_rank])
V_r = V[:, :W_rank].t()

# Compute A and B
B = U_r @ S_r
A = V_r

print(f'Shape of B: {B.shape}')
print(f'Shape of A: {A.shape}')

Shape of B: torch.Size([10, 2])
Shape of A: torch.Size([2, 10])


Given the same input, check the output using the original W matrix and the matrices resulting from the decomposition

In [6]:
bias = torch.randn(d)
x = torch.randn(d)

# Compute y = Wx + b
y = W @ x + bias

# Compute y' = BAx + b
y_prime = (B @ A) @ x + bias

print(f'Original y using W:\n{y}')
print(f'Original y using BA:\n{y_prime}')

Original y using W:
tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1639e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])
Original y using BA:
tensor([ 7.2684e+00,  2.3162e+00,  7.7151e+00, -1.0446e+01, -8.1640e-03,
        -3.7270e+00, -1.1146e+01,  2.0207e+00, -9.6258e+00, -4.1163e+00])


Above output implies that A and B captured most of the information of W

In [7]:
print(f'Total parameters of W: {W.nelement()}')
print(f'Total parameters of A and B: {A.nelement() + B.nelement()}')

Total parameters of W: 100
Total parameters of A and B: 40
