In [None]:
!git clone https://<Insert token here>@github.com/Ekanshsomani/es335-24-fall-assignment-2/
%cd es335-24-fall-assignment-2

You have an image patch of size (50x50) that you want to compress using matrix factorization. To do this, you'll split the patch $[N\times N]$ into two smaller matrices of size $[N\times r]$ and $[r\times N]$ using matrix factorization. Compute the compressed patch by multiplying these two matrices and compare the reconstructed image patch with the original patch. Compute the Root Mean Squared Error (RMSE) and Peak Signal-to-Noise Ratio (PSNR) between the original and reconstructed image patches.

- Test different values for the low-rank $r = [5, 10, 25, 50]$.
- Use Gradient Descent to learn the compressed matrices.
- Display the reconstructed image patches, keeping the original pixel values outside the patch unchanged, and use your compressed matrix for the patch to show how well the reconstruction works.
- Compute the RMSE and PSNR for each value of $r$. 

In [None]:
import torch

# Remove all the warnings
import warnings
warnings.filterwarnings('ignore')

# Set env CUDA_LAUNCH_BLOCKING=1
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Retina display
%config InlineBackend.figure_format = 'retina'

# try:
#     from einops import rearrange
# except ImportError:
#     %pip install einops
#     from einops import rearrange

In [None]:
if os.path.exists('dog.jpg'):
    print('dog.jpg exists')
else:
    !curl -o dog.jpg https://segment-anything.com/assets/gallery/AdobeStock_94274587_welsh_corgi_pembroke_CD.jpg

In [None]:
# functions for simple matrix factorization

import torch.optim as optim

# gradient descent
def fac_grad(A: torch.tensor, r: int, epochs: int):
    A = A.to(device)
    n = A.shape[0]
    U = torch.rand(n, r, requires_grad = True, device = device)
    V = torch.rand(r, n, requires_grad = True, device = device)

    losses = []
    optimizer = optim.Adam([U, V], lr = 0.01)
    # mask = ~torch.isnan(A)

    for i in range(epochs):
        diff_matrix = (U @ V) - A
        # diff_vector = diff_matrix[mask]
        loss = torch.norm(diff_matrix)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i%50 == 0:
            losses.append(loss)
    return U, V, losses

# wals method
def fac_wals(A: torch.tensor, r: int, epochs: int):
    A = A.to(device)
    n = A.shape[0]
    U = torch.rand(n, r, requires_grad = True, device = device)
    V = torch.rand(r, n, requires_grad = True, device = device)

    losses = []
    optimizer1 = optim.Adam([U], lr = 0.01)
    optimizer2 = optim.Adam([V], lr = 0.01)

    for i in range(epochs):
        # fix V, update U
        diff_matrix = (U @ V) - A
        loss = torch.norm(diff_matrix)
        optimizer1.zero_grad()
        loss.backward()
        optimizer1.step()

        # fix U, update V
        diff_matrix = (U @ V) - A
        loss = torch.norm(diff_matrix)
        optimizer2.zero_grad()
        loss.backward()
        optimizer2.step()

        if i%50 == 0:
            losses.append(loss)

    return U, V, losses

In [None]:
# loop over the functions for image factorization

# gradient descent
def fac_grad_loop(A: torch.tensor, r: int, epochs: int):
    A = A.to(device)
    n, _, d = A.shape

    U = torch.zeros(n, r, d, device = device)
    V = torch.zeros(r, n, d, device = device)
    losses = [[], [], []]
    for i in range(d):
        U[:, :, i], V[:, :, i], losses[i] = fac_wals(A[:, :, i], r, epochs)

    return U, V, losses

# WALS method
def fac_wals_loop(A: torch.tensor, r: int, epochs: int):
    A = A.to(device)
    n, _, d = A.shape

    U = torch.zeros(n, r, d, device = device)
    V = torch.zeros(r, n, d, device = device)
    losses = [[], [], []]

    for i in range(d):
        U[:, :, i], V[:, :, i], losses[i] = fac_wals(A[:, :, i], r, epochs)

    return U, V, losses

In [None]:
from math import log10, sqrt

def diffs(org: torch.tensor, comp: torch.tensor, max_pixel: float = 1.0) -> tuple[float, float]:

    mse = torch.mean((org - comp) ** 2).item()

    if mse < 1e-8: return 0.0, float('inf')

    return sqrt(mse), 10 * log10((max_pixel ** 2) / mse)

In [None]:
import matplotlib.pyplot as plt

def createPlots(patch: torch.tensor, crop: torch.tensor, y: int, x: int,
                gfactorize = fac_grad_loop, wfactorize = fac_wals_loop, epochs: int = 1000):
    r_vals = [5, 10, 25, 50]

    gd_psnr = []
    gd_rms = []
    wals_psnr = []
    wals_rms = []

    fig, axs = plt.subplots(4, 3, figsize = (14, 18))
    for i, r in enumerate(r_vals):
        U_grad, V_grad, _ = gfactorize(patch, r, epochs)
        U_wals, V_wals, _ = wfactorize(patch, r, epochs)

        crop_grad = crop.clone()
        crop_wals = crop.clone()

        for j in range(3):
            crop_grad[y:y+50, x:x+50, j] = (U_grad[:, :, j] @ V_grad[:, :, j]).detach()
            crop_wals[y:y+50, x:x+50, j] = (U_wals[:, :, j] @ V_wals[:, :, j]).detach()

        rms, psnr = diffs(patch, crop_grad[y:y+50, x:x+50, :])
        gd_rms.append(rms)
        gd_psnr.append(psnr)

        rms, psnr = diffs(patch, crop_wals[y:y+50, x:x+50, :])
        wals_rms.append(rms)
        wals_psnr.append(psnr)

        axs[i, 0].imshow(crop.cpu().numpy())
        axs[i, 1].imshow(crop_grad.cpu().numpy())
        axs[i, 2].imshow(crop_wals.cpu().numpy())
        # im_wals is the patch, that we need to change at x, y and display the crop in axs[i, 2]

    for ax, col in zip(axs[0], ['Original', 'Gradient Descent', 'WALS']):
        ax.set_title(col)

    for ax, r in zip(axs[:,0], r_vals):
        ax.set_ylabel(f'r = {r}', size='large')

    plt.tight_layout()
    plt.show()

    fig1, axs1, = plt.subplots(1, 2, figsize=(10, 5))
    # PSNR comparison

    axs1[0].plot(r_vals, gd_psnr, label='Gradient Descent PSNR', marker='o')
    axs1[0].plot(r_vals, wals_psnr, label='WALS PSNR', marker='x')
    axs1[0].set_xlabel('r values')
    axs1[0].set_ylabel('PSNR')
    axs1[0].set_title('PSNR Comparison')
    axs1[0].legend()
    axs1[0].grid()

    # RMS comparison
    axs1[1].plot(r_vals, gd_rms, label='Gradient Descent RMS', marker='o')
    axs1[1].plot(r_vals, wals_rms, label='WALS RMS', marker='x')
    axs1[1].set_xlabel('r values')
    axs1[1].set_ylabel('RMSE')
    axs1[1].set_title('RMSE Comparison')
    axs1[1].grid()

    plt.show()

In [None]:
import torchvision
from sklearn import preprocessing

img = torchvision.io.read_image("dog.jpg")
scaler_img = preprocessing.MinMaxScaler().fit(img.reshape(-1, 1))
img_scaled = scaler_img.transform(img.reshape(-1, 1)).reshape(img.shape)
img_scaled = torch.tensor(img_scaled)


crop = torchvision.transforms.functional.crop(img_scaled, 600, 800, 300, 300)
patch1 = torchvision.transforms.functional.crop(crop, 0, 0, 50, 50)
patch2 = torchvision.transforms.functional.crop(crop, 50, 50, 50, 50)
patch3 = torchvision.transforms.functional.crop(crop, 175, 75, 50, 50)

In [None]:
crop = crop.permute(1, 2, 0)
patch1 = patch1.permute(1, 2, 0)
patch2 = patch2.permute(1, 2, 0)
patch3 = patch3.permute(1, 2, 0)

In [None]:
createPlots(patch1, crop, 0, 0, epochs=200)

In [None]:
createPlots(patch2, crop, 50, 50, epochs=200)

In [None]:
createPlots(patch3, crop, 175, 75, epochs=200)

In [None]:
createPlots(patch1, crop, 0, 0, epochs=500)

In [None]:
createPlots(patch2, crop, 50, 50, epochs=500)

In [None]:
createPlots(patch3, crop, 175, 75, epochs=500)