In [1]:
from math import log, sqrt
import functools as ft

import jax
import jax.numpy as jnp
import jax.random as jrand
import jax.scipy as jsp
import tensorflow_probability.substrates.jax.distributions as jd
from jax import lax, jit, vmap


@ft.partial(jit, static_argnums=1)
def _pca(x, n_components):
    n_samples, n_features = x.shape

    mean = jnp.mean(x, axis=0)

    # Center data
    x_centered = x - mean

    U, S, Vt = jsp.linalg.svd(x_centered, full_matrices=False)

    components = Vt

    explained_variance = (S**2) / (n_samples - 1)
    total_var = explained_variance.sum()
    explained_variance_ratio = explained_variance / total_var
    singular_values = S

    return U, S, Vt, components[:n_components], explained_variance[:n_components], explained_variance_ratio[:n_components], singular_values[:n_components]

In [2]:
import numpy as np
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])

In [3]:
_pca(X, 2)

(DeviceArray([[-0.21956691,  0.53397   ],
              [-0.35264796, -0.45713517],
              [-0.57221484,  0.07683427],
              [ 0.2195669 , -0.5339696 ],
              [ 0.35264796,  0.4571353 ],
              [ 0.57221484, -0.07683427]], dtype=float32),
 DeviceArray([6.3006124, 0.5498041], dtype=float32),
 DeviceArray([[ 0.8384922,  0.5449136],
              [ 0.5449136, -0.8384922]], dtype=float32),
 DeviceArray([[ 0.8384922,  0.5449136],
              [ 0.5449136, -0.8384922]], dtype=float32),
 DeviceArray([7.9395432 , 0.06045691], dtype=float32),
 DeviceArray([0.9924429 , 0.00755711], dtype=float32),
 DeviceArray([6.3006124, 0.5498041], dtype=float32))

In [4]:
%timeit -n10 -r3 _pca(X, 2)

The slowest run took 10.85 times longer than the fastest. This could mean that an intermediate result is being cached.
38.3 µs ± 41.2 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [5]:
from sklearn.decomposition import PCA

In [6]:
pca = PCA(n_components=2)

In [7]:
%timeit -n10 -r3 pca.fit(X)

The slowest run took 8.63 times longer than the fastest. This could mean that an intermediate result is being cached.
277 µs ± 241 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
