In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import skimage

from IPython import display
%matplotlib inline

In [None]:
tfk = tf.keras
tfkl = tf.keras.layers

In [None]:
path = '/nfs/kun1/users/justinvyu/data/fixed_data.pkl'
import gzip
import pickle
with gzip.open(path, 'rb') as f:
    images = pickle.load(f)

## Fixing the data

In [None]:
normalized_images, unnormalized_images = images[:300000], images[300000:]

In [None]:
normalized_images.shape, unnormalized_images.shape

In [None]:
fixed_images_0 = skimage.util.img_as_ubyte(normalized_images)

In [None]:
fixed_images_1 = unnormalized_images.astype(np.uint8)

In [None]:
plt.imshow(fixed_images_0[15000])
plt.imshow(fixed_images_1[15000])

In [None]:
fixed_dataset = np.stack([fixed_images_0, fixed_images_1])

In [None]:
plt.imshow(fixed_dataset[599999])
fixed_dataset.shape

In [None]:
import gzip
import pickle
with gzip.open('/root/nfs/kun1/users/justinvyu/data/fixed_data.pkl', 'wb') as f:
    pickle.dump(fixed_dataset, f)

In [None]:
np.random.shuffle(fixed_dataset)

In [None]:
images = fixed_dataset

## Create training/eval sets

In [None]:
np.random.shuffle(images)

In [None]:
num_images = images.shape[0]
split_index = int(0.1 * num_images)
train_images = images[split_index:]
test_images = images[:split_index]

In [None]:
# Create training and test sets
BATCH_SIZE = 128

def train_generator():
    for image in train_images:
        yield image

def test_generator():
    for image in test_images:
        yield image
        
train_dataset = tf.data.Dataset.from_generator(train_generator, tf.uint8).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_generator(test_generator, tf.uint8).batch(BATCH_SIZE)

## Define the model

In [None]:
from softlearning.models.vae import VAE

## Create the optimizer + ELBO loss function

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
        axis=raxis)

@tf.function
def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)

    # Cross entropy reconstruction loss assumes that the pixels
    # are all independent Bernoulli r.v.s
    # Need to preprocess the label, so the output will be normalized.
    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=x_logit, labels=model.preprocess(x))
    # Sum across all pixels (row/col) + channels
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    
    # Calculate the KL divergence (difference between log of unit normal prior and posterior)
    logpz = log_normal_pdf(z, 0., 0.) # Prior PDF
    logqz_x = log_normal_pdf(z, mean, logvar) # Posterior
    
    reconstruction_loss = logpx_z
    kl_divergence = logpz - logqz_x
    
    beta = 1.0
    loss = reconstruction_loss + beta * kl_divergence
    
#     return -tf.reduce_mean(logpx_z + logpz - logqz_x)
    return -tf.reduce_mean(loss)

@tf.function
def compute_apply_gradients(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [None]:
epochs = 500
latent_dim = 4
num_examples_to_generate = 16
image_shape = (32, 32, 3)

random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, latent_dim])
vae = VAE(image_shape=image_shape, latent_dim=latent_dim)

In [None]:
vae.encoder.summary()
vae.decoder.summary()

## Visualization

In [None]:
def plot_images(images, title=''):
    num_images = images.shape[0]
    rows = int(np.sqrt(num_images))
    cols = num_images // rows
    plt.figure(figsize=(rows, cols))
    plt.title(title)
    print(title)
    for i in range(num_images):
        plt.subplot(rows, cols, i + 1)
        plt.axis('off')
        plt.imshow(images[i, ...])
    plt.show()

def generate_and_save_images(model, epoch, test_input):
    predictions = model.sample(test_input)
    fig = plt.figure(figsize=(4,4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, :])
        plt.axis('off')

    plt.savefig('/home/justinyu/Developer/softlearning/notebooks/vae_images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

## Training loop

In [None]:
generate_and_save_images(vae, 0, random_vector_for_generation)
elbo_history = []

for epoch in range(1, epochs + 1):
    start_time = time.time()
    for train_x in train_dataset:
        compute_apply_gradients(vae, train_x, optimizer)
    end_time = time.time()

    if epoch % 25 == 0:
        # Save weights
        vae.encoder.save_weights('/home/justinyu/Developer/softlearning/notebooks/vae_weights/invisible_claw_encoder_weights_4.h5')
        vae.decoder.save_weights('/home/justinyu/Developer/softlearning/notebooks/vae_weights/invisible_claw_decoder_weights_4.h5')

    if epoch % 1 == 0:
        loss = tf.keras.metrics.Mean()
        for test_x in test_dataset:
            loss(compute_loss(vae, test_x))
        elbo = -loss.result()
        display.clear_output(wait=False)
        print('Epoch: {}, Test set ELBO: {}, '
              'time elapse for current epoch {}'.format(epoch, elbo, end_time - start_time))
        elbo_history.append(elbo)
        generate_and_save_images(
            vae, epoch, random_vector_for_generation)

## Visualizing ground truth vs. reconstructions

In [None]:
n = 64
eval_images = test_images[:n]
plot_images(eval_images, title='Ground truth images')
reconstructions = vae(eval_images)
plot_images(np.array(reconstructions), title='VAE Reconstructions')

In [None]:
vae.encoder.save_weights('/home/justinyu/Developer/softlearning/notebooks/vae_weights/invisible_claw_encoder_weights_4_final.h5')
vae.decoder.save_weights('/home/justinyu/Developer/softlearning/notebooks/vae_weights/invisible_claw_decoder_weights_4_final.h5')

## Test loading a model

In [None]:
loaded_vae = VAE(image_shape=(32, 32, 3), latent_dim=16)
path = '/home/justinyu/Developer/softlearning/softlearning/models/vae_weights'
encoder_path = os.path.join(path, 'invisible_claw_encoder_weights.h5')
decoder_path = os.path.join(path, 'invisible_claw_decoder_weights.h5')
loaded_vae.encoder.load_weights(encoder_path)
loaded_vae.decoder.load_weights(decoder_path)
loaded_reconstructions = loaded_vae(eval_images)
plot_images(np.array(loaded_reconstructions), title='Loaded VAE Reconstructions')
loaded_encodings = loaded_vae.encode(eval_images)

In [None]:
loaded_vae.get_encoder()

In [None]:
checkpoint_fn = '/nfs/kun1/users/justinvyu/data/checkpoint.pkl'
replay_pool_fn = '/nfs/kun1/users/justinvyu/data/replay_pool.pkl'
import pickle
import gzip
with open(checkpoint_fn, 'rb') as f:
    checkpoint = pickle.load(f)
    
with gzip.open(replay_pool_fn, 'rb') as f:
    replay_pool = pickle.load(f)

In [None]:
vae_encoder_weights = checkpoint['policy_weights'][:8]

In [None]:
loaded_vae.encoder.set_weights(vae_encoder_weights)

In [None]:
loaded_reconstructions = loaded_vae(eval_images)
plot_images(np.array(loaded_reconstructions), title='Loaded VAE Reconstructions')

In [None]:
replay_pool_images = replay_pool['observations']['pixels']
random_indices = np.random.randint(replay_pool_images.shape[0], size=100)
eval_replay_pool = replay_pool_images[random_indices]

In [None]:
eval_reconst = loaded_vae(eval_replay_pool)
plot_images(eval_replay_pool)
plot_images(np.array(eval_reconst), title='Loaded VAE Reconstructions')

In [None]:
eval_encodings = vae.encode(eval_images)
np.set_printoptions(precision=3)
np.array(eval_encodings[0])

In [None]:
test_encoded =[ 2.0327365 , -0.48694006,  1.119025  , -0.08618406]

In [None]:
example_encoding = vae.encode(images[10000][None])[0]
example_encoding
print(example_encoding)
plt.imshow(vae.decode(np.array([example_encoding]), apply_sigmoid=True)[0])
plt.show()
plt.imshow(vae.decode(np.array([test_encoded]), apply_sigmoid=True)[0])
plt.show()