In [1164]:
import jax
import jax.numpy as jnp
from scipy.linalg import norm

jax.config.update('jax_enable_x64', False)

In [1165]:
def normalize_vector(v:jax.Array):
    return v / norm(v)

In [1166]:
def householder_reflection(x:jax.Array):
    w_0 = x[0] + jnp.sign(x[0]) * norm(x) if x[0] != 0 else norm(x)
    w = x.at[0].set(w_0)
    w = normalize_vector(w)
    return jnp.identity(n=len(w)) - 2*jnp.linalg.outer(w, w) 

In [1167]:
def apply_householder_reflection(H, x):
    return jnp.dot(H, x)

In [1168]:
x = jnp.array([1.,2.])
y = jnp.array([2.,3.])

x + y

Array([3., 5.], dtype=float32)

In [1169]:
H = householder_reflection(x) 

In [1170]:
apply_householder_reflection(H, x)

Array([-2.2360678e+00,  1.7881393e-07], dtype=float32)

In [1171]:
A = jnp.array([[1,2,3], [4,5,6], [7,8,9]], dtype=jnp.float32)
A

Array([[1., 2., 3.],
       [4., 5., 6.],
       [7., 8., 9.]], dtype=float32)

In [1172]:
H1 = householder_reflection(A.T[0])
A2 = jnp.dot(H1, A)
A2

Array([[-8.1240377e+00, -9.6011353e+00, -1.1078234e+01],
       [ 2.6822090e-07, -8.5965395e-02, -1.7193082e-01],
       [ 4.1723251e-07, -9.0043926e-01, -1.8008790e+00]], dtype=float32)

In [1173]:
H1 @ A

Array([[-8.1240377e+00, -9.6011353e+00, -1.1078234e+01],
       [ 2.6822090e-07, -8.5965395e-02, -1.7193082e-01],
       [ 4.1723251e-07, -9.0043926e-01, -1.8008790e+00]], dtype=float32)

In [1174]:
def qr_decomposition(A:jax.Array):
    n, m = A.shape

    R = A.copy()
    Q = jnp.identity(n)
    for i in range(m-1):
        H_i = householder_reflection(R[i:, i:].T[0])
        H_i = jax.scipy.linalg.block_diag(jnp.eye(i), H_i) if i != 0 else H_i
        R = jnp.dot(H_i, R)
        Q = jnp.dot(Q, H_i.T)
    
    return Q, R

In [1175]:
Q, R = qr_decomposition(jnp.array([[1.3,2.],[2.,1.]]))
jnp.dot(Q, R)

Array([[1.2999997 , 1.9999998 ],
       [1.9999998 , 0.99999976]], dtype=float32)

In [1176]:
A = jnp.array([[1,2,3], [4,5,6], [7,8,9]], dtype=jnp.float32)

def bidiagonalisation_decomposition(A:jax.Array):
    n, m = A.shape
    Q_2 = jnp.identity(m)
    Q_1 = jnp.identity(n)
    B = A.copy()

    for i in range(min(n, m)):
        if i <= n-1:
            H_1 = householder_reflection(B[i:, i:].T[0])
            H_1 = jax.scipy.linalg.block_diag(jnp.eye(i), H_1)
            B = jnp.dot(H_1, B)
            Q_1 = jnp.dot(Q_1, H_1.T)

        if i < m-1:
            H_2 = householder_reflection(B[i:, i+1:][0])
            H_2 = jax.scipy.linalg.block_diag(jnp.eye(i+1), H_2)
            B = jnp.dot(B, H_2.T)
            Q_2 = jnp.dot(H_2, Q_2)

    return Q_1, B, Q_2

In [1177]:
Q_1, B, Q_2 = bidiagonalisation_decomposition(jnp.array([[1,2,3,4,5,6], [6,5,4,3,2,1]], dtype=jnp.float32))
Q_1 @ B @ Q_2

Array([[1.0000011, 1.999998 , 2.9999995, 4.000001 , 5.000002 , 6.000002 ],
       [6.0000005, 4.999997 , 3.9999998, 2.9999998, 2.       , 1.0000004]],      dtype=float32)

In [1178]:
B = jnp.array([[1, 2, 0, 0],
               [3, 4, 5, 0],
               [0, 6, 7, 8],
               [0, 0, 9, 10]])

In [1179]:
# Define block size (example: 2x2)
block_size = 2

# Extract blocks
block_11 = B[:block_size, :block_size]  # Top-left block
block_12 = B[:block_size, block_size:]  # Top-right block
block_21 = B[block_size:, :block_size]  # Bottom-left block
block_22 = B[block_size:, block_size:]  # Bottom-right block

print("Block 11:")
print(block_11)
print("\nBlock 12:")
print(block_12)
print("\nBlock 21:")
print(block_21)
print("\nBlock 22:")
print(block_22)

Block 11:
[[1 2]
 [3 4]]

Block 12:
[[0 0]
 [5 0]]

Block 21:
[[0 6]
 [0 0]]

Block 22:
[[ 7  8]
 [ 9 10]]


In [1180]:
def split_matrix_into_blocks(B, block_size:int=2):

    assert block_size > 0, f'The block size should be greater than 0. Instead {block_size}'
    assert block_size <= min(B.shape), f'The block size should be less than or equal to the size of the matrix. Instead {block_size} > {min(B.shape)}'

    block_11 = B[:block_size, :block_size]  # Top-left block
    block_12 = B[:block_size, block_size:]  # Top-right block
    block_21 = B[block_size:, :block_size]  # Bottom-left block
    block_22 = B[block_size:, block_size:]  # Bottom-right block

    return [block_11, block_12, block_21, block_22]

In [1181]:
split_matrix_into_blocks(B)

[Array([[1, 2],
        [3, 4]], dtype=int32),
 Array([[0, 0],
        [5, 0]], dtype=int32),
 Array([[0, 6],
        [0, 0]], dtype=int32),
 Array([[ 7,  8],
        [ 9, 10]], dtype=int32)]

In [1182]:
def perform_svd_on_blocks(blocks: list):
    U_list, S_list, Vh_list = [], [], []
    singular_values = []

    # Perform SVD on each block
    for block in blocks:
        U_block, S_block, Vh_block = svd(block, full_matrices=False)
        U_list.append(U_block)
        singular_values.append(S_block)
        Vh_list.append(Vh_block)

    # Convert lists to arrays
    U = jnp.hstack(U_list)
    Vh = jnp.vstack(Vh_list)

    # Concatenate singular values and sort them if needed
    S = jnp.concatenate(singular_values)
    
    # Sorting singular values in descending order
    sorted_indices = jnp.argsort(S)[::-1]
    S_sorted = jnp.sort(S)[::-1]
    Vh_sorted = Vh[sorted_indices]

    # Construct diagonal matrix for singular values
    S_diag = jnp.diag(S_sorted)

    return U, S_diag, Vh_sorted

In [1183]:
B = jnp.array([[1, 2, 0, 0],
               [3, 4, 5, 0],
               [0, 6, 7, 8],
               [0, 0, 9, 10]])

blocks = split_matrix_into_blocks(B)
U,S,V = perform_svd_on_blocks(blocks)
S

Array([[17.146032  ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  6.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  5.464986  ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  5.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.3659662 ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.11664507,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ]], dtype=float32)

In [1184]:
B_scaled = (B - jnp.mean(B))/jnp.std(B)

In [1185]:
U, S, Vt = svd(B)
Vt.T.shape

(4, 4)

In [1186]:
explained_variance = S**2/(len(B) - 1)

In [1187]:
explained_variance_ratio = explained_variance / jnp.sum(explained_variance)

In [1188]:
explained_variance_ratio

Array([0.85295266, 0.11453852, 0.03120052, 0.00130828], dtype=float32)

In [1189]:
B_pca = jnp.dot(B, Vt.T[:, :3])

In [1190]:
U2 = jnp.linalg.eig(jnp.dot(B, B.T))[1].astype(jnp.float32)
S2 = jnp.sqrt(jnp.linalg.eig(jnp.dot(B.T, B))[0].astype(jnp.float32))
sorted_indices = jnp.argsort(S2, descending=True)
U, U2

(array([[-0.03292135,  0.29417493, -0.17891654,  0.9382783 ],
        [-0.25488858,  0.74092764,  0.60859233, -0.12519354],
        [-0.65354626,  0.27927089, -0.66242614, -0.23680513],
        [-0.71191663, -0.53525239,  0.39849198,  0.21882354]]),
 Array([[ 0.03292142,  0.29417497, -0.9382782 , -0.17891674],
        [ 0.25488868,  0.74092764,  0.1251934 ,  0.60859215],
        [ 0.65354633,  0.2792707 ,  0.23680526, -0.6624261 ],
        [ 0.71191657, -0.5352525 , -0.21882364,  0.39849204]],      dtype=float32))

In [1191]:
S2, Vt2 = jnp.linalg.eig(jnp.dot(B.T, B))
S2 = jnp.sqrt(S2).astype(jnp.float32).sort()
Vt2 = Vt2.astype(jnp.float32).T
(S, Vt), (S2, Vt2)

((array([18.12144537,  6.64058211,  3.4658617 ,  0.70971094]),
  array([[-0.04401344, -0.27628451, -0.67635424, -0.68137701],
         [ 0.37902669,  0.78723305,  0.12683571, -0.46959088],
         [ 0.47516623, -0.54763311,  0.57486611, -0.37926769],
         [ 0.79285474, -0.06347418, -0.44270374,  0.41396351]])),
 (Array([ 0.7097136,  3.4658616,  6.6405845, 18.121445 ], dtype=float32),
  Array([[ 0.04401338,  0.27628446,  0.6763542 ,  0.681377  ],
         [ 0.37902707,  0.7872326 ,  0.12683602, -0.46959126],
         [-0.79285526,  0.06347464,  0.4427031 , -0.41396323],
         [-0.47516647,  0.54763305, -0.57486606,  0.37926766]],      dtype=float32)))

In [1192]:
def my_svd(B:jax.Array):
    U = jnp.linalg.eig(jnp.dot(B, B.T))[1].astype(jnp.float32)
    S_u = jnp.sqrt(jnp.linalg.eig(jnp.dot(B, B.T))[0].astype(jnp.float32))
    Vt = jnp.linalg.eig(jnp.dot(B.T, B))[1].T.astype(jnp.float32)
    S_vt = jnp.sqrt(jnp.linalg.eig(jnp.dot(B.T, B))[0].astype(jnp.float32))

    sorted_indices_u = jnp.argsort(S_u, descending=True)
    sorted_indices_vt = jnp.argsort(S_vt, descending=True)

    return U[:, sorted_indices_u], jnp.diag(S_u[sorted_indices_u]), Vt[:, sorted_indices_vt]

In [1193]:
U2, S2, Vt2 = my_svd(B)
U, S, Vt = svd(B)

U2 @ S2 @ Vt2

Array([[ 1.5747514e+00,  1.2986461e+00, -5.0670332e-01,  7.5956261e-01],
       [ 3.5359529e-01,  5.3320231e+00, -2.7035451e-03,  4.6308184e+00],
       [ 2.9646125e+00,  4.6783385e+00,  8.2129698e+00,  7.1324143e+00],
       [-1.8006270e+00,  7.6882154e-01,  9.8288918e+00,  8.9754953e+00]],      dtype=float32)

In [1194]:
B = jnp.array([[1, 2, 0, 0, 0],
               [0, 4, 5, 0, 0],
               [0, 0, 7, 8, 0],
               [0, 0, 0, 10, 11],
               [1, 2, 0, 0, 0],
               [1, 2, 0, 0, 0],
               [1, 2, 0, 0, 0]])

B = (B - B.mean())/B.std()
n, m = B.shape

num_components = 3

U, S, Vt = jax.scipy.linalg.svd(B, full_matrices=True)

print(U.shape, S.shape, V.shape)

if n < m:
    S = jnp.concatenate((jnp.diag(S), jnp.zeros((n, m-n))), axis=1)
elif n > m:
    S = jnp.concatenate((jnp.diag(S), jnp.zeros((n-m, m))), axis=0)


B @ Vt[:num_components].T

(7, 7) (5,) (8, 2)


Array([[ 0.7975866 ,  0.33475694, -0.29629555],
       [ 0.8001167 , -1.0575129 ,  0.9375936 ],
       [-1.5112098 , -2.5334752 , -0.49882537],
       [-4.2634344 ,  0.9500482 ,  0.13105164],
       [ 0.7975866 ,  0.33475694, -0.29629552],
       [ 0.7975866 ,  0.33475694, -0.29629552],
       [ 0.7975866 ,  0.33475694, -0.29629552]], dtype=float32)

In [1195]:
U2, U

(Array([[ 0.03292142,  0.29417497, -0.17891674, -0.9382782 ],
        [ 0.25488868,  0.74092764,  0.60859215,  0.1251934 ],
        [ 0.65354633,  0.2792707 , -0.6624261 ,  0.23680526],
        [ 0.71191657, -0.5352525 ,  0.39849204, -0.21882364]],      dtype=float32),
 Array([[ 1.64022923e-01,  1.12288535e-01, -2.42227674e-01,
          3.89632523e-01,  8.66025388e-01, -1.58074709e-09,
         -1.58074709e-09],
        [ 1.64543286e-01, -3.54724824e-01,  7.66501606e-01,
          5.09481072e-01, -1.92716243e-09,  1.22538957e-09,
          1.22538957e-09],
        [-3.10778767e-01, -8.49811614e-01, -4.07799274e-01,
          1.22214384e-01, -1.21136452e-07,  1.23196564e-09,
          1.23196564e-09],
        [-8.76771271e-01,  3.18677723e-01,  1.07136950e-01,
          3.43857884e-01, -3.17294848e-08, -2.37709741e-09,
         -8.71655370e-09],
        [ 1.64022923e-01,  1.12288557e-01, -2.42227525e-01,
          3.89632583e-01, -2.88675159e-01, -5.77350259e-01,
         -5.77350259e-

In [1196]:
S2, S

(Array([[18.121445  ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  6.640581  ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  3.465861  ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.70971274]],      dtype=float32),
 Array([[4.8626533e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 2.9812198e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 1.2232121e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 9.8522407e-01,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         9.6406190e-08],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00],
        [0.0000000e+00, 0.0000000e+00, 0.0000000e+00, 0.0000000e+00,
         0.0000000e+00]], dtype=float32))

In [1197]:
Vt2, Vt

(Array([[ 0.04401338,  0.27628446,  0.681377  ,  0.6763542 ],
        [ 0.37902707,  0.7872326 , -0.46959126,  0.12683602],
        [-0.79285526,  0.06347464, -0.41396323,  0.4427031 ],
        [-0.47516647,  0.54763305,  0.37926766, -0.57486606]],      dtype=float32),
 Array([[ 0.08774178,  0.1797265 , -0.05285541, -0.74583757, -0.63319194],
        [ 0.13246639,  0.02176034, -0.8003858 , -0.3311217 ,  0.48137274],
        [-0.04162986,  0.5418473 ,  0.5000368 , -0.38166016,  0.55584824],
        [-0.8872296 ,  0.35512865, -0.25010833,  0.09995602, -0.11900399],
        [ 0.43109712,  0.739939  , -0.20975605,  0.42240852, -0.21028268]],      dtype=float32))

In [1198]:
U2 @ S2 @ Vt2

Array([[ 1.5747514e+00,  1.2986461e+00, -5.0670332e-01,  7.5956261e-01],
       [ 3.5359529e-01,  5.3320231e+00, -2.7035451e-03,  4.6308184e+00],
       [ 2.9646125e+00,  4.6783385e+00,  8.2129698e+00,  7.1324143e+00],
       [-1.8006270e+00,  7.6882154e-01,  9.8288918e+00,  8.9754953e+00]],      dtype=float32)

In [1199]:
from sklearn.decomposition import PCA

In [1200]:
model = PCA(n_components=3)

B = jnp.array([[1, 2, 0, 0, 0],
               [0, 4, 5, 0, 0],
               [0, 0, 7, 8, 0],
               [0, 0, 0, 10, 11],
               [1, 2, 0, 0, 0],
               [1, 2, 0, 0, 0],
               [1, 2, 0, 0, 0]])

fitted_model = model.fit(B)
B_transformed = fitted_model.transform(B)
fitted_model.inverse_transform(B_transformed)

array([[ 1.00000000e+00,  2.00000000e+00,  2.22044605e-16,
        -4.44089210e-16,  0.00000000e+00],
       [-5.55111512e-16,  4.00000000e+00,  5.00000000e+00,
        -4.44089210e-16, -2.22044605e-16],
       [ 0.00000000e+00, -1.33226763e-15,  7.00000000e+00,
         8.00000000e+00, -4.21884749e-15],
       [ 3.33066907e-16,  2.22044605e-16, -3.33066907e-15,
         1.00000000e+01,  1.10000000e+01],
       [ 1.00000000e+00,  2.00000000e+00,  2.22044605e-16,
        -4.44089210e-16,  0.00000000e+00],
       [ 1.00000000e+00,  2.00000000e+00,  2.22044605e-16,
        -4.44089210e-16,  0.00000000e+00],
       [ 1.00000000e+00,  2.00000000e+00,  2.22044605e-16,
        -4.44089210e-16, -2.22044605e-16]])

In [1201]:
class myPCA():
    def __init__(self, num_components:int):
        self.num_components = num_components
        self.mean = None
        self.principal_components = None
        self.explained_variance = None

    def fit(self, X:jax.Array):
        n, m = X.shape
        
        self.mean = X.mean(axis=0)
        X_centred = X - self.mean
        S, self.principal_components = svd(X_centred, full_matrices=True)[1:]

        if n < m:
            S = jnp.concatenate((jnp.diag(S), jnp.zeros((n, m-n))), axis=1)
        elif n > m:
            S = jnp.concatenate((jnp.diag(S), jnp.zeros((n-m, m))), axis=0)

        self.explained_variance = S**2 / jnp.sum(S**2)

    def transform(self, X:jax.Array):
        if self.principal_components is None:
            raise RuntimeError('Must fit before transforming.')
        
        X_centred = X - X.mean(axis=0)
        return jnp.dot(X_centred, self.principal_components[:self.num_components].T)
    
    def fit_transform(self, X:jax.Array):
        X_centred = X - X.mean(axis=0)

        self.principal_components = svd(X_centred, full_matrices=True)[2]

        return jnp.dot(X_centred, self.principal_components[:self.num_components].T)

    def inverse_transform(self, X_transformed:jax.Array):
        if self.principal_components is None:
            raise RuntimeError('Must fit before transforming.')
        
        return jnp.dot(X_transformed, self.principal_components[:self.num_components]) + self.mean

In [1202]:
mymodel= myPCA(num_components=3)
mymodel.fit(B)
B_transformed = mymodel.transform(B)
mymodel.inverse_transform(B_transformed)

Array([[ 1.0000002e+00,  2.0000007e+00, -2.3841858e-07, -1.9073486e-06,
        -9.5367432e-07],
       [ 2.9802322e-07,  4.0000010e+00,  4.9999976e+00, -4.7683716e-07,
        -4.7683716e-07],
       [-4.7683716e-07, -1.4305115e-06,  6.9999995e+00,  8.0000029e+00,
         2.3841858e-06],
       [-4.7683716e-07, -2.0265579e-06,  4.1723251e-06,  1.0000006e+01,
         1.1000003e+01],
       [ 1.0000002e+00,  2.0000007e+00, -2.3841858e-07, -1.9073486e-06,
        -9.5367432e-07],
       [ 1.0000002e+00,  2.0000007e+00, -2.3841858e-07, -1.9073486e-06,
        -9.5367432e-07],
       [ 1.0000002e+00,  2.0000007e+00, -2.3841858e-07, -1.9073486e-06,
        -9.5367432e-07]], dtype=float32)