In [None]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import os
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import math
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from vae_vanilla import VAE

In [None]:
def load_data_for_model(batch_size=64):
    ## Loading dataset
    labels = np.load('./data/LFSD_labels.npy')
    depths = np.load('./data/LFSD_depths_repeated.npy')
    imgs = np.load('./data/LFSD_imgs.npy')
    masks = np.load('./data/LFSD_masks_single.npy')
    idx = np.random.permutation(len(labels))
    batch_idxs = [idx[i:i + batch_size] for i in range(0, len(labels), batch_size)]
    dataset = []
    for idx in batch_idxs:
        img_batch, depth_batch, mask_batch = imgs[idx], depths[idx], masks[idx]
        dataset.append((img_batch, depth_batch, mask_batch))
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.3)
    print("Train dataset contains %d batches of %d samples each" % (len(train_dataset), batch_size))
    print("Test dataset contains %d batches of %d samples each" % (len(test_dataset), batch_size))
    return train_dataset, test_dataset

In [None]:
def merge_images(image_batch, size):
    h,w = image_batch.shape[1], image_batch.shape[2]
    img = np.zeros((int(h*size[0]), w*size[1]))
    for idx, im in enumerate(image_batch):
        im = np.squeeze(im, axis=2)
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w] = im
    return img

In [None]:
def train_round(train_dataset, test_dataset, learning_rate, model_name, epochs):
    latent_lookup = {
        'inception': 2048,
        'vgg': 512,
        'mobilenet': 1280,
        'resnet': 2048,
    }
    latent_dim = latent_lookup[model_name]
    vae = VAE(latent_dim)
    vae.compile(optimizer=Adam(learning_rate))
    epochs = epochs
    # Train Step
    losses_across_epochs = {
        "loss": [],
        "reconstruction_loss": [],
        "kl_loss": [],
    }
    batch_num = len(train_dataset)
    for i in range(epochs):
        print("Epoch %d: " % i)
        for k, v in losses_across_epochs.items():
            losses_across_epochs[k].append(0)
        for data in train_dataset:
            cur_loss = vae.train_step(data)
            for k, v in cur_loss.items():
                losses_across_epochs[k][-1] += cur_loss[k].numpy() / batch_num
            generated_image = vae.sample(data)
        print("Total loss: %d" % losses_across_epochs['loss'][-1])
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_vanilla/%d.png' % i, im_merged, cmap='gray')
    for k, v in losses_across_epochs.items():
        np.save('./results/vae_vanilla/%s' % k, np.array(v))

    # Testing Step
    test_loss = 0
    for i, data in enumerate(test_dataset):
        _, _, mask_batch = data
        generated_image = vae.sample(data)
        reconstruction_loss = tf.reduce_sum(
            tf.keras.losses.binary_crossentropy(mask_batch, generated_image), [1,2]
        )
        test_loss += tf.reduce_mean(reconstruction_loss).numpy()
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_vanilla/test_batch_%d.png' % i, im_merged, cmap='gray')
        
    test_loss = test_loss / len(test_dataset)
    np.save('./results/vae_vanilla/test_loss', np.array([test_loss]))
    return vae

In [None]:
def get_encoding_for_model(vae, model_name):
    from skimage.transform import resize
    from tensorflow.keras import datasets

    train_dataset, test_dataset = None, None

    (train_images, train_labels), (test_images, test_labels) = datasets.cifar100.load_data()
    result = []

    for img in tqdm(train_images):
        img = np.expand_dims(resize(img, (256, 256, 3)), 0)
        activation, _, _ = vae.encode(img, None, rand_depth=True)
        activation = np.squeeze(activation, axis=0)
        result.append(activation)

    np.save('./data/CIFAR100_vae_vanilla_%s_encoding_train.npy' % model_name, np.array(result))

    result = []
    for img in tqdm(test_images):
        img = np.expand_dims(resize(img, (256, 256, 3)), 0)
        activation, _, _ = new_vae.encode(img, None, rand_depth=True)
        activation = np.squeeze(activation, axis=0)
        result.append(activation)
    np.save('./data/CIFAR100_vae_vanilla_%s_encoding_test.npy' % model_name, np.array(result))

In [None]:
learning_rate = 1e-4
epochs = 1
for model_name in ['inception', 'vgg', 'resnet', 'mobilenet']:
    train_dataset, test_dataset = load_data_for_model()
    trained_model = train_round(train_dataset, test_dataset, learning_rate, model_name, epochs)
    print("Gen encoding...")
    get_encoding_for_model(trained_model, model_name)