In [138]:
import numpy as np
from sklearn.datasets import fetch_olivetti_faces
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import torch_gs, grassmann_distance
from data import generate_data, generate_batched_data
from scipy import linalg
import os

In [139]:
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"

In [174]:
data = torch.from_numpy(fetch_olivetti_faces(data_home="./olivetti/")['data'])

In [202]:
n = data.shape[-1]
d = 64
device="mps"
num_epochs = 500

In [203]:
data = data.to(device)

In [177]:
V = torch.randn((n, d)).to(torch.float32).to(device)
W = torch.randn((n, d)).to(torch.float32).to(device)

V = torch_gs(V, row_vecs=False, norm=True) * np.sqrt(n)
W = torch_gs(W, row_vecs=False, norm=True) * np.sqrt(n)

In [162]:
# sigma = torch.diag(torch.ones((d))) * 10
# sigma = torch.diag(torch.arange(1, d+1)) / 10.0
sigma = (torch.diag(torch.ones((d))) * 10)#.to(device)
eta = 1
disc_lr = 0.5
gen_lr = 0.01

In [147]:
pca = PCA(n_components=d)
pca.fit(data.cpu().flatten(start_dim=1).numpy())
gt_subspace = pca.components_
gt_subspace = torch.from_numpy(gt_subspace)

In [148]:
def generate_data(U, eta, d, n, sigma, scale=True):
    c = sigma @ np.random.randn(d, 1)
    c = c.to(torch.float32).to(device)
    a = torch.from_numpy(np.random.randn(n, 1) * np.sqrt(eta)).to(torch.float32).to(device)
    scale_factor = np.sqrt(n) if scale else 1
    y = U @ c / scale_factor + a
    return y, c, a


In [None]:
Qs_gan = []
ts_gan = []

total_count = 0

for j in tqdm(range(num_epochs)):
    pbar = tqdm(range(data.shape[0]))
    for i in pbar:
        if total_count % 100 == 0:
            g_distance = grassmann_distance(V.detach().cpu().numpy(), gt_subspace.T)
            Qs_gan.append(g_distance)
            ts_gan.append(total_count)
            pbar.set_description(f"Grassmann distance: {g_distance}")
        y_gen, c_gen, a_gen = generate_data(V, eta, d, n, sigma, scale=True)

        img = data[i].unsqueeze(0).T
        img = img - torch.mean(img) # Centering data
        img = img / torch.norm(img) * np.sqrt(n) # Scaling data properly


        d_grad_true = img @ (img.T @ W)
        d_grad_gen = -1 * y_gen @ (y_gen.T @ W)

        g_w = disc_lr * (d_grad_true + d_grad_gen) / n

        gradient = W @ W.T @ y_gen @ c_gen.T
        n_g = gen_lr * gradient / (n * np.sqrt(n))

        V = V + n_g
        W = W + g_w
        W = torch_gs(W, row_vecs=False, norm=True) * np.sqrt(n)
        V = V / (torch.ones((n, 1)).to(device) * torch.sqrt(torch.sum(V ** 2, axis=0))) * np.sqrt(n)
        total_count += 1
        if j % 25 == 0 and i == 0:
            U, S, Vh = np.linalg.svd(V.detach().cpu().numpy())
            top_16 = U.T.reshape(-1, 64, 64)
            plt.figure()
            for i in range(16):
                plt.subplot(4, 4, i+1)
                plt.imshow(top_16[i, :, :], cmap="gray")
                plt.axis("off")

            plt.savefig(f"eigenfaces_gan_epoch_{j+1}.png")
        # if i % 100 == 0:
        #     disc_loss = torch.norm(img.T @ W) - torch.norm(y_gen.T @ W)
        #     gen_loss = torch.norm(y_gen.T @ W)
        #     pbar.set_description(f"Loss: {(disc_loss + gen_loss).item()}")

In [None]:
U, S, Vh = torch.svd(V)
top_16 = U.T.reshape(-1, 64, 64).detach().cpu().numpy()
plt.figure()
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(top_16[i, :, :], cmap="gray")
    plt.axis("off")

plt.show()
plt.savefig(f"eigenfaces.png")

In [None]:
# Generating images
gen_images = []
for i in range(16):
    y_gen, c_gen, a_gen = generate_data(V, 0, d, n, sigma, scale=True)
    gen_images.append(y_gen[:, 0])
gen_images = torch.stack(gen_images)
gen_images = gen_images.reshape(-1, 64, 64)
gen_images = gen_images.detach().cpu().numpy()

In [None]:
plt.figure()
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(gen_images[i, :, :], cmap="gray")
    plt.axis("off")

plt.show()

In [None]:
pca = PCA(n_components=d)
pca.fit(data.cpu().flatten(start_dim=1).numpy())
gt_subspace = pca.components_
gt_subspace = torch.from_numpy(gt_subspace)

In [None]:
V_oja = torch.randn((n, d), device=device).to(torch.float32)
V_oja = torch_gs(V_oja)

lr = 0.01

Qs = []
ts = []

total_count = 0

# gt_subspace = gt_subspace.cpu().numpy()

for j in tqdm(range(num_epochs)):
    # pbar = tqdm(range(data.shape[0]))
    for i in range(data.shape[0]):
        if total_count % 1000 == 0:
            Qs.append(grassmann_distance(V_oja.cpu().numpy(), gt_subspace.T))
            ts.append(total_count)
        img = data[i].unsqueeze(0).T#.cpu().numpy()
        img = img - torch.mean(img)
        img = img / torch.norm(img) * np.sqrt(n)
        w = torch.linalg.lstsq(V_oja, img).solution[:, 0]
        # w = torch.pinverse(V_oja) @ img
        V_oja = torch_gs(V_oja + lr / n * torch.outer(img[:, 0], w))
        total_count += 1
    
    if j % 25 == 0 and i == 0:
        U, S, Vh = np.linalg.svd(V_oja)
        top_16 = U.T.reshape(-1, 64, 64)
        plt.figure()
        for i in range(16):
            plt.subplot(4, 4, i+1)
            plt.imshow(top_16[i, :, :], cmap="gray")
            plt.axis("off")

        plt.savefig(f"eigenfaces_oja_epoch_{j+1}.png")

In [None]:
U, S, Vh = np.linalg.svd(V_oja)
top_16 = U.T.reshape(-1, 64, 64)
plt.figure()
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow(top_16[i, :, :], cmap="gray")
    plt.axis("off")

plt.show()
plt.savefig(f"eigenfaces_oja.png")

In [None]:
V_single = torch.randn((n, d)).to(torch.float32).to(device)
W_single = torch.randn((n, 1)).to(torch.float32).to(device)

V_single = torch_gs(V_single, row_vecs=False, norm=True) * np.sqrt(n)
W_single = torch_gs(W_single, row_vecs=False, norm=True) * np.sqrt(n)

Qs_gan_single = []
ts_gan_single = []

total_count = 0

for j in tqdm(range(num_epochs)):
    pbar = tqdm(range(data.shape[0]))
    for i in pbar:
        if total_count % 100 == 0:
            g_distance = grassmann_distance(V_single.detach().cpu().numpy(), gt_subspace.T)
            Qs_gan_single.append(g_distance)
            ts_gan_single.append(total_count)
            pbar.set_description(f"Grassmann distance: {g_distance}")
        y_gen, c_gen, a_gen = generate_data(V_single, eta, d, n, sigma, scale=True)

        img = data[i].unsqueeze(0).T
        img = img - torch.mean(img) # Centering data
        img = img / torch.norm(img) * np.sqrt(n) # Scaling data properly


        d_grad_true = img @ (img.T @ W_single)
        d_grad_gen = -1 * y_gen @ (y_gen.T @ W_single)

        g_w = disc_lr * (d_grad_true + d_grad_gen) / n

        gradient = W_single @ W_single.T @ y_gen @ c_gen.T
        n_g = gen_lr * gradient / (n * np.sqrt(n))

        V_single = V_single + n_g
        W_single = W_single + g_w
        W_single = torch_gs(W_single, row_vecs=False, norm=True) * np.sqrt(n)
        V_single = V_single / (torch.ones((n, 1)).to(device) * torch.sqrt(torch.sum(V_single ** 2, axis=0))) * np.sqrt(n)
        total_count += 1
        if j % 25 == 0 and i == 0:
            U, S, Vh = np.linalg.svd(V_single.detach().cpu().numpy())
            top_16 = U.T.reshape(-1, 64, 64)
            plt.figure()
            for i in range(16):
                plt.subplot(4, 4, i+1)
                plt.imshow(top_16[i, :, :], cmap="gray")
                plt.axis("off")

            plt.savefig(f"eigenfaces_gan_single_epoch_{j+1}.png")
        # if i % 100 == 0:
        #     disc_loss = torch.norm(img.T @ W) - torch.norm(y_gen.T @ W)
        #     gen_loss = torch.norm(y_gen.T @ W)
        #     pbar.set_description(f"Loss: {(disc_loss + gen_loss).item()}")

In [199]:
np.save("oja_ts_2.npy", ts)
np.save("oja_distances_2.npy", distances)

In [None]:
# ts = np.vstack(ts)
# distances = np.vstack(Qs)

ts = np.load("oja_ts.npy")
distances = np.load("oja_distances.npy")

ts_gan = np.vstack(ts_gan)
distances_gan = np.vstack(Qs_gan)

ts_gan_single = np.vstack(ts_gan_single)
Qs_gan_single = np.vstack(Qs_gan_single)

# np.save("oja_ts.npy", ts)
# np.save("oja_distances.npy", distances)

plt.locator_params(axis='y', nbins=6)
plt.locator_params(axis='x', nbins=5)

plt.tick_params(axis='both', which='major', labelsize=14)
plt.tick_params(axis='both', which='minor', labelsize=14)

plt.plot(ts, distances, label="Oja's", linewidth=4.0)
plt.plot(ts_gan, distances_gan, label="GAN Multi-feature", linewidth=4.0)
plt.plot(ts_gan_single, Qs_gan_single, label="GAN Single-feature", linewidth=4.0)
plt.legend(prop={'weight':'bold', 'size': 13})
plt.show()