In [None]:
import os
import random
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, Reshape, BatchNormalization, Activation, Conv2D, Conv2DTranspose, LeakyReLU, Flatten, Input, Concatenate
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.optimizers import RMSprop
from tensorflow.data import Dataset
from IPython import display
from tqdm import tqdm
import time
import pandas as pd

In [None]:
tf.config.gpu.set_per_process_memory_growth(True)
AUTOTUNE = tf.data.experimental.AUTOTUNE
HEIGHT, WIDTH, CHANNEL = 128, 128, 3
BATCH_SIZE = 64

In [None]:
types = ['grass', 'fire', 'water', 'electric', 'psychic', 'dragon', 'flying']

types_idx = {typ: idx for idx, typ in enumerate(types)}
df = pd.read_csv('pokemon.csv')[['pokedex_number', 'type1', 'type2']]
df = df[df['type1'].isin(types) | df['type2'].isin(types)]
n = len(types)
def make_cond(r):
    z, t1, t2 = np.zeros(n), r[1], r[2]
    if t1 in types: z[types_idx[t1]] = 1
    if t2 in types: z[types_idx[t2]] = 1
    return z
df['condition'] = df.apply(make_cond, axis=1)
df.set_index('pokedex_number', inplace=True)
COND = df.drop(['type1', 'type2'], axis=1)

In [None]:
class CGAN:
    def __init__(self):
        self.noise_len = 100
        self.noise_dim = (self.noise_len, )
        self.in_dim = (128, 128, 3)
        self.cond_len = len(types)
        self.cond_dim = (self.cond_len, )
        self.gen = self.get_gen()
        self.dis = self.get_dis()
        self.gen_opt = RMSprop(lr=2e-4)
        self.dis_opt = RMSprop(lr=2e-4, clipvalue=.01)
        self.dis_iters = 5
        self.gen_iters = 1
        self.ds, self.ds_len = WGAN.process_data()
        self.dsi = iter(self.ds)
        self.epochs = int(6e4)
        self.num_batches = self.ds_len // BATCH_SIZE
        self.loss = []

    def get_gen(self):
        inp = Input(shape=self.noise_dim)
        cond = Input(shape=self.cond_dim)
        m = Concatenate()([inp, cond])
        m = Dense(8192, kernel_initializer=TruncatedNormal(0, .02))(m)
        m = Reshape((4, 4, 512))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        m = Conv2DTranspose(256, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        m = Conv2DTranspose(128, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        m = Conv2DTranspose(64, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        m = Conv2DTranspose(32, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(m)
        m = BatchNormalization(-1, .9, 1e-5)(m)
        m = Activation('relu')(m)
        out = Conv2DTranspose(3, (5, 5), (2, 2), 'same', activation='sigmoid', kernel_initializer=TruncatedNormal(0, .02))(m)
        return Model(inputs=[inp, cond], outputs=out)

    def get_dis(self):
        inp = Input(shape=self.in_dim)
        cond = Input(shape=self.cond_dim)
        x = Conv2D(64, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(stddev=.02))(inp)
        x = LeakyReLU(.2)(x)
        x = Conv2D(128, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(x)
        x = BatchNormalization(-1, .9, 1e-5)(x)
        x = LeakyReLU(.2)(x)
        x = Conv2D(256, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(x)
        x = BatchNormalization(-1, .9, 1e-5)(x)
        x = LeakyReLU(.2)(x)
        x = Conv2D(512, (5, 5), (2, 2), 'same', kernel_initializer=TruncatedNormal(0, .02))(x)
        x = BatchNormalization(-1, .9, 1e-5)(x)
        x = LeakyReLU(.2)(x)
        x = Flatten()(x)
        m = Concatenate()([x, cond])
        m = Dense(512, activation='relu')(m)
        out = Dense(1, kernel_initializer=TruncatedNormal(0, .02))(m)
        return Model(inputs=[inp, cond], outputs=out)

    def train(self):
        for epoch in range(self.epochs):
            start = time.time()
            dis_loss, gen_loss = self.train_step()
            self.loss.append((gen_loss, dis_loss))
            display.clear_output(wait=True)
            noise = tf.random.uniform([64, self.noise_len], -1., 1.)
            seed = [noise, self.get_conds()]
            self.gen_save(epoch + 1, seed)
            print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
            print('Epoch: %d, Dis. Loss: %.4f, Gen. Loss: %.4f' % (epoch+1, dis_loss, gen_loss))
            if (epoch + 1) % 2 == 0:
                self.gen.save('./weights/gen.h5')
                self.dis.save('./weights/dis.h5')
            if (epoch + 1) % 100 == 0:
                plt.plot(self.loss)
                plt.title('Training Loss for Conditional WGAN')
                plt.ylabel('Training Loss')
                plt.xlabel('Epoch')
                plt.legend(['Generator', 'Discriminator'], loc='upper_left')
                plt.savefig('train_loss.png', bbox_inches='tight')
                plt.clf()
        display.clear_output(wait=True)
        WGAN.gen_save(self.gen, self.epochs, seed)

    # @tf.function
    def train_step(self):
        for _ in range(self.num_batches):
            # train discriminator
            noise = tf.random.uniform([BATCH_SIZE, self.noise_len], -1., 1.)
            for _ in range(self.dis_iters):
                imgs, conds = next(self.dsi)
                gen_imgs = self.gen([noise, conds], training=True)
                with tf.GradientTape() as dis_t:
                    real_out = self.dis([imgs, conds], training=True)
                    fake_out = self.dis([gen_imgs, conds], training=True)
                    dis_loss = tf.reduce_mean(fake_out) - tf.reduce_mean(real_out)
                dis_grad = dis_t.gradient(dis_loss, self.dis.trainable_variables)
                self.dis_opt.apply_gradients(zip(dis_grad, self.dis.trainable_variables))
            # train generator
            for _ in range(self.gen_iters):
                conds = self.get_conds()
                with tf.GradientTape() as gen_t:
                    gen_imgs = self.gen([noise, conds], training=True)
                    fake_out = self.dis([gen_imgs, conds], training=True)
                    gen_loss = -tf.reduce_mean(fake_out)
                gen_grad = gen_t.gradient(gen_loss, self.gen.trainable_variables)
                self.gen_opt.apply_gradients(zip(gen_grad, self.gen.trainable_variables))
        return dis_loss, gen_loss

    @tf.function
    def get_conds(self):
        nr = 2 if random.random() > 0.4 else 1
        idx = tf.random.uniform([BATCH_SIZE, nr], 0, self.cond_len, dtype=tf.int32)
        oh = tf.one_hot(idx, self.cond_len)
        return tf.clip_by_value(tf.reduce_sum(oh, 1), 0, 1)

    def gen_save(self, epoch, test_inp):
        pred = self.gen(test_inp, training=False)
        fig = plt.figure(figsize=(8,8))
        for i in range(pred.shape[0]):
            plt.subplot(8, 8, i+1)
            plt.imshow(pred[i])
            plt.axis('off')
        if epoch % 50 == 0:
            plt.savefig('gen_imgs/img_at_ep_{:06d}.png'.format(epoch), bbox_inches='tight')
        plt.show()

    @staticmethod
    def process_data():
        current_dir = os.getcwd()
        pokemon_dir = os.path.join(current_dir, 'data_big')
        images, conds = [], []
        for each in os.listdir(pokemon_dir):
            e = os.path.splitext(each)
            if (e[1] == '.jpg') and (int(e[0]) in COND.index):
                images.append(os.path.join(pokemon_dir,each))
                conds.append(tf.convert_to_tensor(COND.loc[int(e[0])][0], dtype=tf.float32))
        image_paths = Dataset.from_tensor_slices((images, conds))
        def load_preprocess(path, cond):
            content = tf.io.read_file(path)
            image = tf.image.decode_jpeg(content, channels=CHANNEL)
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_brightness(image, max_delta=.1)
            image = tf.image.random_contrast(image, lower=.9, upper=1.1)
            size = [HEIGHT, WIDTH]
            image = tf.image.resize(image, size)
            image.set_shape([HEIGHT,WIDTH,CHANNEL])
            image = tf.cast(image, tf.float32)
            image /= 255.
            return image, cond
        all_images = image_paths.map(load_preprocess, num_parallel_calls=AUTOTUNE)
        num_images = len(images)
        all_images = all_images.shuffle(buffer_size=num_images)
        all_images = all_images.repeat()
        all_images = all_images.batch(BATCH_SIZE)
        all_images = all_images.prefetch(buffer_size=AUTOTUNE)
        return all_images, num_images

In [None]:
cgan = CGAN()

In [None]:
cgan.train()