# Principal Component Analysis
PCA implementation using eigen and SVD

ref [1](http://agnesmustar.com/2017/11/01/principal-component-analysis-pca-implemented-pytorch/)

ref [2](https://zhuanlan.zhihu.com/p/58064462)

ref [3](https://blog.csdn.net/Dark_Scope/article/details/53150883)

ref [4](https://blog.csdn.net/Little_Fire/article/details/80445987) Eigende Composition

In [1]:
import torch

In [2]:
torch.manual_seed(886)
m=5 # num of sample
n=8 # dim of one sample
ratio=0.99

In [3]:
src = torch.rand((m,n))
# 1. centerize (n,m)
src = src - src.mean()
src

tensor([[-0.1389,  0.3233,  0.1166, -0.4258,  0.2467,  0.0271, -0.1988,  0.4109],
        [ 0.2700, -0.4517, -0.4135,  0.4607,  0.0278, -0.2408, -0.2052, -0.0065],
        [ 0.1202,  0.3165, -0.4434,  0.0486, -0.2088,  0.3180,  0.0377,  0.0705],
        [-0.2457, -0.2703, -0.0092,  0.3814,  0.2892, -0.0494,  0.3654, -0.3693],
        [ 0.3405, -0.2373, -0.1980,  0.3845, -0.2701, -0.1027, -0.2083,  0.1383]])

## Eigen
$$ X^TX = VEV^T$$

In [4]:
# 2. conv (m,m)
conv = torch.mm(src.t(), src)

In [5]:
# 3. eigen
# eigen value is sorted, ascending
# eigen vector is column by column 
e, v = conv.symeig(eigenvectors=True)

In [6]:
# check eig
rec_conv = torch.matmul(torch.matmul(v, torch.diag(e)), torch.t(v))
assert conv.allclose(rec_conv)

In [7]:
# find the perfect k
total = e.sum()
for k in range(1, n+1):
    if e[-k:].sum()/total >= ratio:
        print('perfect k is', k)
        break

perfect k is 4


In [8]:
# 4. compute pca (m,k)
pca_eig = torch.mm(src, v[:,-k:])
pca_eig

tensor([[-0.2578,  0.2384,  0.1905,  0.6529],
        [-0.2050,  0.1887,  0.2000, -0.7995],
        [-0.1288, -0.5249,  0.4153,  0.0067],
        [-0.1565, -0.1421, -0.6577, -0.3916],
        [ 0.1006,  0.1023,  0.4043, -0.5562]])

## SVD
$$X=U \Sigma V^T$$

In [9]:
u, s, v = src.svd()
# s is sorted, but is descending
pca_svd = torch.mm(src, v[:,:k])
pca_svd

tensor([[-0.6529, -0.1905,  0.2384, -0.2578],
        [ 0.7995, -0.2000,  0.1887, -0.2050],
        [-0.0067, -0.4153, -0.5249, -0.1288],
        [ 0.3916,  0.6577, -0.1421, -0.1565],
        [ 0.5562, -0.4043,  0.1023,  0.1006]])