In [1]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import os
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

from tensorflow.keras.optimizers import Adam
import math
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from vae_vanilla import VAE

In [2]:
def load_data_for_model(batch_size=64):
    ## Loading dataset
    labels = np.load('./data/LFSD_labels.npy')
    depths = np.load('./data/LFSD_depths_repeated.npy')
    imgs = np.load('./data/LFSD_imgs.npy')
    masks = np.load('./data/LFSD_masks_single.npy')
    idx = np.random.permutation(len(labels))
    batch_idxs = [idx[i:i + batch_size] for i in range(0, len(labels), batch_size)]
    dataset = []
    for idx in batch_idxs:
        img_batch, depth_batch, mask_batch = imgs[idx], depths[idx], masks[idx]
        dataset.append((img_batch, depth_batch, mask_batch))
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.3)
    print("Train dataset contains %d batches of %d samples each" % (len(train_dataset), batch_size))
    print("Test dataset contains %d batches of %d samples each" % (len(test_dataset), batch_size))
    return train_dataset, test_dataset

In [3]:
def merge_images(image_batch, size):
    h,w = image_batch.shape[1], image_batch.shape[2]
    img = np.zeros((int(h*size[0]), w*size[1]))
    for idx, im in enumerate(image_batch):
        im = np.squeeze(im, axis=2)
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w] = im
    return img

In [4]:
def train_round(train_dataset, test_dataset, learning_rate, model_name, epochs):
    latent_lookup = {
        'inception': 2048,
        'efficientnet': 1280,
        'vgg': 512,
        'mobilenet': 1280,
        'resnet': 2048,
    }
    latent_dim = latent_lookup[model_name]
    vae = VAE(latent_dim)
    vae.compile(optimizer=Adam(learning_rate))
    epochs = epochs
    # Train Step
    losses_across_epochs = {
        "loss": [],
        "reconstruction_loss": [],
        "kl_loss": [],
    }
    batch_num = len(train_dataset)
    for i in range(epochs):
        for k, v in losses_across_epochs.items():
            losses_across_epochs[k].append(0)
        for data in train_dataset:
            cur_loss = vae.train_step(data)
            for k, v in cur_loss.items():
                losses_across_epochs[k][-1] += cur_loss[k].numpy() / batch_num
            generated_image = vae.sample(data)
        print(f"Epoch {i} Total loss: { losses_across_epochs['loss'][-1]}")
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_vanilla/%d.png' % i, im_merged, cmap='gray')
    for k, v in losses_across_epochs.items():
        
        np.save('./results/vae_vanilla/%s_%s' % (model_name, k), np.array(v))

    # Testing Step
    test_loss = 0
    for i, data in enumerate(test_dataset):
        _, _, mask_batch = data
        generated_image = vae.sample(data)
        reconstruction_loss = tf.reduce_sum(
            tf.keras.losses.binary_crossentropy(mask_batch, generated_image), [1,2]
        )
        test_loss += tf.reduce_mean(reconstruction_loss).numpy()
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_vanilla/test_batch_%d.png' % i, im_merged, cmap='gray')
        
    test_loss = test_loss / len(test_dataset)
    np.save('./results/vae_vanilla/%s_test_loss' % model_name, np.array([test_loss]))
    return vae

In [5]:
def get_encoding_for_model(vae, model_name):
    from skimage.transform import resize
    from tensorflow.keras import datasets

    train_dataset, test_dataset = None, None

    (train_images, train_labels), (test_images, test_labels) = datasets.cifar100.load_data()
    train_images = train_images / 255.0
    test_images = test_images / 255.0
    with tf.device('/cpu:0'):
        train_images = tf.image.resize(train_images, (256, 256)).numpy()
        test_images = tf.image.resize(test_images, (256, 256)).numpy()
    
    train_result, _, _ = vae.encode(train_images[:128], None, rand_depth=True)
    for i in tqdm(range(128, len(train_images), 128)):
        img = train_images[i:i+128]
        activation, _, _ = vae.encode(img, None, rand_depth=True)
        train_result = tf.concat((train_result, activation), axis=0)        
    np.save('./data/CIFAR100_vae_vanilla_%s_encoding_train.npy' % model_name, train_result.numpy())

    test_result, _, _ = vae.encode(test_images[:128], None, rand_depth=True)
    for i in tqdm(range(128, len(test_images), 128)):
        img = train_images[i:i+128]
        activation, _, _ = vae.encode(img, None, rand_depth=True)
        train_result = tf.concat((test_result, activation), axis=0)  
    np.save('./data/CIFAR100_vae_vanilla_%s_encoding_test.npy' % model_name, test_result.numpy())

In [7]:
%%time
learning_rate = 1e-4
epochs = 100
for model_name in ['efficientnet','inception', 'resnet', 'mobilenet', 'vgg']:
    train_dataset, test_dataset = load_data_for_model()
    trained_model = train_round(train_dataset, test_dataset, learning_rate, model_name, epochs)
    print("Gen encoding...")
    get_encoding_for_model(trained_model, model_name)

Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2908242.444444444
Epoch 1 Total loss: 2904333.444444445
Epoch 2 Total loss: 2880892.8333333335
Epoch 3 Total loss: 2814354.833333333
Epoch 4 Total loss: 2713130.5277777775
Epoch 5 Total loss: 2587788.6111111115
Epoch 6 Total loss: 2447528.055555556
Epoch 7 Total loss: 2308516.638888889
Epoch 8 Total loss: 2182060.6666666665
Epoch 9 Total loss: 2067482.4166666665
Epoch 10 Total loss: 1963549.944444444
Epoch 11 Total loss: 1868865.5694444445
Epoch 12 Total loss: 1782617.4999999998
Epoch 13 Total loss: 1703761.6944444443
Epoch 14 Total loss: 1631415.2638888888
Epoch 15 Total loss: 1565192.361111111
Epoch 16 Total loss: 1504388.9305555555
Epoch 17 Total loss: 1448474.5972222222
Epoch 18 Total loss: 1396794.8333333335
Epoch 19 Total loss: 1348848.5000000002
Epoch 20 Total loss: 1304495.388888889
Epoch 21 Total loss: 1263370.4583333335
Epoch 22 Total loss: 1225446.555555

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:09<00:00, 42.65it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 50.46it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2881126.1944444445
Epoch 1 Total loss: 2709657.8055555555
Epoch 2 Total loss: 2616610.361111111
Epoch 3 Total loss: 2487864.2777777775
Epoch 4 Total loss: 2347294.361111111
Epoch 5 Total loss: 2203783.638888889
Epoch 6 Total loss: 2065879.1111111112
Epoch 7 Total loss: 1939857.4861111112
Epoch 8 Total loss: 1828539.4027777775
Epoch 9 Total loss: 1729399.7916666665
Epoch 10 Total loss: 1638541.8194444445
Epoch 11 Total loss: 1555953.3888888888
Epoch 12 Total loss: 1480860.3194444443
Epoch 13 Total loss: 1412535.2083333333
Epoch 14 Total loss: 1350362.5833333335
Epoch 15 Total loss: 1293820.375
Epoch 16 Total loss: 1242230.5833333333
Epoch 17 Total loss: 1195516.9444444443
Epoch 18 Total loss: 1152755.8611111112
Epoch 19 Total loss: 1113450.861111111
Epoch 20 Total loss: 1077072.9722222222
Epoch 21 Total loss: 1043414.1875
Epoch 22 Total loss: 1012230.9583333333
Epoc

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:07<00:00, 55.26it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 60.14it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2849275.222222222
Epoch 1 Total loss: 2702035.388888889
Epoch 2 Total loss: 2643367.277777778
Epoch 3 Total loss: 2551540.6111111115
Epoch 4 Total loss: 2434087.583333333
Epoch 5 Total loss: 2301143.5
Epoch 6 Total loss: 2162925.0694444445
Epoch 7 Total loss: 2030672.2083333333
Epoch 8 Total loss: 1911214.708333333
Epoch 9 Total loss: 1802713.0
Epoch 10 Total loss: 1703699.5694444445
Epoch 11 Total loss: 1614014.5694444447
Epoch 12 Total loss: 1532890.5972222225
Epoch 13 Total loss: 1459589.5416666665
Epoch 14 Total loss: 1393745.2222222222
Epoch 15 Total loss: 1334093.375
Epoch 16 Total loss: 1280136.0138888888
Epoch 17 Total loss: 1231018.027777778
Epoch 18 Total loss: 1185684.125
Epoch 19 Total loss: 1143983.3055555555
Epoch 20 Total loss: 1105538.9722222222
Epoch 21 Total loss: 1070003.9305555557
Epoch 22 Total loss: 1037170.4791666667
Epoch 23 Total loss: 1006

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:06<00:00, 56.50it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 56.85it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2908789.9444444445
Epoch 1 Total loss: 2902572.6388888895
Epoch 2 Total loss: 2862036.0555555555
Epoch 3 Total loss: 2754722.25
Epoch 4 Total loss: 2614886.0277777775
Epoch 5 Total loss: 2469485.805555556
Epoch 6 Total loss: 2325638.972222222
Epoch 7 Total loss: 2189006.0277777775
Epoch 8 Total loss: 2065064.6388888888
Epoch 9 Total loss: 1952100.7777777778
Epoch 10 Total loss: 1849864.5277777778
Epoch 11 Total loss: 1757247.2777777778
Epoch 12 Total loss: 1673148.2083333333
Epoch 13 Total loss: 1596659.8194444445
Epoch 14 Total loss: 1526956.375
Epoch 15 Total loss: 1463066.513888889
Epoch 16 Total loss: 1404564.5138888888
Epoch 17 Total loss: 1350951.6805555555
Epoch 18 Total loss: 1301594.2083333333
Epoch 19 Total loss: 1256026.8055555555
Epoch 20 Total loss: 1213935.5555555555
Epoch 21 Total loss: 1174988.5138888888
Epoch 22 Total loss: 1138869.9444444445
Epoch

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:08<00:00, 44.79it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 50.17it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2178800.986111111
Epoch 1 Total loss: 2558555.611111111
Epoch 2 Total loss: 2558074.9166666665
Epoch 3 Total loss: 2478587.277777778
Epoch 4 Total loss: 2368469.416666667
Epoch 5 Total loss: 2257922.9444444445
Epoch 6 Total loss: 2154644.0
Epoch 7 Total loss: 2057426.6944444445
Epoch 8 Total loss: 1965483.6249999998
Epoch 9 Total loss: 1879958.2777777778
Epoch 10 Total loss: 1800970.6111111108
Epoch 11 Total loss: 1727764.611111111
Epoch 12 Total loss: 1659684.7916666667
Epoch 13 Total loss: 1596062.986111111
Epoch 14 Total loss: 1536235.6249999998
Epoch 15 Total loss: 1479838.4861111112
Epoch 16 Total loss: 1426575.402777778
Epoch 17 Total loss: 1376353.972222222
Epoch 18 Total loss: 1329338.8194444445
Epoch 19 Total loss: 1285287.1666666665
Epoch 20 Total loss: 1244403.0972222222
Epoch 21 Total loss: 1206326.4166666667
Epoch 22 Total loss: 1170670.4166666667
Epoc

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:07<00:00, 53.07it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 65.94it/s]


Wall time: 16min 21s
