# Example: PCA with MNIST

In [None]:
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset

%matplotlib inline
import matplotlib.pyplot as plt

## Downloading the data

In [None]:
path = './data/mnist/'

img_size = [28, 28]

trainset = torchvision.datasets.MNIST(path, train=True, download=True, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=0)

In [None]:
data = []
labels = []
for image, label in trainloader:
    data.append(image.detach().numpy())
    labels.append(label.detach().numpy())
data = np.concatenate(data)
labels = np.concatenate(labels)
X = data.reshape(data.shape[0], np.prod(img_size))
X = X[:10000].T
y = labels[:10000]
m, n = X.shape
print('Dimension:\t' + str(m))
print('Samples:\t' + str(n))

## Principal Component Analysis

In [None]:
import numpy as np
from sklearn import decomposition

pca = decomposition.PCA(whiten=False)

X_prime = pca.fit_transform(X.T).T
mu = pca.mean_
U = pca.components_.T
D = pca.singular_values_**2 / (n - 1)
exp_var = pca.explained_variance_ratio_

print('Size before PCA: ' + str(X.shape))
print('Size after PCA: ' + str(X_prime.shape))

print('Size of U: ' + str(U.shape))
print('Size of D: ' + str(D.shape))

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(np.cumsum(exp_var))
ax.set_xlabel('Mode')
ax.set_ylabel('Retained Variance')
plt.show()    

In [None]:
mu_img = mu
mode1 = U[:,0] * np.sqrt(D[0])*1;
mode2 = U[:,1] * np.sqrt(D[1])*1;
mode3 = U[:,2] * np.sqrt(D[2])*1;
mode4 = U[:,3] * np.sqrt(D[3])*1;

mu_img -= np.min(mu_img)
mode1 -= np.min(mode1)
mode2 -= np.min(mode2)
mode3 -= np.min(mode3)
mode4 -= np.min(mode4)

mu_img /= np.max(mu_img)
mode1 /= np.max(mode1)
mode2 /= np.max(mode2)
mode3 /= np.max(mode3)
mode4 /= np.max(mode4)

fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(14, 4))
ax1.imshow(mu_img.reshape(img_size))
ax2.imshow(mode1.reshape(img_size))
ax3.imshow(mode2.reshape(img_size))
ax4.imshow(mode3.reshape(img_size))
ax5.imshow(mode4.reshape(img_size))
plt.show()

In [None]:
mode_x = 0
mode_y = 1
f = plt.figure()
plt.scatter(X_prime[mode_x,:],X_prime[mode_y,:], c=y, marker='.', cmap='tab10')
plt.grid()
plt.colorbar()
plt.xlabel('PCA 1')
plt.ylabel('PCA 2')
plt.show()

In [None]:
from ipywidgets import interact, fixed

def plot_digit(mean_shape,modes,s1,s2,s3,s4):
    image = mu + U[:,0] * s1 + U[:,1] * s2 + U[:,2] * s3 + U[:,3] * s4
    image -= np.min(image)
    image /= np.max(image)
    plt.imshow(image.reshape(img_size))

def interactive_pca(mu,U,D):
    interact(plot_digit,mean_shape=fixed(mu),modes=fixed(U),
             **{'s%d' % (i+1): (-np.sqrt(D[i]) * 6, np.sqrt(D[i]) * 6, np.sqrt(D[i])) for i in range(4)});

interactive_pca(mu,U,D)