In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
tf.enable_eager_execution()

import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio

from IPython import display
%matplotlib inline

from softlearning.utils.tensorflow import nest

tfk = tf.keras
tfkl = tf.keras.layers

!export CUDA_VISIBLE_DEVICES=1

In [None]:
from softlearning.models.state_estimation import (
    get_dumped_pkl_data
)
images_path = '/home/justinvyu/dev/softlearning-vice/goal_classifier/free_screw_state_estimator_data_invisible_claw/more_data.pkl'
images, _ = get_dumped_pkl_data(images_path)

In [None]:
images.shape
images = None

In [None]:
image_shape = images.shape[1:]
num_images = images.shape[0]

# Shuffle images
np.random.shuffle(images)

In [None]:
validation_split = 0.1
split_index = int(num_images * validation_split)
num_training_examples = 500000

_train_images = images[split_index:split_index + num_training_examples]
# _train_images = images[split_index:]
_test_images = images[:split_index]

train_images = _train_images
test_images = _test_images
# train_images = (_train_images / 255.).astype(np.float32)
# test_images = (_test_images / 255.).astype(np.float32)

In [None]:
train_images = train_images[:200000]
train_images.shape, test_images.shape

In [None]:
train_images.shape, test_images.shape

In [None]:
BATCH_SIZE = 128

def train_generator():
    for image in train_images:
        yield image

def test_generator():
    for image in test_images:
        yield image
        
train_dataset = tf.data.Dataset.from_generator(train_generator, tf.float32).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_generator(test_generator, tf.float32).batch(BATCH_SIZE)

In [None]:
def preprocess(x):
    """Cast to float, normalize, and concatenate images along last axis."""
    x = nest.map_structure(
        lambda image: tf.image.convert_image_dtype(image, tf.float32), x)
    x = nest.flatten(x)
    x = tf.concat(x, axis=-1)
#     x = (tf.image.convert_image_dtype(x, tf.float32) - 0.5) * 2.0
    return x


class CVAE(tf.keras.Model):
    def __init__(self, input_shape=(32, 32, 3), latent_dim=64):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
                
        self.encoder = tf.keras.Sequential([
            tfkl.InputLayer(input_shape=input_shape),
            tfkl.Lambda(preprocess),
            tfkl.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation=tfkl.LeakyReLU()),
            tfkl.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation=tfkl.LeakyReLU()),
            tfkl.Conv2D(
                filters=32, kernel_size=3, strides=(2, 2), activation=tfkl.LeakyReLU()),
            tfkl.Flatten(),
            # No activation
            tfkl.Dense(latent_dim + latent_dim)])

        self.decoder = tf.keras.Sequential([
            tfkl.InputLayer(input_shape=(latent_dim,)),
            tfkl.Dense(units=4*4*32, activation=tf.nn.relu),
            tfkl.Reshape(target_shape=(4, 4, 32)),
            tfkl.Conv2DTranspose(
                filters=64,
                kernel_size=3,
                strides=(2, 2),
                padding="SAME",
                activation=tfkl.LeakyReLU()),
            tfkl.Conv2DTranspose(
                filters=64,
                kernel_size=3,
                strides=(2, 2),
                padding="SAME",
                activation=tfkl.LeakyReLU()),
            tfkl.Conv2DTranspose(
                filters=32,
                kernel_size=3,
                strides=(2, 2),
                padding="SAME",
                activation=tfkl.LeakyReLU()),
            # No activation
            tfkl.Conv2DTranspose(
                filters=3, kernel_size=3, strides=(1, 1), padding="SAME")])

    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x):
        mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.decoder(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits

    def __call__(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        x_reconstruct = self.decode(z, apply_sigmoid=True)
        return x_reconstruct

In [None]:
optimizer = tf.keras.optimizers.Adam(1e-4)

def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
        axis=raxis)

@tf.function
def compute_loss(model, x):
    mean, logvar = model.encode(x)
    z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)

    cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(
        logits=x_logit, labels=x)
    logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    logpz = log_normal_pdf(z, 0., 0.)
    logqz_x = log_normal_pdf(z, mean, logvar)
    return -tf.reduce_mean(logpx_z + logpz - logqz_x)

@tf.function
def compute_apply_gradients(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

In [None]:
epochs = 500
latent_dim = 64
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, latent_dim])
model = CVAE(input_shape=image_shape, latent_dim=latent_dim)

In [None]:
def generate_and_save_images(model, epoch, test_input):
    predictions = model.sample(test_input)
    #   print(predictions, predictions.shape)
    fig = plt.figure(figsize=(4,4))
    #   print(predictions[0, :, :, :])
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :, :])
        plt.axis('off')

    plt.savefig('/home/justinvyu/dev/softlearning-vice/notebooks/vae_invisible_claw_images/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [None]:
generate_and_save_images(model, 0, random_vector_for_generation)

# with tf.device('/GPU:1'):
for epoch in range(1, epochs + 1):
    start_time = time.time()
    for train_x in train_dataset:
        compute_apply_gradients(model, train_x, optimizer)
    end_time = time.time()

    if epoch % 1 == 0:
        loss = tf.keras.metrics.Mean()
        for test_x in test_dataset:
            loss(compute_loss(model, test_x))
        elbo = -loss.result()
        display.clear_output(wait=False)
        print('Epoch: {}, Test set ELBO: {}, '
              'time elapse for current epoch {}'.format(epoch,
                                                        elbo,
                                                        end_time - start_time))
        generate_and_save_images(
            model, epoch, random_vector_for_generation)

In [None]:
encoder = tf.keras.Sequential(
  [
      tf.keras.layers.InputLayer(input_shape=(64, 64, 3)),
      tf.keras.layers.Conv2D(
          filters=64, kernel_size=3, strides=(2, 2), activation=tf.keras.layers.LeakyReLU()),
      tf.keras.layers.Conv2D(
          filters=64, kernel_size=3, strides=(2, 2), activation=tf.keras.layers.LeakyReLU()),
      tf.keras.layers.Conv2D(
          filters=32, kernel_size=3, strides=(2, 2), activation=tf.keras.layers.LeakyReLU()),
      tf.keras.layers.Flatten(),
      # No activation
      tf.keras.layers.Dense(latent_dim + latent_dim),
  ]
)

decoder = tf.keras.Sequential(
    [
      tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
      tf.keras.layers.Dense(units=8*8*32, activation=tf.nn.relu),
      tf.keras.layers.Reshape(target_shape=(8, 8, 32)),
      tf.keras.layers.Conv2DTranspose(
          filters=64,
          kernel_size=3,
          strides=(2, 2),
          padding="SAME",
          activation=tf.keras.layers.LeakyReLU()),
      tf.keras.layers.Conv2DTranspose(
          filters=64,
          kernel_size=3,
          strides=(2, 2),
          padding="SAME",
          activation=tf.keras.layers.LeakyReLU()),
    tf.keras.layers.Conv2DTranspose(
          filters=32,
          kernel_size=3,
          strides=(2, 2),
          padding="SAME",
          activation=tf.keras.layers.LeakyReLU()),
      # No activation
      tf.keras.layers.Conv2DTranspose(
          filters=3, kernel_size=3, strides=(1, 1), padding="SAME"),
    ]
)

In [None]:
encoder.summary()

In [None]:
decoder.summary()

In [None]:
model.encoder.save_weights('inference_weights.h5')
model.decoder.save_weights('generative_weights.h5')

In [None]:
model = CVAE(latent_dim)
model.encoder.load_weights('inference_weights.h5')
model.decoder.load_weights('generative_weights.h5')

In [None]:
test_image = images[1]
test_image = (test_image / 255.).astype(np.float32)

In [None]:
mean, logvar = model.encode(test_image[None, ...])
z = model.reparameterize(mean, logvar)
x_logit = model.decode(z, apply_sigmoid=True)

In [None]:
mean, logvar

In [None]:
z

In [None]:
x_logit

In [None]:
decoded = x_logit.numpy()
plt.imshow(decoded[0])
plt.show()
plt.imshow(test_image)
plt.show()

In [None]:
l2_loss = np.linalg.norm(decoded - test_image)
l2_loss

In [None]:
anim_file = 'cvae.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('vae_images/image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

import IPython
if IPython.version_info >= (6,2,0,''):
  display.Image(filename=anim_file)

In [None]:
def plot_side_by_side(img1, img2, title1='', title2='', figsize=(4, 2)):
    fig = plt.figure(figsize=figsize)
    plt.subplot(1, 2, 1)
    plt.title(title1)
    plt.imshow(img1)
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.title(title2)
    plt.imshow(img2)
    plt.axis('off')

    plt.show()

plot_side_by_side(test_image, decoded[0], 'Ground Truth', 'VAE Reconstruction')

In [None]:
for image in images[10500:11000]:
    image = (image / 255.).astype(np.float32)
    reconstruction = model(image[None, ...])
    plot_side_by_side(image, reconstruction[0], figsize=(2, 1))