In [None]:
# Imports
import numpy as np, matplotlib.pyplot as plt
import torch, torch_pca, gc
from torchvision import datasets, transforms
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from PIL import Image
from tqdm import tqdm

In [None]:
# Params
SIDE_LENGTH = 128

In [None]:
# Data imports
train_transforms=transforms.Compose([
        transforms.Resize((SIDE_LENGTH, SIDE_LENGTH)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        #transforms.Grayscale(),
        transforms.ToTensor(),])

test_transforms=transforms.Compose([
        transforms.Resize((SIDE_LENGTH, SIDE_LENGTH)),
        #transforms.Grayscale(),
        transforms.ToTensor(),])

train_dataset = datasets.Imagenette(transform=train_transforms, size="320px", root='data', split='train')
test_dataset = datasets.Imagenette(transform=test_transforms, size="320px", root='data', split='val')

In [None]:
def channel_pca(channel_index: int, n_components: int, dtype=torch.float16):
    """Computes the pca over one channel over the training set."""
    global train_dataset, SIDE_LENGTH
    data = torch.empty(len(train_dataset), SIDE_LENGTH, SIDE_LENGTH, dtype=dtype)
    for (i, d) in enumerate(train_dataset):
        data[i, :, :] = torch.tensor(np.array(d[0])[channel_index, :, :].reshape(SIDE_LENGTH, SIDE_LENGTH))
    data = data.reshape(-1, SIDE_LENGTH*SIDE_LENGTH)
    gc.collect()

    pca = torch_pca.PCA(n_components=n_components, svd_solver='randomized')
    pca.fit(data)
    gc.collect()
    return pca

In [None]:
# May take a couple of minutes, depending on parameters
pcas = [channel_pca(channel, 100) for channel in range(3)]

In [None]:
# Save the PCA results as a .pt file
# torch.save({
#     'components0': pcas[0].mean_,
#     'explained_variance0': pcas[0].explained_variance_,
#     'mean0': pcas[0].components_,
#     'components1': pcas[1].mean_,
#     'explained_variance1': pcas[1].explained_variance_,
#     'mean1': pcas[1].components_,
#     'components2': pcas[2].mean_,
#     'explained_variance2': pcas[2].explained_variance_,
#     'mean2': pcas[2].components_,
# }, f'pcas{pcas[0].n_components_}.pt')

In [None]:
def load_test_img(idx: int):
    """Returns the given test image as a torch tensor with shape (s, s, 3)."""
    global test_dataset
    return torch.tensor(np.array(test_dataset[idx][0]), dtype=torch.float32).permute(1, 2, 0)

def pca_project(pcas, idx: int, n=None):
    """Projects the given test image on n PCA components, or all available if `n` is unspecified."""
    global SIDE_LENGTH
    nc = pcas[0].n_components_
    n = nc if n is None else n
    p = torch.cat((torch.ones(n), torch.zeros(nc-n)), dim=0)
    sl = SIDE_LENGTH
    img = load_test_img(idx)
    tchs = [pca.inverse_transform(pca.transform(img[:, :, ch].reshape(1, -1)) * p).reshape(sl, sl, 1) for (ch, pca) in enumerate(pcas)]
    return torch.cat(tchs, dim=2)

def pca_multi_project(pcas, idx: int):
    """Projects the given test image on all PCA components.
    Returns a (3, nc+1, s, s) tensor with the [:, 0, :, :] tensor representing channel-wise means, 
    and [:, n, :, :] are projections on the n first components (incl. means)."""
    global SIDE_LENGTH
    nc = pcas[0].n_components_
    sl = SIDE_LENGTH
    img = load_test_img(idx)
    with torch.no_grad():
        out = torch.zeros((3, nc+1, sl, sl))
        for (ch, pca) in enumerate(pcas):
            x = pca.transform(img[:, :, ch].reshape(1, -1))
            out[ch, 1:, :, :] = torch.matmul(pca.components_.transpose(0, 1), torch.diag(x.flatten())).reshape(nc, sl, sl)
            out[ch, 0, :, :] = pca.mean_.reshape(sl, sl)
        out = torch.cumsum(out, dim=1)
    return out

In [None]:
def show_img(t):
    return Image.fromarray(torch.clamp(255*t, 0, 255).byte().numpy())

def show_test_img(idx):
    return Image.fromarray(np.array((255*test_dataset[idx][0]).byte()).transpose((1, 2, 0)))

def normalize_gray(x: torch.Tensor):
    """Normalizes 0 to 0.5, and scales everything evenly so that all elements are in [0, 1].
    (Almost surely) either the smallest element will be 0 or the largest will be 1."""
    return 0.5 + x/(2*torch.max(torch.abs(x)))

def normalize_01(x: torch.Tensor):
    """Normalizes x so that the smallest element is 0 and the largest is 1."""
    return (x-torch.min(x))/(torch.max(x)-torch.min(x))

def get_component(pcas, idx):
    global SIDE_LENGTH
    nc = pcas[0].n_components_
    sl = SIDE_LENGTH
    if idx == -1:
        return torch.cat([(pca.mean_).reshape(sl, sl, 1) for pca in pcas], dim=2)
    p = torch.eye(nc)[idx]
    q = torch.cat([(pca.inverse_transform(p)-pca.mean_).reshape(sl, sl, 1) for pca in pcas], dim=2)
    # Attempt to align the components in sign.
    for ch in [1, 2]:
        flip = torch.sign(torch.dot(q[:, :, 0].flatten(), q[:, :, ch].flatten()))
        assert flip != 0, f"The {idx}th components for channel 0 and {ch} are orthogonal. This is extremely unlikely."
        q[:, :, ch] *= flip
    return q

In [None]:
num_rows = 4
num_cols = 4  # Adjust based on how many images you want to plot

# Create a figure with subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(8, 8))

ax = axes[0,0]
ax.imshow(show_img(get_component(pcas, -1)))
ax.axis('off')
ax.set_title('µ', fontsize=10)

for i in range(num_rows * num_cols-1):
    ax = axes[(i+1) // num_cols, (i+1) % num_cols]
    ax.imshow(show_img(normalize_gray(get_component(pcas, i))))
    ax.axis('off')
    ax.set_title(f'C{i+1}', fontsize=10)

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
idx = 100
num_rows = 3
num_cols = 3  # Adjust based on how many images you want to plot
nc = pcas[0].n_components_

# Create a figure with subplots
fig, axes = plt.subplots(num_rows, num_cols, figsize=(6, 6))

ax = axes[0,0]
ax.imshow(show_test_img(idx))
ax.axis('off')
ax.set_title('Original', fontsize=10)

for i in range(num_rows*num_cols-1):
    n = round(np.linspace(0, nc, num_rows*num_cols-1)[i])
    ax = axes[(i+1) // num_cols, (i+1) % num_cols]
    ax.imshow(show_img(pca_project(pcas, idx, n=n)))
    ax.axis('off')
    ax.set_title(f'{n}C', fontsize=10)

# Adjust layout to prevent overlap
plt.tight_layout()
plt.show()

In [None]:
plt.imshow(show_img(pca_project(pcas, idx, n=100)))

In [None]:
# Veeery slow...
l = len(test_dataset)
nc = pcas[0].n_components_

mse = np.zeros((l, nc))
ssim = np.zeros((l, nc))
psnr = np.zeros((l, nc))

# This may benefit from parallelization
# For N=100 with step 0, should take just under 1hr.
for i in tqdm(range(l)):
    original = np.array(load_test_img(i)) # (128, 128, 3)
    original_bc = np.broadcast_to(original[None, :, :, :], (nc, *original.shape)) # (nc, 128, 128, 3)
    cprojected = np.array(pca_multi_project(pcas, i)).transpose((1, 2, 3, 0))[1:, :, :, :] # (nc, 128, 128, 3)
    mse[i, :] = np.mean(np.square(cprojected - original_bc), axis=(1,2,3))
    # ssim[i, :] = structural_similarity(original_bc, cprojected, data_range=1, channel_axis=3) # Faster when unvectorized.
    psnr[i, :] = peak_signal_noise_ratio(original_bc, cprojected, data_range=1)
    for n in range(nc):
        # mse[i, n] = np.mean(np.square(projected - original))
        ssim[i, n] = structural_similarity(original, cprojected[n, :, :, :], data_range=1, channel_axis=2) # Faster when unvectorized for some reason.
        # psnr[i, n] = peak_signal_noise_ratio(original, cprojected[n, :, :, :], data_range=1)

In [None]:
torch.save({
    'mse': torch.tensor(mse),
    'ssim': torch.tensor(ssim),
    'psnr': torch.tensor(psnr)
}, 'pca_test_metrics.pt')

In [None]:
data = torch.load('pca_test_metrics.pt', weights_only=True)
mse = np.array(data['mse'])
ssim = np.array(data['ssim'])
psnr = np.array(data['psnr'])

In [None]:
# X axis is number of components.
plt.plot(np.mean(mse, axis=0))
plt.xlabel("Component count")
plt.ylabel("Pixel MSE")

In [None]:
plt.plot(np.mean(psnr, axis=0))
plt.xlabel("Component count")
plt.ylabel("PSNR")

In [None]:
plt.plot(np.mean(ssim, axis=0))
plt.xlabel("Component count")
plt.ylabel("SSIM")