<a href="https://colab.research.google.com/github/JiaSunDeepLearning/AutoEncoder/blob/master/myautoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

TensorFlow 2.x selected.


In [3]:
from __future__ import absolute_import, division, print_function, unicode_literals

# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras

# Helper libraries
import numpy as np
import matplotlib.pyplot as plt

print(tf.__version__)

2.0.0-beta1


In [4]:
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
zip_dir = tf.keras.utils.get_file('cats_and_dogs_filterted.zip', origin=_URL, extract=True)

Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip


In [5]:
import os
zip_dir_base = os.path.dirname(zip_dir)
!find $zip_dir_base -type d -print
print(zip_dir_base)

/root/.keras/datasets
/root/.keras/datasets/cats_and_dogs_filtered
/root/.keras/datasets/cats_and_dogs_filtered/train
/root/.keras/datasets/cats_and_dogs_filtered/train/cats
/root/.keras/datasets/cats_and_dogs_filtered/train/dogs
/root/.keras/datasets/cats_and_dogs_filtered/validation
/root/.keras/datasets/cats_and_dogs_filtered/validation/cats
/root/.keras/datasets/cats_and_dogs_filtered/validation/dogs
/root/.keras/datasets


In [0]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import os
import time
import numpy as np
import glob
import matplotlib.pyplot as plt
# plt.switch_backend('agg')
import PIL
import imageio
from IPython import display
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
import pathlib
import random
import logging

logger = tf.get_logger()
logger.setLevel(logging.ERROR)

In [7]:
base_dir = os.path.join(os.path.dirname(zip_dir), 'cats_and_dogs_filtered')
path = os.path.join(base_dir, 'train/')
print(path)

/root/.keras/datasets/cats_and_dogs_filtered/train/


In [0]:
BATCH_SIZE = 20
IMG_SHAPE = 256


class CVAE(tf.keras.Model):
    def __init__(self, latent_dim):
        super(CVAE, self).__init__()
        self.latent_dim = latent_dim
        self.inference_net = tf.keras.Sequential(
            [
                # tf.keras.layers.InputLayer(input_shape=(128, 128, 2)),
                tf.keras.layers.Conv2D(
                    filters=64, kernel_size=(5, 5), strides=(2, 2), activation='relu', data_format='channels_last',
                    padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=128, kernel_size=(5, 5), strides=(2, 2), activation='relu', padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=256, kernel_size=(5, 5), strides=(2, 2), activation='relu', padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2D(
                    filters=512, kernel_size=(5, 5), strides=(2, 2), activation='relu', padding='valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Flatten(),
                # No activation
                tf.keras.layers.Dropout(rate=0.3),
                tf.keras.layers.Dense(1024, activation='relu'),

                tf.keras.layers.Flatten(),
                # No activation
                tf.keras.layers.Dropout(rate=0.3),
                tf.keras.layers.Dense(2048, activation='relu'),

            ]
        )

        self.generative_net = tf.keras.Sequential(
            [
                tf.keras.layers.InputLayer(input_shape=(2048,)),
                tf.keras.layers.Dropout(rate=0.3),
                tf.keras.layers.Dense(units=16 * 16 * 512, activation=tf.nn.relu),
                tf.keras.layers.Reshape(target_shape=(16, 16, 512)),
                tf.keras.layers.Conv2DTranspose(
                    filters=256,
                    kernel_size=(5, 5),
                    strides=(2, 2),
                    padding="SAME",
                    activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2DTranspose(
                    filters=128,
                    kernel_size=(5, 5),
                    strides=(2, 2),
                    padding="SAME",
                    activation='relu'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.Conv2DTranspose(
                    filters=32,
                    kernel_size=(5, 5),
                    strides=(2, 2),
                    padding="SAME",
                    activation='relu'),
                tf.keras.layers.BatchNormalization(),
                # No activation
                tf.keras.layers.Conv2DTranspose(
                    filters=3, kernel_size=(5, 5), strides=(1, 1), padding="SAME", activation='relu'),
            ]
        )

    @tf.function
    def sample(self, eps=None):
        if eps is None:
            eps = tf.random.normal(shape=(100, self.latent_dim))
        return self.decode(eps, apply_sigmoid=True)

    def encode(self, x):
        # mean, logvar = tf.split(self.inference_net(x), num_or_size_splits=2, axis=1)
        y = self.inference_net(x)
        return y

    def reparameterize(self, mean, logvar):
        eps = tf.random.normal(shape=mean.shape)
        return eps * tf.exp(logvar * .5) + mean

    def decode(self, z, apply_sigmoid=False):
        logits = self.generative_net(z)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs

        return logits

    def sample_forward(self, x):
        y = self.inference_net(x)
        return y


def log_normal_pdf(sample, mean, logvar, raxis=1):
    log2pi = tf.math.log(2. * np.pi)
    return tf.reduce_sum(
        -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
        axis=raxis)


@tf.function
def compute_loss(model, x):
    z = model.encode(x)
    # z = model.reparameterize(mean, logvar)
    x_logit = model.decode(z)
    mse = tf.keras.losses.MeanSquaredError()

    # cross_ent = tf.nn.l2_loss(x_logit-x)
    # cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
    # logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
    # logpx_z = -tf.reduce_sum(cross_ent)
    # logpz = log_normal_pdf(z, 0., 0.)
    # logqz_x = log_normal_pdf(z, mean, logvar)
    # return -tf.reduce_mean(logpx_z + logpz - logqz_x)
    return mse(x, x_logit)


@tf.function
def compute_apply_gradients(model, x, optimizer):
    with tf.GradientTape() as tape:
        loss = compute_loss(model, x)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss


def generate_and_save_images(model, epoch, test_input):
    predictions = model.sample(test_input)
    fig = plt.figure(figsize=(4, 4))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[0])
        plt.show()
        # plt.imshow(predictions[i, :, :, 0], cmap='gray')
        plt.axis('off')

    # tight_layout minimizes the overlap between 2 sub-plots
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()


def valid_model(model, image_input):
    t = image_input
    z = tf.random.normal(
        shape=[1, 512],
        mean=125,
        stddev=100.0)
    predictions = model.sample(z)
    plt.imshow(predictions[0])
    plt.show()
    print(predictions[0])


def display_image(epoch_no):
    return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))


def preprocess_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [128, 128])
    image /= 255.0  # normalize to [0,1] range

    return image


def load_and_preprocess_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)


def data_load(path):
    data_root = pathlib.Path(path)
    all_image_paths = list(data_root.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]
    random.shuffle(all_image_paths)

    image_count = len(all_image_paths)

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())

    label_to_index = dict((name, index) for index, name in enumerate(label_names))

    all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                        for path in all_image_paths]

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

    image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(5)

    label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))
    return image_ds, label_ds, image_count


def main():
    # path = './cats_and_dogs_filtered/train/'
    data_root = pathlib.Path(path)
    all_image_paths = list(data_root.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]
    random.shuffle(all_image_paths)

    image_count = len(all_image_paths)

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())

    label_to_index = dict((name, index) for index, name in enumerate(label_names))

    all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                        for path in all_image_paths]

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)

    image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(BATCH_SIZE)

    label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))

    optimizer = tf.keras.optimizers.Adam(1e-4)

    epochs = 500
    latent_dim = 512

    # evaluate_my_model()

    model = CVAE(latent_dim)
    for epoch in range(1, epochs + 1):
        i = 1
        s = 0
        start_time = time.time()
        for train_x in image_ds:
            '''
            while i < 1:
                plt.imshow(train_x[i])
                plt.grid(False)
                plt.show()
                i = i + 1
            valid_model(model, train_x)
            '''
            loss = compute_apply_gradients(model, train_x, optimizer)
            print('Time: {}/{}, Epoch: {}, Batch_mse: {}'.format(i, image_count/BATCH_SIZE, epoch, loss))
            i = i + 1
            s = s + loss
        end_time = time.time()

        if epoch % 1 == 0:
            '''
            loss = tf.keras.metrics.Mean()
            # for test_x in val_dataset:
            #     loss(compute_loss(model, test_x))
            elbo = -loss.result()
            display.clear_output(wait=False)
            print('Epoch: {}, Test set ELBO: {}, '
                  'time elapse for current epoch {}'.format(epoch,
                                                            elbo,
                                                            end_time - start_time))
            '''
            print('Epoch:{}, time elapse: {}, epoch_mse: {}'.format(epoch, end_time - start_time, s / i))

    model.save_weights('./my_checkpoint/checkpoint')

    new_model = CVAE(512)
    new_model.load_weights('./my_checkpoint/checkpoint')
    '''
    print('evaluating')
    for image in image_ds:

        with tf.GradientTape() as tape:
            z = new_model.encode(image)
            print(z)
            y = new_model.sample(z)
            while i < 6:
                plt.imshow(y[i])
                plt.grid(False)
                plt.show()
                i = i + 1
            break
    '''


def evaluate_my_model():
    path = './cats_and_dogs_filtered/train/'
    optimizer = tf.keras.optimizers.Adam(1e-4)
    image_ds, label_ds, image_count = data_load(path)

    new_model = CVAE(512)
    new_model.load_weights('./my_checkpoint/checkpoint')
    i = 0
    noise = tf.random.normal(
        shape=[1, 1024])
    img = new_model.decode(noise)
    # plt.imshow(img[0]*255)
    # plt.show()

    for image in image_ds:
        z = new_model.encode(image)
        print(tf.reduce_max(z), tf.reduce_min(z))
        y = new_model.sample(z)
        print(tf.reduce_max(y), tf.reduce_min(y))
        while i < 5:
            plt.imshow(y[i]*255)
            plt.grid(False)
            plt.show()
            i = i + 1
            break
    return 0


if __name__ == '__main__':
    train = True
    if train:
        main()
    else:
        evaluate_my_model()


Time: 1/100.0, Epoch: 1, Batch_mse: 0.25316089391708374
Time: 2/100.0, Epoch: 1, Batch_mse: 0.2622780501842499
Time: 3/100.0, Epoch: 1, Batch_mse: 0.22908854484558105
Time: 4/100.0, Epoch: 1, Batch_mse: 0.2560398578643799
Time: 5/100.0, Epoch: 1, Batch_mse: 0.26998263597488403
Time: 6/100.0, Epoch: 1, Batch_mse: 0.28410154581069946
Time: 7/100.0, Epoch: 1, Batch_mse: 0.2649073004722595
Time: 8/100.0, Epoch: 1, Batch_mse: 0.2154305875301361
Time: 9/100.0, Epoch: 1, Batch_mse: 0.24646787345409393
Time: 10/100.0, Epoch: 1, Batch_mse: 0.24384717643260956
Time: 11/100.0, Epoch: 1, Batch_mse: 0.1779196560382843
Time: 12/100.0, Epoch: 1, Batch_mse: 0.1271274834871292
Time: 13/100.0, Epoch: 1, Batch_mse: 0.11102068424224854
Time: 14/100.0, Epoch: 1, Batch_mse: 0.11516094207763672
Time: 15/100.0, Epoch: 1, Batch_mse: 0.08273544162511826
Time: 16/100.0, Epoch: 1, Batch_mse: 0.10293938219547272
Time: 17/100.0, Epoch: 1, Batch_mse: 0.10208338499069214
Time: 18/100.0, Epoch: 1, Batch_mse: 0.0833538