In [None]:
import os
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Reshape, BatchNormalization, Activation, Conv2D, Conv2DTranspose, LeakyReLU, Flatten, Input, Concatenate, UpSampling2D
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.optimizers import Adam
from tensorflow.data import Dataset
from IPython import display
from tqdm import tqdm
import time
import pandas as pd

In [None]:
tf.config.gpu.set_per_process_memory_growth(True)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
HEIGHT, WIDTH, CHANNEL = 128, 128, 3
BATCH_SIZE = 64

In [None]:
class AnimeDRAGAN:
    def __init__(self):
        self.noise_len = 100
        self.noise_dim = (self.noise_len, )
        self.in_dim = (128, 128, 3)
        self.gen = self.get_gen()
        self.dis = self.get_dis()
        self.gen_opt = Adam(2e-4, .5, .9)
        self.dis_opt = Adam(2e-4, .5, .9)
        self.dis_iters = 5
        self.gen_iters = 1
        self.ds, self.ds_len = WGAN.process_data()
        self.dsi = iter(self.ds)
        self.epochs = int(6e4)
        self.num_batches = self.ds_len // BATCH_SIZE
        self.weight_dir = 'weights/'
        self.weight_pref = 'animedragan'
        if not os.path.exists(self.weight_dir):
            os.makedirs(self.weight_dir)
        self.img_dir = 'animedragan_gen_imgs/'
        if not os.path.exists(self.img_dir):
            os.makedirs(self.img_dir)
        self.loss = []

    def gen_conv_block(self, x):
        m = Conv2D(64, (3, 3), (1, 1), 'same', kernel_initializer=TruncatedNormal(0, .02))(x)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        m = Conv2D(64, (3, 3), (1, 1), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        return m + x
    
    def dis_conv_block(self, x, n):
        m = Conv2D(n, (3, 3), (1, 1), 'same', kernel_initializer=TruncatedNormal(0, .02))(x)
        m = LeakyReLU(.2)(m)
        m = Conv2D(n, (3, 3), (1, 1), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = m + x
        return LeakyReLU(.2)(m)
    
    def upsampling(self, x):
        m = Conv2D(256, (3, 3), (1, 1), 'same', kernel_initializer=TruncatedNormal(0, .02))(x)
        m = UpSampling2D(2)(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        return Activation('relu')(m)
    
    def get_gen(self):
        inp = Input(shape=self.noise_dim)
        m = Dense(16384, kernel_initializer=TruncatedNormal(0, .02))(inp)
        m = Reshape((16, 16, 64))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        x = m
        for _ in range(16): m = self.gen_conv_block(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        m = m + x
        for _ in range(3): m = self.upsampling(m)
        out = Conv2D(3, (9, 9), (1, 1), 'same', activation='sigmoid', kernel_initializer=TruncatedNormal(0, .02))(m)
        return Model(inputs=inp, outputs=out)

    def get_dis(self):
        inp = Input(shape=self.in_dim)
        m = Conv2D(64, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(stddev=.02))(inp)
        m = LeakyReLU(.2)(m)
        m = self.dis_conv_block(m, 64)
        m = Conv2D(128, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = LeakyReLU(.2)(m)
        m = self.dis_conv_block(m, 128)
        m = Conv2D(256, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = LeakyReLU(.2)(m)
        m = self.dis_conv_block(m, 256)
        m = Conv2D(512, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = LeakyReLU(.2)(m)
        m = Flatten()(m)
        out = Dense(1, kernel_initializer=TruncatedNormal(0, .02))(m)
        return Model(inputs=inp, outputs=out)

    def train(self):
        for epoch in range(self.epochs):
            start = time.time()
            dis_loss, gen_loss = self.train_step()
            self.loss.append((gen_loss, dis_loss))
            display.clear_output(wait=True)
            seed = tf.random.normal([64, self.noise_len])
            self.gen_save(epoch + 1, seed)
            print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
            print('Epoch: %d, Dis. Loss: %.4f, Gen. Loss: %.4f' % (epoch+1, dis_loss, gen_loss))
            if (epoch + 1) % 10 == 0:
                self.gen.save_weights(self.weight_dir + self.weight_pref + '_gen.h5')
                self.dis.save_weights(self.weight_dir + self.weight_pref + '_dis.h5')
            if (epoch + 1) % 10 == 0:
                plt.plot(self.loss)
                plt.title('Training Loss for Anime DRAGAN')
                plt.ylabel('Training Loss')
                plt.xlabel('Epoch')
                plt.legend(['Generator', 'Discriminator'], loc='upper_left')
                plt.savefig('train_loss.png', bbox_inches='tight')
                plt.clf()
        display.clear_output(wait=True)
        self.gen_save(self.epochs, seed)
    
    def train_step(self):
        for _ in range(self.num_batches):
            noise = tf.random.normal([BATCH_SIZE, self.noise_len])
            imgs = next(self.dsi)
            std = tf.sqrt(tf.nn.moments(imgs, [0])[1])
            U = tf.random.uniform(self.in_dim, 0, .5)
            imgs_p = imgs + std * U
            a = tf.random.uniform([BATCH_SIZE, 1, 1, 1], 0, 1)
            imgs_h = a * imgs + (1 - a) * imgs_p
            with tf.GradientTape() as gen_tape, tf.GradientTape() as dis_tape:
                real_out = self.dis(imgs, training=True)
                gen_imgs = self.gen(noise, training=True)
                fake_out = self.dis(gen_imgs, training=True)
                with tf.GradientTape() as inner_tape:
                    inner_tape.watch(imgs_h)
                    realh_out = self.dis(imgs_h, training=True)
                dish_grad = inner_tape.gradient(realh_out, [imgs_h])[0]
                dis_loss =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=tf.ones_like(real_out),
                                logits=real_out
                            ))
                dis_loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=tf.zeros_like(fake_out),
                                logits=fake_out
                            ))              
                slope = tf.sqrt(tf.reduce_sum(tf.square(dish_grad), axis=[1, 2, 3]))
                dis_loss += 10 * tf.reduce_mean((slope-1.)**2)
                gen_loss =  tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
                                labels=tf.ones_like(fake_out),
                                logits=fake_out
                            ))
            gen_grad = gen_tape.gradient(gen_loss, self.gen.trainable_variables)
            dis_grad = dis_tape.gradient(dis_loss, self.dis.trainable_variables)
            self.gen_opt.apply_gradients(zip(gen_grad, self.gen.trainable_variables))
            self.dis_opt.apply_gradients(zip(dis_grad, self.dis.trainable_variables))
        return dis_loss, gen_loss

    def gen_save(self, epoch, test_inp):
        pred = self.gen(test_inp, training=False)
        fig = plt.figure(figsize=(8,8))
        for i in range(pred.shape[0]):
            plt.subplot(8, 8, i+1)
            plt.imshow(pred[i])
            plt.axis('off')
        if epoch % 50 == 0:
            plt.savefig(self.img_dir + 'img_at_ep_{:06d}.png'.format(epoch), bbox_inches='tight')
        plt.show()

    @staticmethod
    def process_data():
        current_dir = os.getcwd()
        pokemon_dir = os.path.join(current_dir, 'data_huge')
        images = []
        for each in os.listdir(pokemon_dir):
            images.append(os.path.join(pokemon_dir,each))
        image_paths = Dataset.from_tensor_slices(images)
        def load_preprocess(path):
            content = tf.io.read_file(path)
            image = tf.image.decode_jpeg(content, channels=CHANNEL)
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_brightness(image, max_delta=.1)
            image = tf.image.random_contrast(image, lower=.9, upper=1.1)
            size = [HEIGHT, WIDTH]
            image = tf.image.resize(image, size)
            image.set_shape([HEIGHT,WIDTH,CHANNEL])
            image = tf.cast(image, tf.float32)
            image /= 255.
            return image
        all_images = image_paths.map(load_preprocess, num_parallel_calls=AUTOTUNE)
        num_images = len(images)
        all_images = all_images.shuffle(buffer_size=num_images)
        all_images = all_images.repeat()
        all_images = all_images.batch(BATCH_SIZE)
        all_images = all_images.prefetch(buffer_size=AUTOTUNE)
        return all_images, num_images

In [None]:
agan = AnimeDRAGAN()

In [None]:
agan.gen.summary()

In [None]:
agan.dis.summary()

In [None]:
agan.train()