In [1]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import os

import tensorflow as tf
print(f"gpu: { len(tf.config.list_physical_devices('GPU')) }")
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)
    
from tensorflow.keras.optimizers import Adam
import math
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from vae_adapted import VAEAdapted

gpu: 1


In [2]:
def load_data_for_model(model_name, batch_size=64):
    labels = np.load('./data/LFSD_labels.npy')
    depths = np.load('./data/LFSD_depths_repeated_%s_feat.npy' % model_name)
    imgs = np.load('./data/LFSD_imgs_%s_feat.npy' % model_name)
    masks = np.load('./data/LFSD_masks_single.npy')
    idx = np.random.permutation(len(labels))
    batch_idxs = [idx[i:i + batch_size] for i in range(0, len(labels), batch_size)]
    dataset = []
    for idx in batch_idxs:
        img_batch, depth_batch, mask_batch = imgs[idx], depths[idx], masks[idx]
        dataset.append((img_batch, depth_batch, mask_batch))
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.3)
    print("Train dataset contains %d batches of %d samples each" % (len(train_dataset), batch_size))
    print("Test dataset contains %d batches of %d samples each" % (len(test_dataset), batch_size))
    return train_dataset, test_dataset

In [3]:
def merge_images(image_batch, size):
    h,w = image_batch.shape[1], image_batch.shape[2]
    img = np.zeros((int(h*size[0]), w*size[1]))
    for idx, im in enumerate(image_batch):
        im = np.squeeze(im, axis=2)
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w] = im
    return img

In [4]:
def train_round(train_dataset, test_dataset, learning_rate, model_name, epochs):
    latent_lookup = {
        'inception': 2048,
        'vgg': 512,
        'efficientnet': 1280,
        'mobilenet': 1280,
        'resnet': 2048,
    }
    latent_dim = latent_lookup[model_name]
    vae = VAEAdapted(latent_dim)
    vae.compile(optimizer=Adam(learning_rate))
    epochs = epochs
    # Training Step
    losses_across_epochs = {
        "loss": [],
        "reconstruction_loss": [],
        "kl_loss": [],
    }
    batch_num = len(train_dataset)
    for i in range(epochs):

        for k, v in losses_across_epochs.items():
            losses_across_epochs[k].append(0)
        for data in train_dataset:
            cur_loss = vae.train_step(data)
            for k, v in cur_loss.items():
                losses_across_epochs[k][-1] += cur_loss[k].numpy() / batch_num
            generated_image = vae.sample(data)
        print(f"Epoch {i} Total loss: { losses_across_epochs['loss'][-1]}")
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_adapted/%d.png' % i, im_merged, cmap='gray')
    for k, v in losses_across_epochs.items():
        np.save('./results/vae_adapted/%s_%s' % (model_name, k), np.array(v))
    # Testing Step
    test_loss = 0
    for i, data in enumerate(test_dataset):
        _, _, mask_batch = data
        generated_image = vae.sample(data)
        reconstruction_loss = tf.reduce_sum(
            tf.keras.losses.binary_crossentropy(mask_batch, generated_image), [1,2]
        )
        test_loss += tf.reduce_mean(reconstruction_loss).numpy()
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_adapted/test_batch_%d.png' % i, im_merged, cmap='gray')
        
    test_loss = test_loss / len(test_dataset)
    np.save('./results/vae_adapted/%s_test_loss' % model_name, np.array([test_loss]))
    return vae

In [5]:
def get_encoding_for_model(vae, model_name):
    from skimage.transform import resize
    from tensorflow.keras import datasets

    train_dataset, test_dataset = None, None
 
    train_feats = np.load('./data/CIFAR100_%s_train_feat.npy' % model_name)
    test_feats = np.load('./data/CIFAR100_%s_test_feat.npy' % model_name)

    train_result, _, _ = vae.encode(train_feats[:128], tf.random.normal(train_feats[:128].shape))
    for i in range(128, len(train_feats), 128):
        activation, _, _ = vae.encode(train_feats[i:i+128], tf.random.normal(train_feats[i:i+128].shape))
        train_result = tf.concat((train_result, activation), axis=0)
    np.save('./data/CIFAR100_vae_adapted_%s_encoding_train.npy' % model_name, np.array(train_result))

    test_result, _, _ = vae.encode(test_feats[:128], tf.random.normal(test_feats[:128].shape))
    for i in range(128, len(test_feats), 128):
        activation, _, _ = vae.encode(test_feats[i:i+128], tf.random.normal(test_feats[i:i+128].shape))
        test_result = tf.concat((test_result, activation), axis=0)
    np.save('./data/CIFAR100_vae_adapted_%s_encoding_test.npy' % model_name, np.array(test_result))


In [6]:
%%time

learning_rate = 1e-4
epochs = 100
for model_name in [ 'mobilenet', 'efficientnet','inception', 'resnet', 'vgg']:
    train_dataset, test_dataset = load_data_for_model(model_name)
    trained_model = train_round(train_dataset, test_dataset, learning_rate, model_name, epochs)
    print(f" {model_name} Gen encoding... ")
    get_encoding_for_model(trained_model, model_name)

Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2633721.861111111
Epoch 1 Total loss: 2633065.944444445
Epoch 2 Total loss: 2609566.1666666665
Epoch 3 Total loss: 2525977.916666667
Epoch 4 Total loss: 2416643.388888889
Epoch 5 Total loss: 2312822.8333333335
Epoch 6 Total loss: 2225860.583333333
Epoch 7 Total loss: 2155152.0555555555
Epoch 8 Total loss: 2097084.5833333335
Epoch 9 Total loss: 2048095.6944444445
Epoch 10 Total loss: 2006120.4444444445
Epoch 11 Total loss: 1970055.4027777778
Epoch 12 Total loss: 1938567.7222222225
Epoch 13 Total loss: 1910380.125
Epoch 14 Total loss: 1885724.4166666667
Epoch 15 Total loss: 1863045.0416666665
Epoch 16 Total loss: 1840966.9722222222
Epoch 17 Total loss: 1820218.9444444445
Epoch 18 Total loss: 1798372.3888888888
Epoch 19 Total loss: 1777743.0694444445
Epoch 20 Total loss: 1757126.3333333333
Epoch 21 Total loss: 1735857.4166666665
Epoch 22 Total loss: 1720367.125
Epoch 

Epoch 0 Total loss: 2920473.083333333
Epoch 1 Total loss: 2905832.5833333335
Epoch 2 Total loss: 2860749.861111112
Epoch 3 Total loss: 2758303.888888889
Epoch 4 Total loss: 2635412.861111111
Epoch 5 Total loss: 2522940.583333334
Epoch 6 Total loss: 2430036.722222222
Epoch 7 Total loss: 2354925.333333333
Epoch 8 Total loss: 2293664.277777778
Epoch 9 Total loss: 2243189.555555556
Epoch 10 Total loss: 2201238.8333333335
Epoch 11 Total loss: 2165644.8055555555
Epoch 12 Total loss: 2135420.0555555555
Epoch 13 Total loss: 2109193.611111111
Epoch 14 Total loss: 2086165.6666666665
Epoch 15 Total loss: 2065987.8333333335
Epoch 16 Total loss: 2048170.1805555555
Epoch 17 Total loss: 2032212.6666666665
Epoch 18 Total loss: 2017924.9722222222
Epoch 19 Total loss: 2004957.9305555553
Epoch 20 Total loss: 1993137.8611111112
Epoch 21 Total loss: 1982449.4861111115
Epoch 22 Total loss: 1972570.6805555555
Epoch 23 Total loss: 1963318.125
Epoch 24 Total loss: 1954890.972222222
Epoch 25 Total loss: 1947140

Epoch 3 Total loss: 13423144618.666666
Epoch 4 Total loss: 10456446520.88889
Epoch 5 Total loss: 8567547278.222221
Epoch 6 Total loss: 7258502371.555556
Epoch 7 Total loss: 6297508750.222222
Epoch 8 Total loss: 5561894798.222222
Epoch 9 Total loss: 4980616305.777778
Epoch 10 Total loss: 4509669148.444445
Epoch 11 Total loss: 4120333937.7777786
Epoch 12 Total loss: 3793067719.1111116
Epoch 13 Total loss: 3514107050.666667
Epoch 14 Total loss: 3273479196.444444
Epoch 15 Total loss: 3063779953.7777777
Epoch 16 Total loss: 2879399367.1111116
Epoch 17 Total loss: 2716010467.5555553
Epoch 18 Total loss: 2570218097.7777777
Epoch 19 Total loss: 2439323989.333333
Epoch 20 Total loss: 2321153934.2222223
Epoch 21 Total loss: 2213937066.666667
Epoch 22 Total loss: 2116216803.5555556
Epoch 23 Total loss: 2026783843.5555553
Epoch 24 Total loss: 1944625720.8888888
Epoch 25 Total loss: 1868889230.2222223
Epoch 26 Total loss: 1798849294.2222223
Epoch 27 Total loss: 1733885824.0000002
Epoch 28 Total los