<a href="https://colab.research.google.com/github/T-Yamaguch/PatchWGAN/blob/master/PatchWGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Concatenate, Conv2D, \
MaxPooling2D, Activation, ReLU, LeakyReLU, UpSampling2D, BatchNormalization, \
Dropout, Dense, Flatten, Add, LayerNormalization, GaussianNoise, Reshape, Lambda
from keras.regularizers import l2

class conv_block(Model):
  def __init__(self, filter_num, kernel_size, kernel_regularizer= l2(0.001)):
    super(conv_block, self).__init__()
    self.conv = Conv2D(filter_num, kernel_size, padding = 'same', kernel_regularizer= kernel_regularizer)
    self.norm = BatchNormalization(trainable=True)
    self.act = LeakyReLU()
    self.up = UpSampling2D((2,2))
    self.noise = GaussianNoise(0.2)

  def call(self, x):
    x = self.conv(x)
    x = self.norm(x)
    x = self.act(x)
    x = self.up(x)
    x = self.noise(x)
    return x

class res_block(Model):
  def __init__(self, filter_num, kernel_size, kernel_regularizer= l2(0.001)):
    super(res_block, self).__init__()
    self.conv1 = Conv2D(filter_num, kernel_size, padding = 'same', kernel_regularizer= kernel_regularizer)
    self.conv2 = Conv2D(filter_num, kernel_size, padding = 'same', kernel_regularizer= kernel_regularizer)
    self.norm1 = BatchNormalization(trainable=True)
    self.norm2 = BatchNormalization(trainable=True)
    self.act1 = LeakyReLU()
    self.act2 = LeakyReLU()
    self.add = Add()

  def call(self, x):
    y = self.conv1(x)
    y = self.norm1(y)
    y = self.act1(y)
    y = self.conv2(y)
    y = self.norm2(y)
    y = self.act2(y)
    x = self.add([x, y])
    return x

class disc_block(Model):
  def __init__(self, filter_num, kernel_size, kernel_regularizer= l2(0.001)):
    super(disc_block, self).__init__()
    self.conv = Conv2D(filter_num, kernel_size, padding = 'same', kernel_regularizer= kernel_regularizer)
    self.norm = BatchNormalization(trainable=True)
    self.act = LeakyReLU()
    self.pooling = MaxPooling2D((2,2), strides=(2,2))
    self.drop = Dropout(0.3)

  def call(self, x):
    x = self.conv(x)
    # x = self.norm(x) dにnorm入れないほうがいいという噂
    x = self.act(x)
    x = self.pooling(x)
    x = self.drop(x)
    return x

class dense_block(Model):
  def __init__(self, filter_num, kernel_regularizer= l2(0.001)):
    super(dense_block, self).__init__()
    self.dense = Dense(filter_num, kernel_regularizer= kernel_regularizer)
    self.norm = BatchNormalization(trainable=True)
    self.act = LeakyReLU()

  def call(self, x):
    x = self.dense(x)
    x = self.norm(x)
    x = self.act(x)
    return x

class dense_block_wo_norm(Model):
  def __init__(self, filter_num, kernel_regularizer= l2(0.001)):
    super(dense_block_wo_norm, self).__init__()
    self.dense = Dense(filter_num, kernel_regularizer= kernel_regularizer)
    self.act = LeakyReLU()

  def call(self, x):
    x = self.dense(x)
    x = self.act(x)
    return x




In [3]:
class Generator():
  def __init__(self):
    self.channel_num = 1024
    self.layer_num = 4
    self.res_num = 0
    self.latent_num = 8
    self.inputs = Input(shape=(self.latent_num)) 
    self.kernel_size = (5, 5)
    self.name = 'generator'
    self.kernel_regularizer= None
      
  def model(self):
    x = self.inputs

    final_size = 4*4*self.channel_num
    data_size = self.latent_num

    while data_size*64 < final_size:
      data_size *= 64
      x = dense_block(data_size, kernel_regularizer= self.kernel_regularizer)(x)

    x = dense_block(final_size, kernel_regularizer= self.kernel_regularizer)(x)
    x = Reshape((4, 4, self.channel_num))(x)

    filter_num = self.channel_num
    
    for n in range(self.layer_num):
      for m in range(self.res_num):
        x = res_block(filter_num, self.kernel_size, kernel_regularizer= self.kernel_regularizer)(x)
      filter_num /= 2
      x = conv_block(filter_num, self.kernel_size, kernel_regularizer= self.kernel_regularizer)(x)

    for m in range(self.res_num):
      x = res_block(filter_num, self.kernel_size, kernel_regularizer= self.kernel_regularizer)(x)

    x = Conv2D(3, self.kernel_size, padding = 'same', kernel_regularizer= self.kernel_regularizer)(x)
    x = Activation('sigmoid')(x)
    outputs = x
    return Model(inputs = self.inputs, outputs = outputs, name = self.name)

g = Generator()
g.model().summary()

Model: "generator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 8)]               0         
_________________________________________________________________
dense_block (dense_block)    (None, 512)               6656      
_________________________________________________________________
dense_block_1 (dense_block)  (None, 16384)             8470528   
_________________________________________________________________
reshape (Reshape)            (None, 4, 4, 1024)        0         
_________________________________________________________________
conv_block (conv_block)      (None, 8, 8, 512)         13109760  
_________________________________________________________________
conv_block_1 (conv_block)    (None, 16, 16, 256)       3278080   
_________________________________________________________________
conv_block_2 (conv_block)    (None, 32, 32, 128)       81

In [4]:
class Discriminator():
  def __init__(self):
    self.channel_num = 16
    self.layer_num = 3
    self.latent_num = 8
    self.input_shape = (64, 64, 3)
    self.inputs = Input(shape=self.input_shape)
    self.kernel_size = (4, 4)
    self.name = 'discriminator'
    self.kernel_regularizer= None

  def model(self):
    x = self.inputs

    filter_num = self.channel_num
    for n in range(self.layer_num):
      x = disc_block(filter_num, self.kernel_size, kernel_regularizer= self.kernel_regularizer)(x)
      filter_num *= 2

    y = Conv2D(filter_num, self.kernel_size, padding='same', kernel_regularizer= self.kernel_regularizer)(x)
    y = LeakyReLU()(y)
    y = Conv2D(1, self.kernel_size, padding='same', kernel_regularizer= self.kernel_regularizer)(y)
    y = Flatten()(y)
    y = Dense(1)(y)


    while x.shape[1] != 4:
      x = disc_block(filter_num, self.kernel_size, kernel_regularizer= self.kernel_regularizer)(x)
      filter_num *= 2
    x = Flatten()(x)
    x = dense_block_wo_norm(64, kernel_regularizer= self.kernel_regularizer)(x)
    x = Dense(1)(x)
    
    outputs = Add()([x, y])

    return Model(inputs = self.inputs, outputs = outputs, name = self.name)

d = Discriminator()
d.model().summary()

Model: "discriminator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 64, 64, 3)]  0                                            
__________________________________________________________________________________________________
disc_block (disc_block)         (None, 32, 32, 16)   784         input_2[0][0]                    
__________________________________________________________________________________________________
disc_block_1 (disc_block)       (None, 16, 16, 32)   8224        disc_block[0][0]                 
__________________________________________________________________________________________________
disc_block_2 (disc_block)       (None, 8, 8, 64)     32832       disc_block_1[0][0]               
______________________________________________________________________________________

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.losses import binary_crossentropy, MSE
import glob
import time
import random
import sys

class WGAN():
  def __init__(self, 
               img_size=128, 
               code_num = 2048,
               batch_size = 16, 
               train_epochs = 100, 
               train_steps = 8, 
               checkpoint_epochs = 25, 
               image_epochs = 1, 
               start_epoch = 1,
               optimizer = Adam(learning_rate = 1e-4),
               n_critics = 8
               ):
    
    self.batch_size = batch_size
    self.train_epochs =  train_epochs
    self.train_steps = train_steps
    self.checkpoint_epochs = checkpoint_epochs
    self.image_epochs = image_epochs
    self.start_epoch = start_epoch
    self.code_num = code_num
    self.img_size = img_size
    self.n_critics = n_critics
    
    self.gen_optimizer = optimizer
    self.disc_optimizer = optimizer

    g = Generator()
    self.gen = g.model()
    
    d = Discriminator()
    self.disc = d.model()

    checkpoint_dir = "drive/My Drive/PatchWGAN/checkpoint"
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = tf.train.Checkpoint(gen_optimizer = self.gen_optimizer,
                                     disc_optimizer = self.disc_optimizer,
                                     gen = self.gen,
                                     disc = self.disc,
                                     )

    self.manager = tf.train.CheckpointManager(checkpoint, directory=checkpoint_dir, max_to_keep=2)

    train_image_path = 'drive/My Drive/samples/image'
    
    self.train_filenames = glob.glob(train_image_path + '/*.jpg') 

    checkpoint.restore(self.manager.latest_checkpoint)

    self.g_history = []
    self.d_history = []
    # self.endec_history = []  

  def preprocess_image(self, image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [self.img_size, self.img_size] )
    image = image/255  # normalize to [0,1] range
    return tf.cast(image, tf.float32)

  def load_and_preprocess_image(self, path):
    image = tf.io.read_file(path)
    return self.preprocess_image(image)

  def dataset(self, paths, batch_size):
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    path_ds = tf.data.Dataset.from_tensor_slices(paths)
    img_ds = path_ds.map(self.load_and_preprocess_image, num_parallel_calls=AUTOTUNE)
    img_ds = img_ds.batch(batch_size)
    return img_ds

  def image_preparation(self, filenames, batch_size, steps):
    img_batch = []
    while 1:
      random.shuffle(filenames)
      for path in filenames:
        img_batch.append(path)
        if len(img_batch) == steps*batch_size:
          imgs = self.dataset(img_batch, batch_size)
          img_batch = []
          yield imgs

  def discriminator_loss(self, original_outputs, generated_outputs):
    real_loss = binary_crossentropy(tf.ones_like(original_outputs), original_outputs)
    generated_loss = binary_crossentropy(tf.zeros_like(generated_outputs), generated_outputs)
    loss_d = tf.math.reduce_mean(real_loss + generated_loss)
    return loss_d

  def generator_loss(self, generated_outputs):
    loss_g = tf.math.reduce_mean(binary_crossentropy(tf.ones_like(generated_outputs), generated_outputs))
    return loss_g

  def mse_loss(self, true, pred):
    loss =  tf.math.reduce_mean(MSE(true, pred))
    return loss

  def wasserstein_loss(self, ori_outputs, gen_outputs):
    d_loss = -tf.reduce_mean(ori_outputs) + tf.reduce_mean(gen_outputs)
    g_loss = -tf.reduce_mean(gen_outputs)
    return d_loss, g_loss

  def g_train(self, imgs):
    noise =tf.random.uniform([self.batch_size, self.code_num], minval=0, maxval=1, dtype=tf.dtypes.float32)

    with tf.GradientTape() as gen_tape:
      gen_imgs = self.gen(noise, training=True)

      ori_outputs = self.disc(imgs, training=False)
      gen_outputs = self.disc(gen_imgs, training=False)

      _, g_loss = self.wasserstein_loss(ori_outputs, gen_outputs)
      self.g_temp.append(g_loss)

    gradients_of_gen = gen_tape.gradient(g_loss, self.gen.trainable_variables)
    self.gen_optimizer.apply_gradients(zip(gradients_of_gen, self.gen.trainable_variables))

  def d_train(self, imgs):
    noise =tf.random.uniform([self.batch_size, self.code_num], minval=0, maxval=1, dtype=tf.dtypes.float32)

    with tf.GradientTape() as disc_tape:
      gen_imgs = self.gen(noise, training=False)

      ori_outputs = self.disc(imgs, training=True)
      gen_outputs = self.disc(gen_imgs, training=True)
      
      d_loss, _ = self.wasserstein_loss(ori_outputs, gen_outputs)
      self.d_temp.append(d_loss)

    gradients_of_disc = disc_tape.gradient(d_loss, self.disc.trainable_variables)    
    self.disc_optimizer.apply_gradients(zip(gradients_of_disc, self.disc.trainable_variables))

  def visualise_batch(self, s_1, epoch):
    gen_img = self.gen(s_1)  
    gen_img = (np.array(gen_img*255, np.uint8))
    fig, axes = plt.subplots(4, 6)
    for idx, img in enumerate(gen_img):
      p, q = idx//6, idx%6
      axes[p, q].imshow(img)
      axes[p, q].axis('off')
    
    save_name = 'drive/My Drive/PatchWGAN/generated_image/'+'image_at_epoch_{:04d}.png'
    plt.savefig(save_name.format(epoch), dpi=200)
    # plt.pause(0.1)
    plt.close('all')

  def loss_vis(self):
    plt.plot(self.g_history, 'b', self.d_history, 'r')
    plt.title('blue: g  red: d')
    plt.savefig('drive/My Drive/PatchWGAN/loss/gan_loss.png')
    plt.close('all')

  def update_loss_history(self):
    self.d_history.append(sum(self.d_temp)/len(self.d_temp))
    self.g_history.append(sum(self.g_temp)/len(self.g_temp))
    self.d_temp = []
    self.g_temp = []

  def __call__(self):
    sample_noise =tf.random.uniform([24, self.code_num], minval=0, maxval=1, dtype=tf.dtypes.float32)
    image_loader = self.image_preparation(self.train_filenames, self.batch_size, self.train_steps)
    self.d_temp = []
    self.g_temp = []
    [w.assign(tf.clip_by_value(w, -0.01, 0.01)) for w in self.disc.variables]

    for epoch in range(self.start_epoch, self.train_epochs+1):
      print ('\nepochs {}'.format(epoch))
      imgs_ds = next(image_loader)

      for steps, imgs in enumerate(imgs_ds):
        print("\r" + 'steps{}'.format(steps+1), end="")
        sys.stdout.flush()

        self.d_train(imgs)
        [w.assign(tf.clip_by_value(w, -0.01, 0.01)) for w in self.disc.variables]

        if steps % self.n_critics == 0:
          self.g_train(imgs)
        
      self.update_loss_history()
                               
      if epoch % self.image_epochs == 0:
        self.visualise_batch(sample_noise, epoch)
        self.loss_vis()

      if epoch % self.checkpoint_epochs == 0:
        print ('\nSaving checkpoint at epoch{}\n\n'.format(epoch))
        self.manager.save()
      
if __name__ == '__main__':
  a = WGAN(img_size = 64,
           code_num = 8,
           batch_size = 256,
           train_epochs = 10000, 
           train_steps = 8, 
           checkpoint_epochs = 100, 
           image_epochs = 10, 
           start_epoch = 1,
           optimizer = RMSprop(lr=5E-7),
           n_critics = 1
           )
  a()

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
epochs 3457
steps8
epochs 3458
steps8
epochs 3459
steps8
epochs 3460
steps8
epochs 3461
steps8
epochs 3462
steps8
epochs 3463
steps8
epochs 3464
steps8
epochs 3465
steps8
epochs 3466
steps8
epochs 3467
steps8
epochs 3468
steps8
epochs 3469
steps8
epochs 3470
steps8
epochs 3471
steps8
epochs 3472
steps8
epochs 3473
steps8
epochs 3474
steps8
epochs 3475
steps8
epochs 3476
steps8
epochs 3477
steps8
epochs 3478
steps8
epochs 3479
steps8
epochs 3480
steps8
epochs 3481
steps8
epochs 3482
steps8
epochs 3483
steps8
epochs 3484
steps8
epochs 3485
steps8
epochs 3486
steps8
epochs 3487
steps8
epochs 3488
steps8
epochs 3489
steps8
epochs 3490
steps8
epochs 3491
steps8
epochs 3492
steps8
epochs 3493
steps8
epochs 3494
steps8
epochs 3495
steps8
epochs 3496
steps8
epochs 3497
steps8
epochs 3498
steps8
epochs 3499
steps8
epochs 3500
steps8
Saving checkpoint at epoch3500



epochs 3501
steps8
epochs 3502
steps8
epochs 3503
steps8
epochs 3