In [None]:
from google.colab import drive
drive.mount('/content/gdrive')
import sys
sys.path.append('/content/gdrive/MyDrive/PROJECTS/vae')

In [None]:
!unzip /content/gdrive/MyDrive/DATASETS/Celeb_A/celeba-dataset.zip

In [None]:
from experiments.celebA.architecture import build_model
vae = build_model()
vae.load_weights('/content/gdrive/MyDrive/PROJECTS/vae/experiments/celebA/results/checkpoints/0024.tf')


In [None]:
# lets verify that the vae works: We simply try to generate a new human face
import tensorflow as tf
z = vae.prior_distribution.sample()
z = z[tf.newaxis, ...] # batchify
decoded = vae.decode(z)

In [None]:
import matplotlib.pyplot as plt
reconstruction = decoded.sample()
reconstruction = reconstruction[0]
plt.imshow(reconstruction)

In [None]:
import tensorflow as tf
import os

DATA_DIR = '/content/img_align_celeba/img_align_celeba'
BATCH_SIZE = 256

def preprocessing(x):
    x = x/255
    return (x,tf.constant([]))

dataset_train, dataset_val  = tf.keras.utils.image_dataset_from_directory(
    directory=DATA_DIR,
    label_mode=None,
    class_names=None,
    color_mode='rgb',
    batch_size=None,
    image_size=(128,128),
    shuffle=True,
    seed=42,
    validation_split = 0.05,
    subset = 'both',)

n_trainingssamples = len(dataset_train)

dataset_train = dataset_train.map(preprocessing).batch(BATCH_SIZE)
dataset_val = dataset_val.map(preprocessing).batch(BATCH_SIZE)

In [None]:
batch = dataset_val.as_numpy_iterator().next()[0]
elem = batch[0]
plt.imshow(elem)
plt.show()
r = vae(tf.constant([elem]))[0]
plt.imshow(r)
plt.show()

z = vae.get_latent_representation(tf.constant([elem]))[0]
z

In [None]:
import numpy as np
fig, axes = plt.subplots(15,15, figsize=(20, 20))
z = vae.prior_distribution.sample(15*15)
reconstruction = vae.get_reconstruction(z)
for r, ax, in zip(reconstruction, np.ravel(axes)):
    r = np.clip(r, 0, 1)
    ax.imshow(r)
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
img1 = dataset_val.as_numpy_iterator().next()[0][46]
img2 = dataset_val.as_numpy_iterator().next()[0][46+1]

fig, axes = plt.subplots(1, 20, figsize=(30, 6))

# interpolation in latent space
z1, z2 = vae.get_latent_representation(np.stack([img1, img2]))
interpolation_range = np.linspace(0,1,20)
for ax, fraction in zip(np.ravel(axes), interpolation_range):
    z = z1*(1-fraction) + z2*fraction
    r = vae.get_reconstruction(z[np.newaxis, ...])[0]
    r = np.clip(r, 0, 1)
    ax.imshow(r)
    ax.axis('off')
plt.show()

In [None]:
img1 = dataset_val.as_numpy_iterator().next()[0][46]
img2 = dataset_val.as_numpy_iterator().next()[0][46+1]

img1, img2 = vae(np.stack([img1, img2]))

fig, axes = plt.subplots(1, 20, figsize=(30, 6))

# interpolation pixel space
interpolation_range = np.linspace(0,1,20)
for ax, fraction in zip(np.ravel(axes), interpolation_range):
    img = img1*(1-fraction) + img2*fraction
    img = np.clip(img, 0, 1)
    ax.imshow(img)
    ax.axis('off')
plt.show()

In [None]:
import pandas as pd
df = pd.read_csv('/content/list_attr_celeba.csv')
print(df.shape)
display(df.head())
print(df.columns)

In [None]:
df_male = df.loc[df['Male']==1]
print(df_male.shape)

In [None]:
img = dataset_val.as_numpy_iterator().next()[0][42]

for i in range(3):
    fig, axes = plt.subplots(1, 10, figsize=(30,3))

    ax = axes[0]
    z = vae.get_latent_representation(img[np.newaxis, ...])[0]
    ax.bar(range(len(z)), z, width=1)

    for ax in axes[1:]:
        r = vae.get_reconstruction(z[np.newaxis, ...])[0]
        r = np.clip(r, 0, 1)
        ax.axis('off')
        ax.imshow(r)

    plt.show()

In [None]:
z = vae.prior_distribution.sample()
plt.bar(range(len(z)), z, width=1, color='red')
plt.show()


In [None]:
dataset_train = dataset_train.rebatch(2048)
batch = dataset_train.as_numpy_iterator().next()
latents = vae.get_latent_representation(batch)

In [None]:
import sklearn.decomposition as decomp
pca = decomp.PCA(n_components=2)

random_latents = vae.prior_distribution.sample(20000)# sample according to prior distribution
components = pca.fit_transform(random_latents)

import matplotlib.pyplot as plt
plt.scatter(components[:,0], components[:,1], s=0.1)

components = pca.transform(latents)
plt.scatter(components[:,0], components[:,1], s=0.1)
plt.show()

In [None]:
import os

# grab a batch of male images
BATCH_SIZE = 1024
DATADIR = '/content/img_align_celeba/img_align_celeba'

for filename in df_male['image_id'].iloc[:BATCH_SIZE]:
    imgs = tf.constant([plt.imread()])