<a href="https://colab.research.google.com/github/JiaSunDeepLearning/AutoEncoder/blob/master/autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')


Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
with open('/content/drive/My Drive/Autoencodercheckpoint/foo.txt', 'w') as f:
  f.write('Hello Google Drive!')
!cat /content/drive/My\ Drive/Autoencodercheckpoint/foo.txt

Hello Google Drive!

In [0]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [5]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)
import os
import time
import glob
import imageio
from IPython import display
import pathlib
import random
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
import PIL
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
zip_dir = tf.keras.utils.get_file('cats_and_dogs_filterted.zip', origin=_URL, extract=True)

zip_dir_base = os.path.dirname(zip_dir)
!find $zip_dir_base -type d -print
print(zip_dir_base)

base_dir = os.path.join(os.path.dirname(zip_dir), 'cats_and_dogs_filtered')
path = os.path.join(base_dir, 'train/')
print(path)

2.0.0-beta1
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
/root/.keras/datasets
/root/.keras/datasets/cats_and_dogs_filtered
/root/.keras/datasets/cats_and_dogs_filtered/train
/root/.keras/datasets/cats_and_dogs_filtered/train/cats
/root/.keras/datasets/cats_and_dogs_filtered/train/dogs
/root/.keras/datasets/cats_and_dogs_filtered/validation
/root/.keras/datasets/cats_and_dogs_filtered/validation/cats
/root/.keras/datasets/cats_and_dogs_filtered/validation/dogs
/root/.keras/datasets
/root/.keras/datasets/cats_and_dogs_filtered/train/


In [0]:
BATCH_SIZE = 50


class CVAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.inference_net = tf.keras.Sequential(
            [
                tf.keras.layers.Conv2D(
                    filters=64, kernel_size=(5, 5), strides=(2, 2), activation='relu', data_format='channels_last',
                    padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=128, kernel_size=(5, 5), strides=(2, 2), activation='relu', padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=256, kernel_size=(5, 5), strides=(2, 2), activation='relu', padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=512, kernel_size=(5, 5), strides=(2, 2), activation='relu', padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Flatten(),
                # No activation
                tf.keras.layers.Dropout(rate=0.3),
                tf.keras.layers.Dense(1024, activation='relu'),

                tf.keras.layers.Flatten(),
                # No activation
                tf.keras.layers.Dropout(rate=0.3),
                tf.keras.layers.Dense(2048, activation='relu'),

            ]
        )

        self.generative_net = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(2048,)),
                tf.keras.layers.Dropout(rate=0.3),
                tf.keras.layers.Dense(units=16 * 16 * 512, activation=tf.nn.relu),
                tf.keras.layers.Reshape(target_shape=(16, 16, 512)),
                tf.keras.layers.Conv2DTranspose(
                    filters=256,
                    kernel_size=(5, 5),
                    strides=(2, 2),
                    padding="SAME",
                    activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2DTranspose(
                    filters=128,
                    kernel_size=(5, 5),
                    strides=(2, 2),
                    padding="SAME",
                    activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2DTranspose(
                    filters=32,
                    kernel_size=(5, 5),
                    strides=(2, 2),
                    padding="SAME",
                    activation='relu'),
                tf.keras.layers.BatchNormalization(),
                # No activation
                tf.keras.layers.Conv2DTranspose(
                    filters=3, kernel_size=(5, 5), strides=(1, 1), padding="SAME", activation='relu'),
            ]
        )

    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x):
        # mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
        y = self.inference_net(x)
        return y

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.generative_net(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs

        return logits

    def sample_forward(self, x):
        y = self.inference_net(x)
        return y


def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
        axis=raxis)


@tf.function
def compute_loss(model, x):
    z = model.encode(x)
    # z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    mse = tf.keras.losses.MeanSquaredError()

    # cross_ent = tf.nn.l2_loss(x_logit-x)
    # cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    # logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    # logpx_z = -tf.reduce_sum(cross_ent)
    # logpz = log_normal_pdf(z, 0., 0.)
    # logqz_x = log_normal_pdf(z, mean, logvar)
    # return -tf.reduce_mean(logpx_z + logpz - logqz_x)
    return mse(x, x_logit)


@tf.function
def compute_apply_gradients(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


def generate_and_save_images(model, epoch, test_input):
    predictions = model.sample(test_input)
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[0])
        plt.show()
        # plt.imshow(predictions[i, :, :, 0], cmap='gray')
        plt.axis('off')

    # tight_layout minimizes the overlap between 2 sub-plots
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


def valid_model(model, image_input):
    t = image_input
    z = tf.random.normal(
        shape=[1, 512],
        mean=125,
        stddev=100.0)
    predictions = model.sample(z)
    plt.imshow(predictions[0])
    plt.show()
    print(predictions[0])


def display_image(epoch_no):
    return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))


def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [128, 128])
    image /= 255.0  # normalize to [0,1] range

    return image


def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)


def data_load(path):
    data_root = pathlib.Path(path)
    all_image_paths = list(data_root.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]
    random.shuffle(all_image_paths)

    image_count = len(all_image_paths)

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())

    label_to_index = dict((name, index) for index, name in enumerate(label_names))

    all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                        for path in all_image_paths]

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

    image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(BATCH_SIZE)

    label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
    return image_ds, label_ds, image_count


def main():
    data_root = pathlib.Path(path)
    all_image_paths = list(data_root.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]
    random.shuffle(all_image_paths)

    image_count = len(all_image_paths)

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())

    label_to_index = dict((name, index) for index, name in enumerate(label_names))

    all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                        for path in all_image_paths]

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

    image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(BATCH_SIZE)

    label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64)).batch(BATCH_SIZE)

    optimizer = tf.keras.optimizers.Adam(1e-4)

    epochs = 160
    latent_dim = 512

    model = CVAE(latent_dim)
    for epoch in range(1, epochs + 1):
        i = 1
        s = 0
        start_time = time.time()
        for train_x in image_ds:
            loss = compute_apply_gradients(model, train_x, optimizer)
            i = i + 1
            s = s + loss
        end_time = time.time()

        if epoch % 1 == 0:
            '''
            loss = tf.keras.metrics.Mean()
            # for test_x in val_dataset:
            #     loss(compute_loss(model, test_x))
            elbo = -loss.result()
            display.clear_output(wait=False)
            print('Epoch: {}, Test set ELBO: {}, '
                  'time elapse for current epoch {}'.format(epoch,
                                                            elbo,
                                                            end_time - start_time))
            '''
            print('Epoch:{}, time elapse: {}, epoch_mse: {}'.format(epoch, end_time - start_time, s / i))
            
        if epoch % 20 == 0:
            model.save_weights('/content/drive/My Drive/Autoencodercheckpoint/ckpt')



def evaluate_my_model():
    image_ds, label_ds, image_count = data_load(path)

    new_model = CVAE(512)
    new_model.load_weights('/content/drive/My Drive/Autoencodercheckpoint/ckpt')
    noise = tf.random.normal(
        shape=[1, 2048])
    img = new_model.decode(noise)

    for image in image_ds:
        z = new_model.encode(image)
        # print(tf.reduce_max(z), tf.reduce_min(z))
        y = new_model.sample(z)
        # print(tf.reduce_max(y), tf.reduce_min(y))
        i = 0
        while i < 4:
            plt.imshow(y[i])
            plt.grid(False)
            plt.show()
            i = i + 1
    return 0


if __name__ == '__main__':
    train = False
    if train:
        main()
    else:
        evaluate_my_model()