# PokéGAN
This work is based on *forecore*'s work, avaliable on https://github.com/forcecore/Keras-GAN-Animeface-Character

In [1]:
# Imports

## Data import related imports
import os
import glob
import h5py
import numpy as np
import scipy.misc
import random
import cv2

import tensorflow as tf
tf.logging.set_verbosity(tf.logging.ERROR)

## Keras packages
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input, InputSpec, Layer, Dense, Activation, Flatten, Reshape, Dropout, BatchNormalization, Conv2D, Conv2DTranspose, UpSampling2D, LeakyReLU
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras import models
from tensorflow.keras.optimizers import Adam, Adagrad, Adadelta, Adamax, SGD
from tensorflow.keras.callbacks import CSVLogger

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
# Hyperparameters and other important variables

# Dataset size: use positive number to sample subset of the full dataset
dataset_sz = -1

# Archive outputs of training here for animating later.
anim_dir = "anim"

# Images size we will work on. (sz, sz, 4)
sz = 64

# Alpha, used by leaky relu of D and G networks.
alpha_D = 0.2
alpha_G = 0.2

# Batch size, during training.
batch_sz = 64

# Length of the noise vector to generate the images from.
# Latent space z
noise_shape = (1, 1, 100)

# GAN training can be ruined any moment if not careful.
# Archive some snapshots in this directory.
snapshot_dir = "./snapshots"

# Dropout probability
dropout = 0.3

# Noisy label magnitude
label_noise = 0.1

# History to keep. Slower training but higher quality.
history_sz = 12

genw = "gen.hdf5"
discw = "disc.hdf5"

# Weight initialization function.
#kernel_initializer = 'Orthogonal'
#kernel_initializer = 'RandomNormal'
# Same as default in Keras, but good for GAN, says
# https://github.com/gheinrich/DIGITS-GAN/blob/master/examples/weight-init/README.md#experiments-with-lenet-on-mnist
kernel_initializer = 'glorot_uniform'

# DCGAN paper suggests 0.5.
adam_beta = 0.5

# BatchNormalization.
bn_momentum = 0.3

In [3]:
# Functions used to generate h5 file containing data

def normalize4gan(im):
    '''
    Convert colorspace and
    cale the input in [-1, 1] range, as described in ganhacks
    '''
    im = im.astype(np.float32)
    im /= 128.0
    im -= 1.0 # now in [-1, 1]
    return im

def denormalize4gan(im):
    '''
    Does opposite of normalize4gan:
    [-1, 1] to [0, 255].
    Warning: input im is modified in-place!
    '''
    im += 1.0 # in [0, 2]
    im *= 127.0 # in [0, 255]
    return im.astype(np.uint8)

def make_hdf5(ofname, wildcard):
    '''
    Preprocess files given by wildcard and save them in hdf5 file, as ofname.
    '''
    pool = list(glob.glob(wildcard))
    if dataset_sz <= 0:
        fnames = pool
    else:
        fnames = []
        for i in range(dataset_sz):
            # possible duplicate but don't care
            fnames.append(random.choice(pool))

    with h5py.File(ofname, "w") as f:
        pokemons = f.create_dataset("pokemons", (len(fnames), sz, sz, 4), dtype='f')

        for i, fname in enumerate(fnames):
            print(fname)
            im = scipy.misc.imread(fname, mode='RGBA') # some have alpha channel
            im = scipy.misc.imresize(im, (sz, sz))
            pokemons[i] = normalize4gan(im)
            
def test(hdff):
    '''
    Reads in hdf file and check if pixels are scaled in [-1, 1] range.
    '''
    with h5py.File(hdff, "r") as f:
        X = f.get("pokemons")
        print(np.min(X[:,:,:,0]))
        print(np.max(X[:,:,:,0]))
        print(np.min(X[:,:,:,1]))
        print(np.max(X[:,:,:,1]))
        print(np.min(X[:,:,:,2]))
        print(np.max(X[:,:,:,2]))
        print("Dataset size:", len(X))
        assert np.max(X) <= 1.0
        assert np.min(X) >= -1.0

In [4]:
# Create file, based on 'training_data' contents
make_hdf5("data.hdf5", "./training_data/*.png")

# Check consistency of file
test("data.hdf5")

`imread` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imread`` instead.
`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``skimage.transform.resize`` instead.


./training_data/718_f3.png
./training_data/462.png
./training_data/758.png
./training_data/019.png
./training_data/718.png
./training_data/025.png
./training_data/435.png
./training_data/325.png
./training_data/289.png
./training_data/755.png
./training_data/539.png
./training_data/252.png
./training_data/164.png
./training_data/645.png
./training_data/399.png
./training_data/501.png
./training_data/457.png
./training_data/517.png
./training_data/277.png
./training_data/508.png
-1.0
0.9921875
-1.0
0.9921875
-1.0
0.9921875
Dataset size: 20


_____

In [5]:
class MinibatchDiscrimination(Layer):
    """Concatenates to each sample information about how different the input
    features for that sample are from features of other samples in the same
    minibatch, as described in Salimans et. al. (2016). Useful for preventing
    GANs from collapsing to a single output. When using this layer, generated
    samples and reference samples should be in separate batches.
    # Example
    ```python
        # apply a convolution 1d of length 3 to a sequence with 10 timesteps,
        # with 64 output filters
        model = Sequential()
        model.add(Convolution1D(64, 3, border_mode='same', input_shape=(10, 32)))
        # now model.output_shape == (None, 10, 64)
        # flatten the output so it can be fed into a minibatch discrimination layer
        model.add(Flatten())
        # now model.output_shape == (None, 640)
        # add the minibatch discrimination layer
        model.add(MinibatchDiscrimination(5, 3))
        # now model.output_shape = (None, 645)
    ```
    # Arguments
        nb_kernels: Number of discrimination kernels to use
            (dimensionality concatenated to output).
        kernel_dim: The dimensionality of the space where closeness of samples
            is calculated.
        init: name of initialization function for the weights of the layer
            (see [initializations](../initializations.md)),
            or alternatively, Theano function to use for weights initialization.
            This parameter is only relevant if you don't pass a `weights` argument.
        weights: list of numpy arrays to set as initial weights.
        W_regularizer: instance of [WeightRegularizer](../regularizers.md)
            (eg. L1 or L2 regularization), applied to the main weights matrix.
        activity_regularizer: instance of [ActivityRegularizer](../regularizers.md),
            applied to the network output.
        W_constraint: instance of the [constraints](../constraints.md) module
            (eg. maxnorm, nonneg), applied to the main weights matrix.
        input_dim: Number of channels/dimensions in the input.
            Either this argument or the keyword argument `input_shape`must be
            provided when using this layer as the first layer in a model.
    # Input shape
        2D tensor with shape: `(samples, input_dim)`.
    # Output shape
        2D tensor with shape: `(samples, input_dim + nb_kernels)`.
    # References
        - [Improved Techniques for Training GANs](https://arxiv.org/abs/1606.03498)
    """

    def __init__(self, nb_kernels, kernel_dim, init='glorot_uniform', weights=None,
                 W_regularizer=None, activity_regularizer=None,
                 W_constraint=None, input_dim=None, **kwargs):
        self.init = initializers.get(init)
        self.nb_kernels = nb_kernels
        self.kernel_dim = kernel_dim
        self.input_dim = input_dim

        self.W_regularizer = regularizers.get(W_regularizer)
        self.activity_regularizer = regularizers.get(activity_regularizer)

        self.W_constraint = constraints.get(W_constraint)

        self.initial_weights = weights
        self.input_spec = [InputSpec(ndim=2)]

        if self.input_dim:
            kwargs['input_shape'] = (self.input_dim,)
        super(MinibatchDiscrimination, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 2

        input_dim = input_shape[1]
        self.input_spec = [InputSpec(dtype=K.floatx(),
                                     shape=(None, input_dim))]

        self.W = self.add_weight(shape=(self.nb_kernels, input_dim, self.kernel_dim),
            initializer=self.init,
            name='kernel',
            regularizer=self.W_regularizer,
            trainable=True,
            constraint=self.W_constraint)

        # Set built to true.
        super(MinibatchDiscrimination, self).build(input_shape)

    def call(self, x, mask=None):
        activation = K.reshape(K.dot(x, self.W), (-1, self.nb_kernels, self.kernel_dim))
        diffs = K.expand_dims(activation, 4) - K.expand_dims(K.permute_dimensions(activation, [1, 2, 0]), 0)
        abs_diffs = K.sum(K.abs(diffs), axis=2)
        minibatch_features = K.sum(K.exp(-abs_diffs), axis=2)
        return K.concatenate([x, minibatch_features], 1)

    def compute_output_shape(self, input_shape):
        assert input_shape and len(input_shape) == 2
        return input_shape[0], input_shape[1]+self.nb_kernels

    def get_config(self):
        config = {'nb_kernels': self.nb_kernels,
                  'kernel_dim': self.kernel_dim,
                  'init': self.init.__name__,
                  'W_regularizer': self.W_regularizer.get_config() if self.W_regularizer else None,
                  'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None,
                  'W_constraint': self.W_constraint.get_config() if self.W_constraint else None,
                  'input_dim': self.input_dim}
        base_config = super(MinibatchDiscrimination, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

In [6]:
# Models

def build_enc(shape) :
    return build_discriminator(shape, build_disc=False)

def build_discriminator( shape, build_disc=True ) :
    '''
    Build discriminator.
    Set build_disc=False to build an encoder network to test
    the encoding/discrimination capability with autoencoder...
    '''
    def conv2d( x, filters, shape=(4, 4), **kwargs ) :
        x = Conv2D( filters, shape, strides=(2, 2),
            padding='same',
            kernel_initializer=kernel_initializer,
            **kwargs )( x )
        x = BatchNormalization(momentum=bn_momentum)( x )
        x = LeakyReLU(alpha=alpha_D)( x )
        return x

    pokemon = Input( shape=shape )
    x = pokemon

    # Warning: Don't batchnorm the first set of Conv2D.
    x = Conv2D( 64, (4, 4), strides=(2, 2),
        padding='same',
        kernel_initializer=kernel_initializer )( x )
    x = LeakyReLU(alpha=alpha_D)( x )
    # 32x32

    x = conv2d( x, 128 )
    # 16x16

    x = conv2d( x, 256 )
    # 8x8

    x = conv2d( x, 512 )
    # 4x4

    if build_disc:
        x = Flatten()(x)
        # 1 when "real", 0 when "fake".
        x = Dense(1, activation='sigmoid',
            kernel_initializer=kernel_initializer)( x )
        return models.Model( inputs=pokemon, outputs=x )
    else:
        # build encoder.
        x = Conv2D(noise_shape[2], (4, 4), activation='tanh')(x)
        return models.Model( inputs=pokemon, outputs=x )
    
def build_gen( shape ) :
    def deconv2d( x, filters, shape=(4, 4) ) :
        '''
        Conv2DTransposed gives me checkerboard artifact...
        Select one of the 3.
        '''

        x = Conv2DTranspose( filters, shape, padding='same',
            strides=(2, 2), kernel_initializer=kernel_initializer )(x)

        x = BatchNormalization(momentum=bn_momentum)( x )
        x = LeakyReLU(alpha=alpha_G)( x )
        return x

    noise = Input( shape=noise_shape )
    x = noise
    # 1x1x256

    x= Conv2DTranspose( 512, (4, 4),
        kernel_initializer=kernel_initializer )(x)
    x = BatchNormalization(momentum=bn_momentum)( x )
    x = LeakyReLU(alpha=alpha_G)( x )
    # 4x4
    x = deconv2d( x, 256 )
    # 8x8
    x = deconv2d( x, 128 )
    # 16x16
    x = deconv2d( x, 64 )
    # 32x32

    # Extra layer
    x = Conv2D( 64, (3, 3), padding='same',
        kernel_initializer=kernel_initializer )( x )
    x = BatchNormalization(momentum=bn_momentum)( x )
    x = LeakyReLU(alpha=alpha_G)( x )
    # 32x32

    x= Conv2DTranspose( 4, (4, 4), padding='same', activation='tanh',
        strides=(2, 2), kernel_initializer=kernel_initializer )(x)
    # 64x64

    return models.Model(inputs=noise, outputs=x)

____

In [7]:
def sample_pokemon( pokemons ):
    reals = []
    for i in range( batch_sz ) :
        j = random.randrange( len(pokemons) )
        pokemon = pokemons[ j ]
        reals.append( pokemon )
    reals = np.array(reals)
    return reals

def binary_noise(cnt):
    # Note about noise range.
    # 0, 1 noise vs -1, 1 noise. -1, 1 seems to be better and stable.

    noise = label_noise * np.random.ranf((cnt,) + noise_shape) # [0, 0.1]
    noise -= 0.05 # [-0.05, 0.05]
    noise += np.random.randint(0, 2, size=((cnt,) + noise_shape))

    noise -= 0.5
    noise *= 2
    return noise

def sample_fake( gen ) :
    noise = binary_noise(batch_sz)
    fakes = gen.predict(noise)
    return fakes, noise

def dump_batch(imgs, cnt, ofname):
    '''
    Merges cnt x cnt generated images into one big image.
    Use the command
    $ feh dump.png --reload 1
    to refresh image peroidically during training!
    '''
    assert batch_sz >= cnt * cnt

    rows = []

    for i in range( cnt ) :
        cols = []
        for j in range(cnt*i, cnt*i+cnt):
            cols.append( imgs[j] )
        rows.append( np.concatenate(cols, axis=1) )

    alles = np.concatenate( rows, axis=0 )
    alles = denormalize4gan( alles )
    scipy.misc.imsave(ofname, alles)
    
def build_networks():
    shape = (sz, sz, 4)

    dopt = Adam(lr=0.0002, beta_1=adam_beta)
    opt  = Adam(lr=0.0001, beta_1=adam_beta)

    # generator part
    gen = build_gen( shape )
    # loss function doesn't seem to matter for this one, as it is not directly trained
    gen.compile(optimizer=opt, loss='binary_crossentropy')
    gen.summary()

    # discriminator part
    disc = build_discriminator( shape )
    disc.compile(optimizer=dopt, loss='binary_crossentropy')
    disc.summary()

    # GAN stack
    # https://ctmakro.github.io/site/on_learning/fast_gan_in_keras.html is the faster way.
    # Here, for simplicity, I use slower way (slower due to duplicate computation).
    noise = Input( shape=noise_shape )
    gened = gen( noise )
    result = disc( gened )
    gan = models.Model( inputs=noise, outputs=result )
    gan.compile(optimizer=opt, loss='binary_crossentropy')
    gan.summary()

    return gen, disc, gan

def train_autoenc( dataf ):
    '''
    Train an autoencoder first to see if your network is large enough.
    '''
    f = h5py.File(dataf, 'r')
    pokemons = f.get('pokemons')

    opt = Adam(lr=0.001)

    shape = (sz, sz, 4)
    enc = build_enc( shape )
    enc.compile(optimizer=opt, loss='mse')
    enc.summary()

    # generator part
    gen = build_gen( shape )
    # generator is not directly trained. Optimizer and loss doesn't matter too much.
    gen.compile(optimizer=opt, loss='mse')
    gen.summary()

    pokemon = Input( shape=shape )
    vector = enc(pokemon)
    recons = gen(vector)
    autoenc = models.Model( inputs=pokemon, outputs=recons )
    autoenc.compile(optimizer=opt, loss='mse')

    epoch = 0
    while epoch < 200 :
        for i in range(10) :
            reals = sample_pokemon(pokemons)
            fakes, noises = sample_fake( gen )
            loss = autoenc.train_on_batch( reals, reals )
            epoch += 1
            print(epoch, loss)
        fakes = autoenc.predict(reals)
        dump_batch(fakes, 4, "fakes.png")
        dump_batch(reals, 4, "reals.png")
    gen.save_weights(genw)
    enc.save_weights(discw)
    print("Saved", genw, discw)

def load_weights(model, wf):
    '''
    I find error message in load_weights hard to understand sometimes.
    '''
    try:
        model.load_weights(wf)
    except:
        print("failed to load weight, network changed or corrupt hdf5", wf, file=sys.stderr)
        sys.exit(1)
        
def train_gan( dataf ) :
    gen, disc, gan = build_networks()

    # Uncomment these, if you want to continue training from some snapshot.
    # (or load pretrained generator weights)
    #load_weights(gen, genw)
    #load_weights(disc, discw)

    logger = CSVLogger('loss.csv') # yeah, you can use callbacks independently
    logger.on_train_begin() # initialize csv file
    with h5py.File( dataf, 'r' ) as f :
        pokemons = f.get( 'pokemons' )
        run_batches(gen, disc, gan, pokemons, logger, range(5000))
    logger.on_train_end()
    
def run_batches(gen, disc, gan, pokemons, logger, itr_generator):
    history = [] # need this to prevent G from shifting from mode to mode to trick D.
    train_disc = True
    for batch in itr_generator:
        # Using soft labels here.
        lbl_fake = label_noise * np.random.ranf(batch_sz)
        lbl_real = 1 - label_noise * np.random.ranf(batch_sz)

        fakes, noises = sample_fake( gen )
        reals = sample_pokemon( pokemons )
        # Add noise...
        # My dataset works without this.
        #reals += 0.5 * np.exp(-batch/100) * np.random.normal( size=reals.shape )

        if batch % 10 == 0 :
            if len(history) > history_sz:
                history.pop(0) # evict oldest
            history.append( (reals, fakes) )

        gen.trainable = False
        #for reals, fakes in history:
        d_loss1 = disc.train_on_batch( reals, lbl_real )
        d_loss0 = disc.train_on_batch( fakes, lbl_fake )
        gen.trainable = True
       
        #if d_loss1 > 15.0 or d_loss0 > 15.0 :
        # artificial training of one of G or D based on
        # statistics is not good at all.

        # pretrain train discriminator only
        if batch < 20 :
            print( batch, "d0:{} d1:{}".format( d_loss0, d_loss1 ) )
            continue

        disc.trainable = False
        g_loss = gan.train_on_batch( noises, lbl_real ) # try to trick the classifier.
        disc.trainable = True

        # To escape this loop, both D and G should be trained so that
        # D begins to mark everything that's wrong that G has done.
        # Otherwise G will only change locally and fail to escape the minima.
        #train_disc = True if g_loss < 15 else False

        print( batch, "d0:{} d1:{}   g:{}".format( d_loss0, d_loss1, g_loss ) )

        # save weights every 10 batches
        if batch % 10 == 0 and batch != 0 :
            end_of_batch_task(batch, gen, disc, reals, fakes)
            row = {"d_loss0": d_loss0, "d_loss1": d_loss1, "g_loss": g_loss}
#             logger.on_epoch_end(batch, row)
            
_bits = binary_noise(batch_sz)
def end_of_batch_task(batch, gen, disc, reals, fakes):
    try :
        # Dump how the generator is doing.
        # Animation dump
        dump_batch(reals, 4, "reals.png")
        dump_batch(fakes, 4, "fakes.png") # to check how noisy the image is
        frame = gen.predict(_bits)
        animf = os.path.join(anim_dir, "frame_{:08d}.png".format(int(batch/10)))
        dump_batch(frame, 4, animf)
        dump_batch(frame, 4, "frame.png")

        serial = int(batch / 10) % 10
        prefix = os.path.join(snapshot_dir, str(serial) + ".")

        print("Saving weights", serial)
        gen.save_weights(prefix + genw)
        disc.save_weights(prefix + discw)
    except KeyboardInterrupt :
        print("Saving, don't interrupt with Ctrl+C!", serial)
        # recursion to surely save everything haha
        end_of_batch_task(batch, gen, disc, reals, fakes)
        raise
        
def generate(genw,cnt):
    shape = (sz, sz, 4)
    gen = build_gen(shape)
    gen.compile(optimizer='sgd', loss='mse')
    load_weights(gen, genw)

    generated = gen.predict(binary_noise(batch_sz))
    # Unoffset, in batch.
    # Must convert back to unit8 to stop color distortion.
    generated = denormalize4gan(generated)

    for i in range(cnt):
        ofname = "{:04d}.png".format(i)
        scipy.misc.imsave(ofname, generated[i])

In [8]:
if not os.path.exists(snapshot_dir):
    os.mkdir(snapshot_dir)
if not os.path.exists(anim_dir):
    os.mkdir(anim_dir)

train_gan( "data.hdf5" )

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 1, 1, 100)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 4, 4, 512)         819712    
_________________________________________________________________
batch_normalization (BatchNo (None, 4, 4, 512)         2048      
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 4, 4, 512)         0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 8, 8, 256)         2097408   
_________________________________________________________________
batch_normalization_1 (Batch (None, 8, 8, 256)         1024      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 8, 8, 256)         0         
__________

`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.


21 d0:1.8590009212493896 d1:0.3443704843521118   g:6.792106628417969
22 d0:0.3906201720237732 d1:0.7925941348075867   g:1.3193854093551636
23 d0:3.641796827316284 d1:0.6691110730171204   g:9.67149543762207
24 d0:0.5376254320144653 d1:0.5493295788764954   g:10.408368110656738
25 d0:0.3834400177001953 d1:0.3876746594905853   g:5.999385833740234
26 d0:0.4735194444656372 d1:0.3094353973865509   g:5.386129379272461
27 d0:0.46801501512527466 d1:0.234207421541214   g:6.462993144989014
28 d0:0.22067303955554962 d1:0.2752486765384674   g:3.839236259460449
29 d0:0.9214929938316345 d1:0.31115788221359253   g:9.444129943847656
30 d0:0.3560725450515747 d1:0.393726110458374   g:6.910520553588867
Saving weights 3
31 d0:0.2147250473499298 d1:0.3370794653892517   g:3.4323973655700684
32 d0:1.3456591367721558 d1:0.39064645767211914   g:9.741103172302246
33 d0:0.4307938814163208 d1:0.4005632996559143   g:8.491881370544434
34 d0:0.2533150315284729 d1:0.260436475276947   g:4.456714153289795
35 d0:0.6575123

KeyboardInterrupt: 