In [1]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import os

import tensorflow as tf
print(f"gpu: { len(tf.config.list_physical_devices('GPU')) }")
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True)

from skimage.transform import resize
from tensorflow.keras import datasets


from tensorflow.keras.optimizers import Adam
import math
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from vae_injected import VAEInjected

gpu: 1


In [2]:
def load_data_for_model(model_name, batch_size=64):
    labels = np.load('./data/LFSD_labels.npy')
    depths = np.load('./data/LFSD_depths_repeated.npy')
    imgs = np.load('./data/LFSD_imgs.npy')
    masks = np.load('./data/LFSD_masks_single.npy')
    depths_feat = np.load('./data/LFSD_depths_repeated_%s_feat.npy' % model_name)
    imgs_feat = np.load('./data/LFSD_imgs_%s_feat.npy' % model_name)
    idx = np.random.permutation(len(labels))
    batch_idxs = [idx[i:i + batch_size] for i in range(0, len(labels), batch_size)]
    dataset = []
    for idx in batch_idxs:
        img_batch, depth_batch, mask_batch = imgs[idx], depths[idx], masks[idx]
        img_feat_batch, depth_feat_batch = imgs_feat[idx], depths_feat[idx]
        dataset.append((img_batch, img_feat_batch, depth_batch, depth_feat_batch, mask_batch))
    train_dataset, test_dataset = train_test_split(dataset, test_size=0.3)
    print("Train dataset contains %d batches of %d samples each" % (len(train_dataset), batch_size))
    print("Test dataset contains %d batches of %d samples each" % (len(test_dataset), batch_size))
    return train_dataset, test_dataset

In [3]:
def merge_images(image_batch, size):
    h,w = image_batch.shape[1], image_batch.shape[2]
    img = np.zeros((int(h*size[0]), w*size[1]))
    for idx, im in enumerate(image_batch):
        im = np.squeeze(im, axis=2)
        i = idx % size[1]
        j = idx // size[1]
        img[j*h:j*h+h, i*w:i*w+w] = im
    return img

In [4]:
def train_round(train_dataset, test_dataset, learning_rate, model_name, epochs):
    latent_lookup = {
        'inception': 2048,
        'efficientnet': 1280,
        'vgg': 512,
        'mobilenet': 1280,
        'resnet': 2048,
    }
    latent_dim = latent_lookup[model_name]
    vae = VAEInjected(latent_dim)
    vae.compile(optimizer=Adam(learning_rate))
    epochs = epochs
    # Training Step
    losses_across_epochs = {
        "loss": [],
        "reconstruction_loss": [],
        "kl_loss": [],
    }
    batch_num = len(train_dataset)
    for i in range(epochs):
       
        for k, v in losses_across_epochs.items():
            losses_across_epochs[k].append(0)
        for data in train_dataset:
            cur_loss = vae.train_step(data)
            for k, v in cur_loss.items():
                losses_across_epochs[k][-1] += cur_loss[k].numpy() / batch_num
            generated_image = vae.sample(data)
        print(f"Epoch {i} Total loss: { losses_across_epochs['loss'][-1]}")
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_injected/%d.png' % i, im_merged, cmap='gray')
    for k, v in losses_across_epochs.items():
        np.save('./results/vae_injected/%s_%s' % (model_name, k), np.array(v))
    # Testing Step
    test_loss = 0
    for i, data in enumerate(test_dataset):
        _, _, _, _, mask_batch = data
        generated_image = vae.sample(data)
        reconstruction_loss = tf.reduce_sum(
            tf.keras.losses.binary_crossentropy(mask_batch, generated_image), [1,2]
        )
        test_loss += tf.reduce_mean(reconstruction_loss).numpy()
        im_merged = merge_images(generated_image.numpy(), [8,8])
        plt.imsave('./images/vae_injected/test_batch_%d.png' % i, im_merged, cmap='gray')
        
    test_loss = test_loss / len(test_dataset)
    np.save('./results/vae_injected/%s_test_loss' % model_name, np.array([test_loss]))
    return vae

In [5]:
def get_encoding_for_model(vae, model_name):

    train_dataset, test_dataset = None, None ## Freeing 
    (train_images, train_labels), (test_images, test_labels) = datasets.cifar100.load_data()
    train_images = train_images / 255.0
    test_images = test_images / 255.0
    with tf.device('/cpu:0'):
        train_images = tf.image.resize(train_images, (256, 256)).numpy()
        test_images = tf.image.resize(test_images, (256, 256)).numpy()
    train_feats = np.load('./data/CIFAR100_%s_train_feat.npy' % model_name)
    test_feats = np.load('./data/CIFAR100_%s_test_feat.npy' % model_name)

    train_result, _, _ = vae.encode((train_images[:128], train_feats[:128], None, None, None), rand_depth=True)
    for i in tqdm(range(128, len(train_images), 128)):
        img = train_images[i:i+128]
        img_feat = train_feats[i:i+128]
        activation, _, _ = vae.encode((img, img_feat, None, None, None), rand_depth=True)
        train_result = tf.concat((train_result, activation), axis=0)
    np.save('./data/CIFAR100_vae_injected_%s_encoding_train.npy' % model_name, train_result.numpy())

    test_result, _, _ = vae.encode((test_images[:128], test_feats[:128], None, None, None), rand_depth=True)
    for i in tqdm(range(128, len(test_images), 128)):
        img = test_images[i:i+128]
        img_feat = test_feats[i:i+128]
        activation, _, _ = vae.encode((img, img_feat, None, None, None), rand_depth=True)
        test_result = tf.concat((test_result, activation), axis=0)
    np.save('./data/CIFAR100_vae_injected_%s_encoding_test.npy' % model_name, test_result.numpy())

In [6]:
%%time
learning_rate = 1e-4
epochs = 100

for model_name in [ 'mobilenet', 'efficientnet','inception', 'resnet', 'vgg']:
    train_dataset, test_dataset = load_data_for_model(model_name)
    trained_model = train_round(train_dataset, test_dataset, learning_rate, model_name, epochs)
    print(f" {model_name} Gen encoding... ")
    get_encoding_for_model(trained_model, model_name)

Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2918101.0
Epoch 1 Total loss: 2908452.4444444445
Epoch 2 Total loss: 2885707.3333333335
Epoch 3 Total loss: 2805139.388888889
Epoch 4 Total loss: 2684355.8055555555
Epoch 5 Total loss: 2558966.0
Epoch 6 Total loss: 2445195.388888889
Epoch 7 Total loss: 2335355.638888889
Epoch 8 Total loss: 2228063.8055555555
Epoch 9 Total loss: 2124560.4444444445
Epoch 10 Total loss: 2025443.1666666665
Epoch 11 Total loss: 1931876.3611111115
Epoch 12 Total loss: 1845497.8194444445
Epoch 13 Total loss: 1766921.666666667
Epoch 14 Total loss: 1695316.8750000002
Epoch 15 Total loss: 1629795.7916666665
Epoch 16 Total loss: 1569392.4027777778
Epoch 17 Total loss: 1513202.138888889
Epoch 18 Total loss: 1461356.6111111112
Epoch 19 Total loss: 1413029.0138888888
Epoch 20 Total loss: 1367807.9583333333
Epoch 21 Total loss: 1325614.875
Epoch 22 Total loss: 1286084.6249999998
Epoch 23 Total lo

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:08<00:00, 45.54it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 40.73it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2868214.4166666665
Epoch 1 Total loss: 2710182.6111111115
Epoch 2 Total loss: 2681330.694444444
Epoch 3 Total loss: 2639985.8055555555
Epoch 4 Total loss: 2561795.0833333335
Epoch 5 Total loss: 2463370.277777778
Epoch 6 Total loss: 2365100.0555555555
Epoch 7 Total loss: 2272080.4444444445
Epoch 8 Total loss: 2180643.3888888885
Epoch 9 Total loss: 2090913.736111111
Epoch 10 Total loss: 2003916.1944444447
Epoch 11 Total loss: 1921397.2499999998
Epoch 12 Total loss: 1843086.4027777778
Epoch 13 Total loss: 1772001.4027777778
Epoch 14 Total loss: 1708900.861111111
Epoch 15 Total loss: 1649659.666666667
Epoch 16 Total loss: 1592867.2222222222
Epoch 17 Total loss: 1539536.1944444445
Epoch 18 Total loss: 1490126.986111111
Epoch 19 Total loss: 1443010.027777778
Epoch 20 Total loss: 1398757.4305555555
Epoch 21 Total loss: 1357136.4444444443
Epoch 22 Total loss: 1318104.36111

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:07<00:00, 52.70it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 61.44it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2911092.388888889
Epoch 1 Total loss: 2883826.7222222225
Epoch 2 Total loss: 2782636.4722222225
Epoch 3 Total loss: 2642597.0277777775
Epoch 4 Total loss: 2504549.3055555555
Epoch 5 Total loss: 2372569.1944444445
Epoch 6 Total loss: 2245844.222222222
Epoch 7 Total loss: 2121757.861111111
Epoch 8 Total loss: 2005185.6666666665
Epoch 9 Total loss: 1896105.611111111
Epoch 10 Total loss: 1797391.6666666667
Epoch 11 Total loss: 1710089.0416666665
Epoch 12 Total loss: 1633626.5972222222
Epoch 13 Total loss: 1564585.9166666667
Epoch 14 Total loss: 1502588.5972222222
Epoch 15 Total loss: 1445433.5416666665
Epoch 16 Total loss: 1392201.4583333333
Epoch 17 Total loss: 1343143.1527777778
Epoch 18 Total loss: 1298205.652777778
Epoch 19 Total loss: 1256928.0694444445
Epoch 20 Total loss: 1218908.513888889
Epoch 21 Total loss: 1182657.6944444445
Epoch 22 Total loss: 1148274.4722

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:06<00:00, 57.73it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 66.02it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2900884.194444444
Epoch 1 Total loss: 2724581.5
Epoch 2 Total loss: 2664438.027777778
Epoch 3 Total loss: 2581711.472222222
Epoch 4 Total loss: 2476153.111111111
Epoch 5 Total loss: 2358423.833333333
Epoch 6 Total loss: 2239801.9444444445
Epoch 7 Total loss: 2123851.6805555555
Epoch 8 Total loss: 2011996.3888888892
Epoch 9 Total loss: 1906332.5555555557
Epoch 10 Total loss: 1808465.8194444445
Epoch 11 Total loss: 1719791.6805555555
Epoch 12 Total loss: 1640715.7916666667
Epoch 13 Total loss: 1569707.138888889
Epoch 14 Total loss: 1505413.486111111
Epoch 15 Total loss: 1447182.0138888888
Epoch 16 Total loss: 1394902.5
Epoch 17 Total loss: 1349127.4722222222
Epoch 18 Total loss: 1305897.013888889
Epoch 19 Total loss: 1264384.1944444445
Epoch 20 Total loss: 1225013.736111111
Epoch 21 Total loss: 1187999.611111111
Epoch 22 Total loss: 1153279.9722222222
Epoch 23 Total 

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:06<00:00, 58.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 63.91it/s]


Train dataset contains 9 batches of 64 samples each
Test dataset contains 5 batches of 64 samples each
Epoch 0 Total loss: 2865435.3055555555
Epoch 1 Total loss: 2708445.0
Epoch 2 Total loss: 2669652.6666666665
Epoch 3 Total loss: 2593712.638888889
Epoch 4 Total loss: 2484932.9166666665
Epoch 5 Total loss: 2372940.0
Epoch 6 Total loss: 2272894.5833333335
Epoch 7 Total loss: 2183917.222222222
Epoch 8 Total loss: 2100720.9583333335
Epoch 9 Total loss: 2017666.6249999998
Epoch 10 Total loss: 1934743.3194444443
Epoch 11 Total loss: 1855231.7083333335
Epoch 12 Total loss: 1779585.3333333335
Epoch 13 Total loss: 1708178.75
Epoch 14 Total loss: 1641368.4722222222
Epoch 15 Total loss: 1579138.513888889
Epoch 16 Total loss: 1521089.7916666667
Epoch 17 Total loss: 1467215.888888889
Epoch 18 Total loss: 1417333.4305555553
Epoch 19 Total loss: 1371113.9861111112
Epoch 20 Total loss: 1327775.416666667
Epoch 21 Total loss: 1287641.013888889
Epoch 22 Total loss: 1249816.0833333333
Epoch 23 Total loss

100%|████████████████████████████████████████████████████████████████████████████████| 390/390 [00:05<00:00, 67.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████| 78/78 [00:01<00:00, 69.62it/s]


Wall time: 13min 49s
