In [None]:
from numpy.random import RandomState
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn import datasets

from sklearn.datasets import fetch_olivetti_faces
 
from sklearn import decomposition 
n_row, n_col = 2, 5
n_components = n_row * n_col
image_shape = (64, 64)
rng = RandomState(0)

In [None]:
%matplotlib inline

In [None]:
faces = fetch_olivetti_faces(data_home='./',shuffle=True, random_state=rng)
#faces = datasets.fetch_olivetti_faces()
faces.data.shape

In [None]:
fig = plt.figure(figsize=(10, 10))
for i in range(10):
    ax = plt.subplot2grid((1, 10), (0, i))    
    ax.imshow(faces.data[i * 10].reshape(64, 64), cmap=plt.cm.gray)
    ax.axis('off')

In [None]:
pca = decomposition.PCA()
pca.fit(faces.data)

print(pca.components_.shape)

In [None]:
fig = plt.figure(figsize=(10, 10))
for i in range(10):
    ax = plt.subplot2grid((1, 10), (0, i))
    
    ax.imshow(pca.components_[i].reshape(64, 64), cmap=plt.cm.gray)
    ax.axis('off')

In [None]:
from skimage.io import imsave
face = faces.data[0]  
trans = pca.transform(face.reshape(1, -1))  

In [None]:
trans.shape

In [None]:
# for k in range(400):
#     rank_k_approx = trans[:, :k].dot(pca.components_[:k]) + pca.mean_
#     if k % 10 == 0:
#         imsave('{:>03}'.format(str(k)) + '.jpg', rank_k_approx.reshape(64, 64))
import os

from skimage import img_as_ubyte

# Create a folder named "log" if it doesn't exist
log_folder = "log"
if not os.path.exists(log_folder):
    os.makedirs(log_folder)

for k in range(400):
    rank_k_approx = trans[:, :k].dot(pca.components_[:k]) + pca.mean_
    if k % 10 == 0:
        # Convert the image data to 'uint8' format
        rank_k_approx_uint8 = img_as_ubyte(rank_k_approx.reshape(64, 64))
        
        # Save the image in the "log" folder
        filename = os.path.join(log_folder, '{:>03}'.format(str(k)) + '.jpg')
        imsave(filename, rank_k_approx_uint8)

In [None]:
from matplotlib.image import imread
# Get a list of all JPEG files in the "log" folder and sort them by name
jpeg_files = sorted([file for file in os.listdir(log_folder) if file.endswith(".jpg")])
# Display all the images

# Define the number of images per row
images_per_row = 10

# Calculate the number of rows needed
num_rows = len(jpeg_files) // images_per_row + (len(jpeg_files) % images_per_row > 0)

# Create subplots
fig, axes = plt.subplots(num_rows, images_per_row, figsize=(12, 3*num_rows))

# Display all the images
for i, jpeg_file in enumerate(jpeg_files):
    # Load the image
    image_path = os.path.join(log_folder, jpeg_file)
    image = imread(image_path)

    # Determine the position in the grid
    row = i // images_per_row
    col = i % images_per_row

    # Display the image
    axes[row, col].imshow(image, cmap='gray')  # Assuming grayscale images
    axes[row, col].set_title(f'Image {i+1}')
    axes[row, col].axis('off')

# Adjust layout and show the plot
plt.tight_layout()
plt.show()