In [None]:
from PIL import Image
import numpy as np
import scipy
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
import sys
import os
from multiprocessing import Pool
from functools import partial
directory = 'figures'
if not os.path.exists(directory):
    os.makedirs(directory)
from concurrent.futures import ProcessPoolExecutor, as_completed
from dask.distributed import Client, progress
from dask import compute, delayed
import dask.array as da
from dask.diagnostics import ProgressBar
plt.rcParams['figure.dpi']=400

In [None]:
def load_image(num):
    # loads image, converts to grayscale, then converts to a matrix
    image = Image.open('images/'+f'image{num}.png')
    gray_image = image.convert('L')
    matrix = np.array(gray_image)
    return matrix

def compute_svd(image_matrix):
    # computes the singular value decomposition of a matrix
    U, s, Vt = np.linalg.svd(image_matrix, full_matrices=False)
    return U, s, Vt

def compress_image(U, s, Vt, k):
    # compresses it to rank k
    S = np.diag(s[:k])
    return np.dot(U[:, :k], np.dot(S, Vt[:k, :]))

def frobenius_percent(original, approx):
    # computes frobenius error
    return 100*(np.linalg.norm(original - approx, 'fro')/np.linalg.norm(original, 'fro'))

def memory_saved(m, n, k, s):
    # computes percentage memory saved
    new = (m * k + k + n * k)
    old = m * len(s) + len(s) + n * len(s)
    return 100*(old-new)/old
    

In [None]:
path = 2
A = load_image(path)
U, s, Vt = compute_svd(A)
N = int(np.log2(len(s)))
ks=[]

for i in range(2, N+1):
    ks.append(len(s)//2**i) 
    
images = [A]
titles = [f'Original, k = {len(s)}']
frobenius=[0]
m, n = A.shape
memory=[0]

for k in ks:
    A_k = compress_image(U, s, Vt, k)
    images.append(A_k)
    frobenius.append(frobenius_percent(A, A_k))
    memory.append(memory_saved(m, n, k, s))
    titles.append(f'k = {k}')

In [None]:
num_images = len(images)
cols = min(num_images, 3)
rows = num_images // cols + (num_images % cols > 0)

fig, axs = plt.subplots(rows, cols, figsize=(cols * (4*(m/n)), rows * 1.25*(n/m)))
fig.subplots_adjust(hspace=0, wspace=0)

axs = axs.ravel() if num_images > 1 else [axs]

for i in range(num_images):
    axs[i].imshow(images[i], cmap='gray')
    axs[i].set_title(f'{titles[i]}\nFrobenius error: {frobenius[i]:.3g}%\nMemory saved: {memory[i]:.3g}%', fontsize=5)
    axs[i].axis('off')
for ax in axs[num_images:]:
    ax.axis('off')
plt.savefig(os.path.join(directory, 'compressed-milkyway.png'), dpi=400)
plt.show()