In [None]:
from fairPCA_jaxnumpy import StreamingFairBlockPCA
import matplotlib.pyplot as plt
import jax.numpy as np
from jax import random

Algo = StreamingFairBlockPCA(
    data_dim=100,
    probability=0.5,
    nullity=0,  # nullity of Sigma_gap
    seed=None,
    eps=0.1,
    mu_scale=0.1,
    max_cov_eig0=1,
    max_cov_eig1=1.5
)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(15, 5))

w0, _ = np.linalg.eigh(Algo.Sigma0)
w1, _ = np.linalg.eigh(Algo.Sigma1)
w, _ = np.linalg.eigh(Algo.Sigma)
wg, _ = np.linalg.eigh(Algo.Sigma_gap)

ax[0].plot(w0, label='Eigenvalues: Sigma0')
ax[0].plot(w1, label='Eigenvalues: Sigma1')
ax[0].legend()
ax[1].plot(w, label='Eigenvalues: Sigma')
ax[1].plot(wg,label='Eigenvalues: Sigma_gap')
ax[1].legend()
ax[2].plot(Algo.mu_gap, c='tab:green', label='Entries: $\mu_{gap}$=$\mu_1$-$\mu_0$')
ax[2].plot(Algo.mu0, c='tab:orange', label='Entries: $\mu_0$')
ax[2].plot(Algo.mu1, c='tab:red', label='Entries: $\mu_1$')
ax[2].plot(Algo.mu, c='gray', label='Entries: $\mu=(1-p)\mu_0 + p\mu_1$')
ax[2].legend()

In [None]:
unfair_pc_score = [(np.linalg.norm(Algo.eigvec_Sigma[:,-i-1] @ Algo.eigvec_Sigma_gap_sq[:,-1:]), i) for i in range(100)]
max(unfair_pc_score, key=lambda x: x[0])

In [None]:
Algo.eigvec_Sigma[:,-1] @ Algo.eigvec_Sigma_gap_sq[:,-1]

In [None]:
k = 4
a, _ = Algo.get_ground_truth(k, 10, 'covariance')
a, Algo.eigval_Sigma[-k:].sum()

In [None]:
R = Algo.eigvec_Sigma_gap_sq[:,:Algo.d-20]
M = R.T @ Algo.Sigma @ R
eigval, eigvec = np.linalg.eigh(M)

In [None]:
np.trace(Algo.Sigma), np.trace(M), eigval[-4:].sum()

## Offline Training

In [None]:
n_iter = 100
V = Algo.offline_train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    lr=1,
    mode='pm',
    constraint='vanilla',
    seed=0,
    tol=0,
    #lr_scheduler=lambda t: 0.99**(t-1)
)
fig, axes = Algo.plot_buffer(save=None);  # semicolon
# Algo.plot_buffer(save=None, fig=fig, axes=axes); 
# fig

In [None]:
V = Algo.offline_train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    lr=1,
    mode='pm',
    constraint='all',
    seed=0,
    tol=0,
    #lr_scheduler=lambda t: 0.99**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None); 
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

## Online Training

### Noisy Power Method

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='vanilla',
    center_by_mean=None,
    subspace_optimization=None,
    pca_optimization='oja',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
fig, axes = Algo.plot_buffer(save=None);  # semicolon
# Algo.plot_buffer(save=None, fig=fig, axes=axes); 
# fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='vanilla',
    center_by_mean=None,
    subspace_optimization=None,
    pca_optimization='npm',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='vanilla',
    center_by_mean=None,
    subspace_optimization=None,
    pca_optimization='npmfd',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='vanilla',
    center_by_mean=None,
    subspace_optimization=None,
    pca_optimization='riemannian',
    lr_pca=0.2,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='vanilla',
    center_by_mean=None,
    subspace_optimization=None,
    pca_optimization='history',
    lr_pca=1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

## Online - Fair PCA

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='vanilla',
    center_by_mean=None,
    subspace_optimization=None,
    pca_optimization='oja',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
fig, axes = Algo.plot_buffer(save=None);  # semicolon
# Algo.plot_buffer(save=None, fig=fig, axes=axes); 
# fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=10,
    constraint='mean',
    center_by_mean=None,
    subspace_optimization='npmfd',
    pca_optimization='oja',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
fig, axes = Algo.plot_buffer(save=None);  # semicolon
# Algo.plot_buffer(save=None, fig=fig, axes=axes); 
# fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='all',
    center_by_mean=None,
    subspace_optimization='npmfd',
    pca_optimization='npmfd',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='all',
    center_by_mean=None,
    subspace_optimization='history',
    pca_optimization='history',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig

In [None]:
n_iter = 200
V = Algo.train(
    target_dim=3,
    rank=10,
    n_iter=n_iter,
    batch_size=50,
    constraint='all',
    center_by_mean=None,
    subspace_optimization='history',
    pca_optimization='oja',
    lr_pca=0.1,
    n_iter_inner=1,
    n_iter_history=10,
    landing_lambda=1,
    seed=0,
    tol=0,
    # lr_scheduler=lambda t: 0.998**(t-1)
)
# fig, axes = Algo.plot_buffer(save=None);  # semicolon
Algo.plot_buffer(save=None, fig=fig, axes=axes); 
fig