In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix

In [None]:
import torch

In [None]:
from model import Discriminator, Generator

In [None]:
import tensorflow as tf

## Load data

Importing torchvision is incompatible with sklearn in this environment. Load data from tensorflow.

In [None]:
CHANNELS_IMG = 1

In [None]:
dataset = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = dataset.load_data()

In [None]:
IMAGE_SIZE = x_train.shape[-1]
num_classes = np.unique(y_train).shape[0]

In [None]:
size_train = x_train.shape[0]
size_test = x_test.shape[0]
scale = x_train.max()
x_train_scale = ((x_train / scale) - 0.5) / 0.5
x_test_scale = ((x_test / scale) - 0.5) / 0.5

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_train == ind
        axs[i,j].imshow(np.mean(x_train_scale[mask], axis=0))
plt.tight_layout()

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_test == ind
        axs[i,j].imshow(np.mean(x_test_scale[mask], axis=0))
plt.tight_layout()

## Load models (discriminator and generator)

Pytorch models loaded onto cpus is incompatible with scikit learn in this environment. Load models onto gpu.

In [None]:
FEATURES = 32
device = 'mps'

In [None]:
disc = Discriminator(features=FEATURES, channels_img=CHANNELS_IMG)
disc.load_state_dict(torch.load(f'dcgan_disc_2024-04-19_1341.pt'))
disc.to(device)
disc.eval();

In [None]:
NOISE_DIM = 100

In [None]:
gen = Generator(channels_noise=NOISE_DIM, features=FEATURES, channels_img=CHANNELS_IMG)
gen.load_state_dict(torch.load(f'dcgan_gen_2024-04-19_1341.pt'))
gen.to(device)
gen.eval();

## Plot metrics

In [None]:
loss = pd.read_csv('dcgan_loss_2024-04-19_1341.csv', index_col=0)

In [None]:
fig, axs = plt.subplots(1,2,figsize=[10,5])
axs[0].plot(loss['Loss Disc Real'], label='loss disc real')
axs[0].plot(loss['Loss Disc Fake'], label='loss disc fake')
axs[0].plot(loss['Loss Gen'], label='loss gen')
axs[0].legend()
axs[1].plot(loss['Mean Disc Real'], label='mean disc real')
axs[1].plot(loss['Mean Disc Fake'], label='mean disc fake')
axs[1].plot(loss['Mean Disc Fake (Gen Training)'], label='mean disc fake (gen training)')
axs[1].legend()

## Generate samples

In [None]:
latent = torch.randn(size_test, NOISE_DIM, 1, 1).to(device)

In [None]:
x_gen_torch = gen(latent)
x_gen = x_gen_torch.cpu().detach().numpy().reshape(size_test, IMAGE_SIZE, IMAGE_SIZE)

In [None]:
n = 4
fig, axs = plt.subplots(n, n, figsize=[5,5], sharex=True, sharey=True)
for i in range (n):
    for j in range (n):
        ind = i*n+j
        axs[i, j].imshow(x_gen[ind])
plt.tight_layout()

In [None]:
disc_sample = disc(x_gen_torch).view(-1).cpu().detach()

In [None]:
plt.hist(disc_sample, bins=100)
plt.yscale('log')

## Train PCA-RF classifier

In [None]:
pca = PCA(n_components=0.9, whiten=True)

In [None]:
pca_train = pca.fit_transform(x_train_scale.reshape(size_train, IMAGE_SIZE**2))

In [None]:
rf = RandomForestClassifier(n_jobs=-1)

In [None]:
rf.fit(pca_train, y_train)

In [None]:
pca_test = pca.transform(x_test_scale.reshape(size_test, IMAGE_SIZE**2))

In [None]:
y_pred = rf.predict(pca_test)

In [None]:
cm = confusion_matrix(y_test, y_pred, normalize='true').round(2)

In [None]:
cm_display = ConfusionMatrixDisplay(cm).plot()

## Predict generated images with classifier

In [None]:
pca_gen = pca.transform(x_gen.reshape(size_test, IMAGE_SIZE**2))

In [None]:
y_gen = rf.predict(pca_gen)

In [None]:
plt.scatter(pca_train[:, 0], pca_train[:, 1], s=1, alpha=0.1, label='train')
plt.scatter(pca_test[:, 0], pca_test[:, 1], s=1, alpha=0.1, label='test')
plt.scatter(pca_gen[:, 0], pca_gen[:, 1], s=1, alpha=0.1, label='gen')
plt.legend()

In [None]:
plt.scatter(pca_train[:, 0], pca_train[:, 1], s=1, alpha=0.1, label='train', color='k')
for i in range (num_classes):
    mask = y_gen == i
    plt.scatter(pca_gen[:, 0][mask], pca_gen[:, 1][mask], s=1, alpha=1, label=i)
plt.legend()

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_gen == ind
        axs[i,j].imshow(np.mean(x_gen[mask], axis=0))
plt.tight_layout()

## Interpolation

In [None]:
latent = latent.cpu().numpy()

In [None]:
rand_ind0, rand_ind1 = np.random.randint(0, size_test, 2)
steps = 20
interp = np.linspace(latent[rand_ind0], latent[rand_ind1], steps)

In [None]:
gen_interp = gen(torch.Tensor(interp).to(device)).cpu().detach().numpy().reshape(steps,IMAGE_SIZE,IMAGE_SIZE)

In [None]:
fig, axs = plt.subplots(4,5, sharex=True, sharey=True)
for i in range (4):
    for j in range (5):
        ind = i*5+j
        axs[i,j].imshow(gen_interp[ind])
plt.tight_layout()

## "Eigen" vectors in latent space

In [None]:
factor = 10
id_matrix = torch.eye(NOISE_DIM).view(NOISE_DIM, NOISE_DIM, 1, 1).to(device) * factor

In [None]:
gen_id = gen(id_matrix).cpu().detach().numpy().reshape(NOISE_DIM, IMAGE_SIZE, IMAGE_SIZE)

In [None]:
fig, axs = plt.subplots(10,10,figsize=[10,10], sharex=True, sharey=True)
for i in range (10):
    for j in range (10):
        ind = i*10+j
        axs[i,j].imshow(gen_id[ind])
plt.tight_layout()

## Manifold

In [None]:
n_manifold = 11
x_min, x_max = -11, 11
y_min, y_max = -11, 11
gen_manifold_ind = []
gen_manifold = []
for i in np.linspace(y_max,y_min,n_manifold):
    for j in np.linspace(x_min,x_max,n_manifold):
        gen_manifold_ind.append([j, i])
        latent_m = torch.zeros(1, NOISE_DIM, 1, 1).to(device)
        latent_m[0, 0, 0, 0] = j
        latent_m[0, 1, 0, 0] = i
        gen_m = gen(latent_m).cpu().detach().numpy().reshape(IMAGE_SIZE, IMAGE_SIZE)
        gen_manifold.append(gen_m)
gen_manifold_ind = np.array(gen_manifold_ind)
gen_manifold = np.array(gen_manifold)

In [None]:
manifold = np.zeros((n_manifold*IMAGE_SIZE, n_manifold*IMAGE_SIZE))
for i in range (n_manifold):
    for j in range (n_manifold):
        ymin = i*IMAGE_SIZE
        ymax = (i+1)*IMAGE_SIZE
        xmin = j*IMAGE_SIZE
        xmax = (j+1)*IMAGE_SIZE
        manifold[ymin:ymax, xmin:xmax] = gen_manifold[i*n_manifold+j]

In [None]:
plt.figure(figsize=[10,10])
plt.imshow(manifold, extent=[x_min,x_max,y_min,y_max])

## Latent space vector addition

In [None]:
rand_ind0, rand_ind1 = np.random.randint(0, size_test, 2)

In [None]:
latent_add = torch.Tensor(latent[rand_ind0] + latent[rand_ind1]).reshape(1, NOISE_DIM, 1, 1).to(device)

In [None]:
gen_add = gen(latent_add).cpu().detach().numpy()

In [None]:
fig, axs = plt.subplots(1, 3)
axs[0].imshow(x_gen[rand_ind0])
axs[1].imshow(x_gen[rand_ind1])
axs[2].imshow(gen_add[0, 0])

## Find mean/std latent vectors for each class

In [None]:
latent_mean = []
latent_std = []
for i in range (num_classes):
    latent_mean_i = np.mean(latent[y_gen==i], axis=0)
    latent_std_i = np.std(latent[y_gen==i], axis=0)
    latent_mean.append(latent_mean_i)
    latent_std.append(latent_std_i)
latent_mean = np.array(latent_mean)
latent_std = np.array(latent_std)

In [None]:
latent_mean.shape, latent_std.shape

In [None]:
fig, axs = plt.subplots(2,1,sharex=True)
axs[0].set_title('mean')
axs[0].imshow(latent_mean.reshape(num_classes, NOISE_DIM))
axs[1].set_title('std')
axs[1].imshow(latent_std.reshape(num_classes, NOISE_DIM))
plt.tight_layout()

In [None]:
gen_mean = gen(torch.Tensor(latent_mean).to(device)).cpu().detach().numpy().reshape(num_classes,IMAGE_SIZE,IMAGE_SIZE)

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        axs[i,j].imshow(gen_mean[ind])
plt.tight_layout()

## Sample latent space around normal distribution for a class (naive conditional)

In [None]:
digit = 9
n_samples = 100
latent_class = []
for i in range (NOISE_DIM):
    noise_dimension = np.random.normal(latent_mean[digit, i, 0, 0], latent_std[digit, i, 0, 0], n_samples)
    latent_class.append(noise_dimension)
latent_class = np.array(latent_class).T.reshape(n_samples, NOISE_DIM, 1, 1)

In [None]:
gen_class = gen(torch.Tensor(latent_class).to(device)).cpu().detach().numpy().reshape(n_samples,IMAGE_SIZE,IMAGE_SIZE)

In [None]:
fig, axs = plt.subplots(10,10,figsize=[10,10], sharex=True, sharey=True)
for i in range (10):
    for j in range (10):
        ind = i*10+j
        axs[i,j].imshow(gen_class[ind])
plt.tight_layout()