In [None]:
import torch

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
from model import Discriminator, Generator, invert

In [None]:
import tensorflow as tf

## Load data

In [None]:
CHANNELS_IMG = 1

In [None]:
dataset = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = dataset.load_data()

In [None]:
IMAGE_SIZE = x_train.shape[-1]
num_classes = np.unique(y_train).shape[0]

In [None]:
size_train = x_train.shape[0]
size_test = x_test.shape[0]
scale = x_train.max()
x_train_scale = ((x_train / scale) - 0.5) / 0.5
x_test_scale = ((x_test / scale) - 0.5) / 0.5

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_train == ind
        axs[i,j].imshow(np.mean(x_train_scale[mask], axis=0))
plt.tight_layout()

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_test == ind
        axs[i,j].imshow(np.mean(x_test_scale[mask], axis=0))
plt.tight_layout()

## Load models (discriminator and generator)

In [None]:
FEATURES = 32
device = 'mps'

In [None]:
disc = Discriminator(features=FEATURES, channels_img=CHANNELS_IMG)
disc.load_state_dict(torch.load(f'dcgan_disc_2024-04-19_1341.pt'))
disc.to(device)
disc.eval();

In [None]:
NOISE_DIM = 100

In [None]:
gen = Generator(channels_noise=NOISE_DIM, features=FEATURES, channels_img=CHANNELS_IMG)
gen.load_state_dict(torch.load(f'dcgan_gen_2024-04-19_1341.pt'))
gen.to(device)
gen.eval();

## Inversion (one image)

In [None]:
img = x_test_scale[10:11]

In [None]:
z_one, x_gens_one, loss_one = invert(img, generator=gen, device=device, steps=1000)

In [None]:
plt.plot(loss_one)
plt.yscale('log')

In [None]:
fig, axs = plt.subplots(1,3,figsize=[15,5])
axs[0].imshow(img[0])
axs[1].imshow(x_gens_one[-1])
axs[2].imshow((img[0]-x_gens_one[-1])**2, vmin=0, vmax=4)

## Inversion (multiple images)

In [None]:
z, x_gens, loss = invert(x_test_scale, generator=gen, device=device, steps=1000)

In [None]:
plt.plot(loss)
plt.yscale('log')

In [None]:
ind = 2
fig, axs = plt.subplots(1,3,figsize=[15,5])
axs[0].imshow(x_test_scale[ind])
axs[1].imshow(x_gens[ind])
axs[2].imshow((x_test_scale[ind]-x_gens[ind])**2, vmin=0, vmax=4)

## Find mean/std latent vectors for each class

In [None]:
latent_mean = []
latent_std = []
for i in range (num_classes):
    latent_mean_i = np.mean(z[y_test==i], axis=0)
    latent_std_i = np.std(z[y_test==i], axis=0)
    latent_mean.append(latent_mean_i)
    latent_std.append(latent_std_i)
latent_mean = np.array(latent_mean)
latent_std = np.array(latent_std)

In [None]:
latent_mean.shape, latent_std.shape

In [None]:
fig, axs = plt.subplots(2,1,sharex=True)
axs[0].set_title('mean')
axs[0].imshow(latent_mean.reshape(num_classes, NOISE_DIM))
axs[1].set_title('std')
axs[1].imshow(latent_std.reshape(num_classes, NOISE_DIM))
plt.tight_layout()

In [None]:
gen_mean = gen(torch.Tensor(latent_mean).to(device)).cpu().detach().numpy().reshape(num_classes,IMAGE_SIZE,IMAGE_SIZE)

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        axs[i,j].imshow(gen_mean[ind])
plt.tight_layout()

In [None]:
fig, axs = plt.subplots(2,5, sharex=True, sharey=True)
for i in range (2):
    for j in range (5):
        ind = i*5+j
        mask = y_test == ind
        axs[i,j].imshow(np.mean(x_gens[mask], axis=0))
plt.tight_layout()