In [59]:
import jax
import jax.numpy as jnp
from jax.numpy.linalg import norm, svd

jax.config.update('jax_enable_x64', False)

In [60]:
def normalize_vector(v:jax.Array):
    return v / norm(v)

In [61]:
def householder_reflection(x:jax.Array):
    w_0 = x[0] + jnp.sign(x[0]) * norm(x) if x[0] != 0 else norm(x)
    w = x.at[0].set(w_0)
    w = normalize_vector(w)
    return jnp.identity(n=len(w)) - 2*jnp.linalg.outer(w, w) 

In [62]:
def apply_householder_reflection(H, x):
    return jnp.dot(H, x)

In [63]:
x = jnp.array([1.,2.])
y = jnp.array([2.,3.])

x + y

Array([3., 5.], dtype=float32)

In [64]:
H = householder_reflection(x) 

In [65]:
apply_householder_reflection(H, x)

Array([-2.2360678e+00,  1.7881393e-07], dtype=float32)

In [66]:
A = jnp.array([[1,2,3], [4,5,6], [7,8,9]], dtype=jnp.float32)
A

Array([[1., 2., 3.],
       [4., 5., 6.],
       [7., 8., 9.]], dtype=float32)

In [67]:
H1 = householder_reflection(A.T[0])
A2 = jnp.dot(H1, A)
A2

Array([[-8.1240387e+00, -9.6011372e+00, -1.1078235e+01],
       [-1.7881393e-07, -8.5965633e-02, -1.7193133e-01],
       [ 1.7881393e-07, -9.0043950e-01, -1.8008795e+00]], dtype=float32)

In [68]:
H1 @ A

Array([[-8.1240387e+00, -9.6011372e+00, -1.1078235e+01],
       [-1.7881393e-07, -8.5965633e-02, -1.7193133e-01],
       [ 1.7881393e-07, -9.0043950e-01, -1.8008795e+00]], dtype=float32)

In [69]:
def qr_decomposition(A:jax.Array):
    n, m = A.shape

    R = A.copy()
    Q = jnp.identity(n)
    for i in range(m-1):
        H_i = householder_reflection(R[i:, i:].T[0])
        H_i = jax.scipy.linalg.block_diag(jnp.eye(i), H_i) if i != 0 else H_i
        R = jnp.dot(H_i, R)
        Q = jnp.dot(Q, H_i.T)
    
    return Q, R

In [70]:
Q, R = qr_decomposition(jnp.array([[1.3,2.],[2.,1.]]))
jnp.dot(Q, R)

Array([[1.2999997 , 1.9999998 ],
       [1.9999998 , 0.99999976]], dtype=float32)

In [71]:
A = jnp.array([[1,2,3], [4,5,6], [7,8,9]], dtype=jnp.float32)

def bidiagonalisation_decomposition(A:jax.Array):
    n, m = A.shape
    Q_2 = jnp.identity(m)
    Q_1 = jnp.identity(n)
    B = A.copy()

    for i in range(min(n, m)):
        if i <= n-1:
            H_1 = householder_reflection(B[i:, i:].T[0])
            H_1 = jax.scipy.linalg.block_diag(jnp.eye(i), H_1)
            B = jnp.dot(H_1, B)
            Q_1 = jnp.dot(Q_1, H_1.T)

        if i < m-1:
            H_2 = householder_reflection(B[i:, i+1:][0])
            H_2 = jax.scipy.linalg.block_diag(jnp.eye(i+1), H_2)
            B = jnp.dot(B, H_2.T)
            Q_2 = jnp.dot(H_2, Q_2)

    return Q_1, B, Q_2

In [72]:
Q_1, B, Q_2 = bidiagonalisation_decomposition(jnp.array([[1,2,3,4,5,6], [6,5,4,3,2,1]], dtype=jnp.float32))
Q_1 @ B @ Q_2

Array([[1.0000021, 2.0000029, 2.9999998, 4.000003 , 5.0000033, 6.0000043],
       [6.000002 , 5.000002 , 4.0000024, 3.0000021, 2.000002 , 1.0000017]],      dtype=float32)

In [73]:
B = jnp.array([[1, 2, 0, 0],
               [3, 4, 5, 0],
               [0, 6, 7, 8],
               [0, 0, 9, 10]])

In [74]:
# Define block size (example: 2x2)
block_size = 2

# Extract blocks
block_11 = B[:block_size, :block_size]  # Top-left block
block_12 = B[:block_size, block_size:]  # Top-right block
block_21 = B[block_size:, :block_size]  # Bottom-left block
block_22 = B[block_size:, block_size:]  # Bottom-right block

print("Block 11:")
print(block_11)
print("\nBlock 12:")
print(block_12)
print("\nBlock 21:")
print(block_21)
print("\nBlock 22:")
print(block_22)

Block 11:
[[1 2]
 [3 4]]

Block 12:
[[0 0]
 [5 0]]

Block 21:
[[0 6]
 [0 0]]

Block 22:
[[ 7  8]
 [ 9 10]]


In [75]:
def split_matrix_into_blocks(B, block_size:int=2):

    assert block_size > 0, f'The block size should be greater than 0. Instead {block_size}'
    assert block_size <= min(B.shape), f'The block size should be less than or equal to the size of the matrix. Instead {block_size} > {min(B.shape)}'

    block_11 = B[:block_size, :block_size]  # Top-left block
    block_12 = B[:block_size, block_size:]  # Top-right block
    block_21 = B[block_size:, :block_size]  # Bottom-left block
    block_22 = B[block_size:, block_size:]  # Bottom-right block

    return [block_11, block_12, block_21, block_22]

In [76]:
split_matrix_into_blocks(B)

[Array([[1, 2],
        [3, 4]], dtype=int32),
 Array([[0, 0],
        [5, 0]], dtype=int32),
 Array([[0, 6],
        [0, 0]], dtype=int32),
 Array([[ 7,  8],
        [ 9, 10]], dtype=int32)]

In [77]:
def perform_svd_on_blocks(blocks: list):
    U_list, S_list, Vh_list = [], [], []
    singular_values = []

    # Perform SVD on each block
    for block in blocks:
        U_block, S_block, Vh_block = svd(block, full_matrices=False)
        U_list.append(U_block)
        singular_values.append(S_block)
        Vh_list.append(Vh_block)

    # Convert lists to arrays
    U = jnp.hstack(U_list)
    Vh = jnp.vstack(Vh_list)

    # Concatenate singular values and sort them if needed
    S = jnp.concatenate(singular_values)
    
    # Sorting singular values in descending order
    sorted_indices = jnp.argsort(S)[::-1]
    S_sorted = jnp.sort(S)[::-1]
    Vh_sorted = Vh[sorted_indices]

    # Construct diagonal matrix for singular values
    S_diag = jnp.diag(S_sorted)

    return U, S_diag, Vh_sorted

In [78]:
B = jnp.array([[1, 2, 0, 0],
               [3, 4, 5, 0],
               [0, 6, 7, 8],
               [0, 0, 9, 10]])

blocks = split_matrix_into_blocks(B)
U,S,V = perform_svd_on_blocks(blocks)
S

Array([[17.146032  ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  6.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  5.4649854 ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  5.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.36596614,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.11664554,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ]], dtype=float32)

In [79]:
B_scaled = (B - jnp.mean(B))/jnp.std(B)

In [80]:
U, S, Vt = svd(B)
Vt.T.shape

(4, 4)

In [81]:
explained_variance = S**2/(len(B) - 1)

In [82]:
explained_variance_ratio = explained_variance / jnp.sum(explained_variance)

In [83]:
explained_variance_ratio

Array([0.8529526 , 0.11453851, 0.0312005 , 0.00130828], dtype=float32)

In [84]:
B_pca = jnp.dot(B, Vt.T[:, :3])

In [85]:
U2 = jnp.linalg.eig(jnp.dot(B, B.T))[1].astype(jnp.float32)
S2 = jnp.sqrt(jnp.linalg.eig(jnp.dot(B.T, B))[0].astype(jnp.float32))
sorted_indices = jnp.argsort(S2, descending=True)
U, U2

(Array([[-0.03292128,  0.294175  , -0.17891642,  0.9382783 ],
        [-0.25488862,  0.7409277 ,  0.6085924 , -0.12519355],
        [-0.65354645,  0.27927098, -0.66242605, -0.23680505],
        [-0.7119168 , -0.5352525 ,  0.39849195,  0.21882348]],      dtype=float32),
 Array([[-0.03292139, -0.29417548, -0.93827826, -0.17891629],
        [-0.25488853, -0.74092764,  0.1251936 ,  0.6085925 ],
        [-0.6535463 , -0.27927062,  0.23680519, -0.6624261 ],
        [-0.7119166 ,  0.53525215, -0.21882358,  0.39849186]],      dtype=float32))

In [86]:
S2, Vt2 = jnp.linalg.eig(jnp.dot(B.T, B))
S2 = jnp.sqrt(S2).astype(jnp.float32).sort()
Vt2 = Vt2.astype(jnp.float32).T
(S, Vt), (S2, Vt2)

((Array([18.121445 ,  6.6405816,  3.465861 ,  0.7097109], dtype=float32),
  Array([[-0.04401344, -0.27628452, -0.6763542 , -0.681377  ],
         [ 0.37902662,  0.7872333 ,  0.12683554, -0.46959096],
         [ 0.47516632, -0.547633  ,  0.5748661 , -0.37926766],
         [ 0.79285467, -0.06347415, -0.44270378,  0.41396353]],      dtype=float32)),
 (Array([ 0.7097096,  3.4658616,  6.6405826, 18.121445 ], dtype=float32),
  Array([[-0.04401337, -0.27628452, -0.6763542 , -0.68137705],
         [-0.3790269 , -0.78723294, -0.12683566,  0.4695908 ],
         [-0.7928546 ,  0.06347422,  0.4427038 , -0.41396356],
         [-0.47516638,  0.5476332 , -0.57486606,  0.3792676 ]],      dtype=float32)))

In [87]:
def my_svd(B:jax.Array):
    U = jnp.linalg.eig(jnp.dot(B, B.T))[1].astype(jnp.float32)
    S_u = jnp.sqrt(jnp.linalg.eig(jnp.dot(B, B.T))[0].astype(jnp.float32))
    Vt = jnp.linalg.eig(jnp.dot(B.T, B))[1].T.astype(jnp.float32)
    S_vt = jnp.sqrt(jnp.linalg.eig(jnp.dot(B.T, B))[0].astype(jnp.float32))

    sorted_indices_u = jnp.argsort(S_u, descending=True)
    sorted_indices_vt = jnp.argsort(S_vt, descending=True)

    return U[:, sorted_indices_u], jnp.diag(S_u[sorted_indices_u]), Vt[:, sorted_indices_vt]

In [88]:
U2, S2, Vt2 = my_svd(B)
U, S, Vt = svd(B)

U2 @ S2 @ Vt2

Array([[ 1.5747491e+00,  1.2986512e+00, -5.0670344e-01,  7.5956029e-01],
       [ 3.5359454e-01,  5.3320231e+00, -2.7045100e-03,  4.6308179e+00],
       [ 2.9646111e+00,  4.6783400e+00,  8.2129717e+00,  7.1324124e+00],
       [-1.8006245e+00,  7.6882309e-01,  9.8288898e+00,  8.9754982e+00]],      dtype=float32)

In [89]:
B = jnp.array([[1, 2, 0, 0, 0],
               [0, 4, 5, 0, 0],
               [0, 0, 7, 8, 0],
               [0, 0, 0, 10, 11],
               [1, 2, 0, 0, 0],
               [1, 2, 0, 0, 0],
               [1, 2, 0, 0, 0]])

B = (B - B.mean())/B.std()
n, m = B.shape

num_components = 3

U, S, Vt = jax.scipy.linalg.svd(B, full_matrices=True)

print(U.shape, S.shape, V.shape)

if n < m:
    S = jnp.concatenate((jnp.diag(S), jnp.zeros((n, m-n))), axis=1)
elif n > m:
    S = jnp.concatenate((jnp.diag(S), jnp.zeros((n-m, m))), axis=0)


B @ Vt[:num_components].T

(7, 7) (5,) (8, 2)


Array([[ 0.79758686,  0.3347571 , -0.29629523],
       [ 0.8001173 , -1.0575135 ,  0.9375937 ],
       [-1.51121   , -2.5334764 , -0.49882543],
       [-4.263436  ,  0.9500487 ,  0.13105209],
       [ 0.79758686,  0.3347571 , -0.29629523],
       [ 0.79758686,  0.3347571 , -0.29629523],
       [ 0.79758686,  0.3347571 , -0.29629523]], dtype=float32)

In [90]:
U2, U

(Array([[-0.03292139, -0.29417548, -0.17891629, -0.93827826],
        [-0.25488853, -0.74092764,  0.6085925 ,  0.1251936 ],
        [-0.6535463 , -0.27927062, -0.6624261 ,  0.23680519],
        [-0.7119166 ,  0.53525215,  0.39849186, -0.21882358]],      dtype=float32),
 Array([[ 1.64022967e-01,  1.12288654e-01, -2.42227316e-01,
          3.89632732e-01,  8.66025329e-01,  0.00000000e+00,
         -8.96705643e-09],
        [ 1.64543360e-01, -3.54724795e-01,  7.66501904e-01,
          5.09480536e-01,  1.56532224e-07,  1.80300059e-08,
          9.59345225e-09],
        [-3.10778737e-01, -8.49811375e-01, -4.07799214e-01,
          1.22215040e-01, -2.22817107e-08, -7.23325755e-09,
         -2.47567788e-09],
        [-8.76771212e-01,  3.18677604e-01,  1.07137300e-01,
          3.43857735e-01, -1.11688678e-07, -1.03446371e-08,
         -6.83913148e-09],
        [ 1.64022967e-01,  1.12288617e-01, -2.42227197e-01,
          3.89632732e-01, -2.88675159e-01, -5.77350259e-01,
         -5.77350259e-

In [91]:
S2, S

(Array([[18.121445  ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  6.6405807 ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  3.4658618 ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.70970935]],      dtype=float32),
 Array([[4.8626552e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 2.9812212e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 1.2232124e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.8522401e-01,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         1.1885693e-07],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00]], dtype=float32))

In [92]:
Vt2, Vt

(Array([[-0.04401337, -0.27628452, -0.68137705, -0.6763542 ],
        [-0.3790269 , -0.78723294,  0.4695908 , -0.12683566],
        [-0.7928546 ,  0.06347422, -0.41396356,  0.4427038 ],
        [-0.47516638,  0.5476332 ,  0.3792676 , -0.57486606]],      dtype=float32),
 Array([[ 0.08774175,  0.17972684, -0.05285533, -0.74583775, -0.6331922 ],
        [ 0.13246639,  0.02176011, -0.80038613, -0.33112183,  0.4813729 ],
        [-0.04163036,  0.5418476 ,  0.5000364 , -0.38165998,  0.5558481 ],
        [-0.88722956,  0.35512844, -0.2501087 ,  0.09995639, -0.1190044 ],
        [-0.4310973 , -0.73993903,  0.20975612, -0.4224085 ,  0.21028274]],      dtype=float32))

In [93]:
U2 @ S2 @ Vt2

Array([[ 1.5747491e+00,  1.2986512e+00, -5.0670344e-01,  7.5956029e-01],
       [ 3.5359454e-01,  5.3320231e+00, -2.7045100e-03,  4.6308179e+00],
       [ 2.9646111e+00,  4.6783400e+00,  8.2129717e+00,  7.1324124e+00],
       [-1.8006245e+00,  7.6882309e-01,  9.8288898e+00,  8.9754982e+00]],      dtype=float32)

In [97]:
class myPCA():
    def __init__(self, num_components:int):
        self.num_components = num_components
        self.mean = None
        self.principal_components = None
        self.explained_variance = None

    def fit(self, X:jax.Array):
        n, m = X.shape
        
        self.mean = X.mean(axis=0)
        X_centred = X - self.mean
        S, self.principal_components = svd(X_centred, full_matrices=True)[1:]

        if n < m:
            S = jnp.concatenate((jnp.diag(S), jnp.zeros((n, m-n))), axis=1)
        elif n > m:
            S = jnp.concatenate((jnp.diag(S), jnp.zeros((n-m, m))), axis=0)

        self.explained_variance = S**2 / jnp.sum(S**2)

    def transform(self, X:jax.Array):
        if self.principal_components is None:
            raise RuntimeError('Must fit before transforming.')
        
        X_centred = X - X.mean(axis=0)
        return jnp.dot(X_centred, self.principal_components[:self.num_components].T)
    
    def fit_transform(self, X:jax.Array):
        if self.mean is None:
            self.mean = X.mean(axis=0)

        X_centred = X - self.mean

        self.principal_components = svd(X_centred, full_matrices=True)[2]

        return jnp.dot(X_centred, self.principal_components[:self.num_components].T)

    def inverse_transform(self, X_transformed:jax.Array):
        if self.principal_components is None:
            raise RuntimeError('Must fit before transforming.')
        
        return jnp.dot(X_transformed, self.principal_components[:self.num_components]) + self.mean

In [104]:
mymodel= myPCA(num_components=3)
B_transformed = mymodel.fit_transform(B)
jnp.isclose(B[:,0], B_transformed[:, 0], rtol=10)

Array([ True,  True,  True,  True,  True,  True,  True], dtype=bool)