In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import json

# Utilities
from PIL import Image
import glob # Khoa's stuff
from IPython import display
from tqdm import tqdm # for training progress

# Parameters
Change based on system

In [4]:
EPOCHS = 100 # Do not change unless testing
BATCH_SIZE = 100 # Based on system 
DATASET_PATH = "./dataset/" # Based on system

IMAGE_SHAPE = (64, 64, 3) # Should not change
LATENT_DIM = 32 # Do not change unless testing, the smaller the better

# Utilities Functions

In [5]:
def compare_and_show_images(original: np.ndarray, generated: np.ndarray, id: str, save: bool=True):
    """
    Compare and show original vs generated images.
    Only allow 2 images to be compared against 2 images.
    Allow for a visual check of the model's progress.
    Parameters:
        - original: numpy array containing the original images (shape (4, 64, 64, 3))
        - generated: numpy array containing the corresponding generated images (shape (4, 64, 64, 3))
        - id: something to save using
        - save: whether to save or not
    """
    fig = plt.figure(figsize=(2, 2))

    for i in range(4):
        plt.subplot(2, 2, i+1)
        plt.imshow(generated[i])
        plt.axis('off')

        plt.subplot(2, 2, i+3)
        plt.imshow(original[i])
        plt.axis('off')

    if save:
        plt.savefig("./generated_images/{}.png".format(id))
        plt.show()

# Load Data
Done in batches

In [6]:
from keras.utils import image_dataset_from_directory

x_train, x_test = image_dataset_from_directory(DATASET_PATH, labels=None, batch_size=BATCH_SIZE, image_size=IMAGE_SHAPE[:-1], validation_split=0.1, subset="both")

Show a few images from the dataset

In [None]:
fig = plt.figure(figsize=(2, 2))
print(x_train[0][0])
print(x_train[0][1])
print(x_train[0][2])
print(x_train[0][3])
for i in range(4):
  plt.subplot(2, 2, i+1)
  plt.imshow(x_train[0][0][i])
  plt.axis('off')

plt.show()

# Model Definition

In [None]:
from keras import models, layers, Model

### Encoder Model

In [None]:
encoder = models.Sequential([
    layers.InputLayer(input_shape=(64, 64, 3)),
    # Fill in the rest of the model
    layers.Flatten(),
    layers.Dense(LATENT_DIM, activation='sigmoid') # Number of features we're condensing down to
], name="face_encoder")

### Decoder Model

In [None]:
decoder = models.Sequential([
    layers.InputLayer(input_shape=(LATENT_DIM)),
    # Fill in the rest of the model
    layers.Conv2DTranspose(3, (5, 5), strides=1, padding='same', use_bias=False, activation='relu'),
    layers.BatchNormalization()
], name="face_decoder")

### Variational Autoencoder Subclassing
Allow for the use of the Keras model API

In [None]:
class Autoencoder(Model):
  def __init__(self, encoder_model, decoder_model):
    super().__init__()
    self.encoder = encoder_model
    self.decoder = decoder_model

  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

  def train_step(self, data):
    train_data = (data, data) # Im too lazy to customize dataset behaviour
    return super().train_step(train_data)

autoencoder = Autoencoder(encoder, decoder)
autoencoder.compile(optimizer='adam', loss='mse')

# Train Model

In [None]:
autoencoder.fit(x_train, x_train,
                epochs=EPOCHS,
                validation_data=(x_test, x_test))

# Test Model 

In [None]:
decoded_imgs = autoencoder.predict(x_test)

In [None]:
x_test_iterator = x_test.as_numpy_iterator()
test_batch = next(x_test_iterator)

n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
  # display original
  ax = plt.subplot(2, n, i + 1)
  plt.imshow(test_batch[i])
  plt.title("original")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)

  # display reconstruction
  ax = plt.subplot(2, n, i + 1 + n)
  plt.imshow(decoded_imgs[i])
  plt.title("reconstructed")
  plt.gray()
  ax.get_xaxis().set_visible(False)
  ax.get_yaxis().set_visible(False)
plt.show()