<a href="https://colab.research.google.com/github/ItalianPepper/VAE-GAN/blob/master/easy_vaegan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Some reference
- Gan trained by TTUR : https://arxiv.org/abs/1706.08500
- Keras problem with trainable: https://github.com/keras-team/keras/issues/9589
- VAEGAN https://github.com/crmaximo/VAEGAN
- VAEGAN original : https://arxiv.org/pdf/1512.09300.pdf

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Unzip a limited number of file
!unzip -Z1 "/content/drive/My Drive/GAN/img_align_celeba.zip" | head -10001 | sed 's| |\\ |g' | xargs unzip "/content/drive/My Drive/GAN/img_align_celeba.zip" -d ./

In [22]:
!mkdir ./res_gan_autoencoder

In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import os
import numpy as np
import cv2 as cv
from random import shuffle
from tensorflow.keras import metrics, backend as K
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.layers import *


def get_dataset(path, img_shape):
    paths = []

    valid_format = [".jpg", ".png", ".tiff", ".bmp", ".gif"]

    for img in os.listdir(path):

        extension = os.path.splitext(img)[1]

        if extension.lower() in valid_format:
            paths.append(os.path.join(path, img))

    k_train = []

    for path_image in paths:
        img = load_img(path_image, target_size=img_shape[:2])
        
        el = img_to_array(img) / 255.0
        
        k_train.append(el)

    k_train = np.asarray(k_train)

    return k_train

  
def get_noise(n_sample=1, nlatent_dim=512):
    noise = np.random.normal(0, 1, (n_sample, nlatent_dim))
    return (noise)
  

def plot_generated_images(noise, nsample=1, path_save=None, epoch=0):
    imgs = decoder.predict(noise)
    img = imgs[0]
    
    # Solution: "Clipping input data to the valid range for imshow 
    # with RGB data ([0..1] for floats or [0..255] for integers)."
    
    img = (img * 255).astype(np.uint8)
    
    fig = plt.figure(figsize=(40, 10))
    fig.patch.set_visible(False)
    epoch_str = str(epoch)
    extension = ".png"
    path_name = path_save + "/" + epoch_str + extension
    
    ax = fig.add_subplot(1, nsample, 1)
    ax.imshow(img)
    ax.axis("off")
    
    #fig.suptitle("Epoch:" + epoch_str, fontsize=30)
    
    plt.savefig(path_name,
                bbox_inches='tight',
                pad_inches=0)
    
    plt.close()

def build_encoder(img_shape):
    input_img = Input(shape=img_shape, name="input_encoder")
    n_filter = 32
    
    # Encoding
    x = Conv2D(n_filter, kernel_size=(3, 3), padding="same")(input_img)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = MaxPooling2D(padding='same')(x)
    
    x = Conv2D(n_filter*4, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = MaxPooling2D(padding='same')(x)
    
    x = Conv2D(n_filter*8, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = MaxPooling2D(padding='same')(x)
    
    x = Flatten()(x)
    logsigma = Dense(512, activation="tanh")(x)
    
    model = Model(inputs=input_img, outputs=logsigma)
    # model.summary()
    return (model)
  
def build_decoder_gen(en_shape=(512,)):
    # impostazione del generatore
    en_img = Input(shape=en_shape, name="input_decoder")
    img_dim = 32 * 16
    n_filter = 32
    
    x = Dense(img_dim*8*8)(en_img)
    x = Reshape((8,8,img_dim))(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    # x = Activation("relu")(x)
    
    # Decoding
    x = Conv2D(n_filter*4, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    #x = Activation("relu")(x)
    x = UpSampling2D()(x)
    
    x = Conv2D(n_filter*2, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    #x = Activation("relu")(x)
    x = UpSampling2D()(x)
    
    x = Conv2D(n_filter, kernel_size=(3, 3), padding="same")(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    #x = Activation("relu")(x)
    x = UpSampling2D()(x)

    x = Conv2D(3, kernel_size=(3, 3), padding="same")(x)
    x = Activation('tanh')(x)
    
    model = Model(inputs=en_img, outputs=x)
    # model.summary()
    return (model)
  
def build_discriminator(img_shape):
  
    input_img = Input(shape=img_shape, name="input_discriminator")
    n_filter = 32
    
    x = Conv2D(n_filter, kernel_size=(3, 3), padding='same', strides=2)(input_img)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(n_filter*2, kernel_size=(3, 3), padding='same', strides=2)(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(n_filter*4, kernel_size=(3, 3), padding='same', strides=2)(x)
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    
    x = Conv2D(128, kernel_size=(3, 3), padding='same', strides=2)(x)
    
    x = BatchNormalization(epsilon=1e-5)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Flatten()(x)

    out = Dense(1, activation='sigmoid')(x)
    model = Model(inputs=input_img, outputs=out)

    # model.summary()
    return model
  
  
def train(encoder, decoder, discriminator, enc_dec, x_train,
          dir_result="./", epochs=10000, batch_size=512):
   
    history = []
    
    if epochs >= 100:
      checkpoint = int(epochs/10)
    else:
      checkpoint = 1
    
    
    for epoch in range(epochs):
      
        start_epoch = time.time()
        
        np.random.shuffle(x_train)
        
        batches_index = [x for x in range(0, len(x_train), batch_size)]
        
        for i in range(0, len(batches_index)):
          
          start = batches_index[i]
          
          if i == len(batches_index)-1:
            end = len(x_train)
          else:
            end = batches_index[i+1]
            
          x_batch = x_train[start:end]
          x_batch_size = len(x_batch)
          
          
          # encoder e decoder su immagini reali
          gen_img_enc = encoder.predict(x_batch)
          gen_img_dec = decoder.predict(gen_img_enc)

          noise = get_noise(n_sample=x_batch_size)
          # immagini create dal noise
          fake_img = decoder.predict(noise)

          # Allenamento encoder - decoder
          enc_dec_loss = enc_dec.train_on_batch(x_batch, x_batch)

          # Allenamento solo decoder
          noise_gen = get_noise(n_sample=x_batch_size)
          decoder_noise_loss = decoder.train_on_batch(noise_gen, x_batch)

          matrix_true = np.ones((x_batch_size, 1))
          matrix_false = np.zeros((x_batch_size, 1))
          
          # Allenamento discriminatore
          dsc_true = discriminator.train_on_batch(x_batch, matrix_true)
          dsc_true_enc = discriminator.train_on_batch(gen_img_dec, matrix_true)
          dsc_fake = discriminator.train_on_batch(fake_img, matrix_false)

        if epoch % checkpoint == 0 and epoch > 0:
          history.append({"Epoch:": epoch,
                          "Discriminator_loss_true:": dsc_true[0],
                          "Discriminator_loss_true_enc:": dsc_true_enc[0],
                          "Discriminator_loss_fake": dsc_fake[0],
                          "Discriminator_acc_fake": dsc_fake[1],
                          "Encoder-Decoder:": enc_dec_loss,
                          "Decoder_noise_loss:": decoder_noise_loss,
                          })

          noise = get_noise(n_sample=1)

          plot_generated_images(noise, 1, dir_result, epoch)

        end_epoch = time.time()
        end_epoch = end_epoch - start_epoch
        print("Ended Epoch {:0.0f}/{:1.0f} in: {:2.4f} s".format(epoch + 1, epochs, end_epoch))

    return history

def gan_loss(real_val, generated_val):
  def loss(y_true, y_pred):
    return K.log((real_val)) + K.log(1-generated_val)
  return loss

start = time.time()

img_shape = (64, 64, 3)
noise_shape = (512,)
optimizer_std = SGD(0.0003)
optimizer_disc = SGD(0.0004)

pre_trained = False
path_load_models = "./drive/My Drive/GAN/train_2_img_celeba/models/"

if pre_trained == False:

  encoder = build_encoder(img_shape)
  decoder = build_decoder_gen(noise_shape)
  discriminator = build_discriminator(img_shape)

  encoder.compile(loss="mse", optimizer=optimizer_std)
  decoder.compile(loss="mse", optimizer=optimizer_std)

  discriminator.compile(loss="mse",
                              optimizer=optimizer_disc,
                              metrics=["accuracy"])

  x_img = Input(shape=img_shape)

  enc_img = encoder(x_img)

  dec_img = decoder(enc_img)

  enc_dec = Model(inputs = [x_img], outputs = [dec_img])

  enc_dec.compile(loss="mse", optimizer=optimizer_std)

  #enc_dec.summary()

  # Wrapping first model (enc_dec) in the second (gan)

  x_1 = Input(shape=img_shape)
  x_2 = Input(shape=noise_shape)

  d_img = enc_dec([x_1])
  g_img = decoder([x_2])

  true_el = discriminator(x_1)
  true_gen = discriminator(d_img)
  fake_img = discriminator(g_img)

  gan = Model(inputs=[x_1, x_2], outputs=[true_el, fake_img])

  gan.compile(loss=gan_loss(true_el, fake_img), optimizer=optimizer_std)

  gan.summary()

else:
  encoder = load_model(path_load_models+"encoder.h5")
  decoder = load_model(path_load_models+"decoder.h5")
  discriminator = load_model(path_load_models+"discriminator.h5")
  enc_dec = load_model(path_load_models+ "enc_dec.h5")
  gan = load_model(path_load_models+"gan_autoencoder.h5")
  

x_train = get_dataset("./img_align_celeba", img_shape)

path = "./res_gan_autoencoder/"

history = train(encoder, decoder, discriminator, enc_dec, x_train,
                dir_result=path, epochs=101, batch_size=128)

with open(path+"log_gan.txt", "w") as f:
  for item in history:
    item_str = str(item) + "\n"
    f.write(item_str)
    
discriminator.trainable = True

encoder.save(path+"encoder.h5", save_format="tf")
decoder.save(path+"decoder.h5", save_format="tf")
discriminator.save(path+"discriminator.h5", save_format="tf")
enc_dec.save(path+"enc_dec.h5", save_format="tf")
gan.save(path+"gan_autoencoder.h5", save_format="tf")

end = time.time()
end = end - start
print("Finished in: " + str(end) + " s")

Model: "functional_165"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_59 (InputLayer)           [(None, 512)]        0                                            
__________________________________________________________________________________________________
input_58 (InputLayer)           [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
functional_159 (Functional)     (None, 64, 64, 3)    17496003    input_59[0][0]                   
__________________________________________________________________________________________________
functional_161 (Functional)     (None, 1)            244161      input_58[0][0]                   
                                                                 functional_159[2][0]

In [None]:
!rm -r ./img_align_celeba

In [21]:
!rm -r ./res_gan_autoencoder