# PCA on GrassmannManifold

In this Notebook we will see how to do Principal Component Analysis (PCA). principal components analysis (PCA) by viewing it as an optimization problem on the Grassmann manifold, i.e.,

$$\min_{\textbf{U} \in \mathcal{G}(m, r)} \frac{1}{n} \sum_{i=1}^{n} \Vert\textbf{z}_{i} - \textbf{U}\textbf{U}^T \text{z}_i\Vert_{2}^2,$$

where $\textbf{z}_i \in \mathbb{R}^n$ denote the data points. Where $ \mathcal{G}(m, r)$ denotes Grassmann Manifold

In [5]:
from jax.config import config
config.update("jax_enable_x64", True) # run jax at float64 bit precision


from jax import jit, vmap
from jax.numpy.linalg import norm
from rieoptax.core import ManifoldArray, rgrad
from rieoptax.geometry.grassmann import GrassmannCanonical
from rieoptax.mechanism.gradient_perturbation import DP_RGD_Mechanism
from rieoptax.optimizers.first_order import dp_rsgd, rsgd
from rieoptax.optimizers.update import apply_updates

### Fit function 

Rieoptax fit function is similar to Optax with single important difference:  **rgrad** is used instead of **grad**

In [3]:
def fit(params, data, optimizer, epochs, private=False):
    @jit 
    def step(params, opt_state, data):
        def cost(params, data):
            def _cost(params, data):
                return norm(data-params.value@(params.value.T @ data))**2
            return vmap(_cost, in_axes=(None,0))(params, data).mean()
        rgrad_fn = rgrad(cost)
        if private:
            data = data[:, None]
            rgrad_fn = vmap(rgrad_fn, in_axes=(None, 0))#per-example gradient
        rgrads = rgrad_fn(params, data)#calculates Riemannian gradients
        updates, opt_state = optimizer.update(rgrads, opt_state, params)
        params = apply_updates(params, updates)#update using Riemannian Exp
        return params, opt_state, loss_value

    opt_state = optimizer.init(params)
    for i in range(epochs):
        params, opt_state, loss_value = step(params, opt_state, data)

### Downloading the data 

Fot this Notebook we will be using Tiny Imagent. Tiny Imagenet consists of 100,000 images  of shape (3,64,64). 

In [None]:
# download Tiny Imagenet 
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
  
# Unzip raw zip file
!unzip -qq 'tiny-imagenet-200.zip'

In [None]:
import cv2
import glob
from tqdm import tqdm 

TRAINPATH = "../input/tiny-imagenet200zip/tiny-imagenet-200/train/*/images/*.JPEG"
images = [cv2.imread(file) for file in tqdm(glob.glob(TRAINPATH)[:5000])]
jnp_data = jnp.asarray(images[:5000], dtype=jnp.float64).reshape(5000,-1)/255

In [None]:
U_init = ManifoldArray(value=init, manifold=GrassmannCanonical())


# non private PCA
lr, epochs = (3e-3, 400)
optimizer = rsgd(lr)
non_private_U = fit(U_init, Z, optimizer, epochs)


#(eps, delta) differentially private PCA
eps, delta, clip_norm, epochs = (1.0, 1e-6, 0.1, 200)
sigma = DP_RGD_Mechanism(eps, delta, clip_norm, n)
private_optimizer = dp_rsgd(lr, sigma, clip_norm)
private_U = fit(U_init, Z, optimizer, epochs, private=True)