In [1]:
print('\nLoading libs...')

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math


Loading libs...


  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [8]:
def generator_model():
    print('Initializing GENERATOR\n')
    model = Sequential()
    model.add(Dense(input_dim=100, units=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, (5, 5), padding='same'))
    model.add(Activation('tanh'))
    return model

In [3]:
def discriminator_model():
    print('Initializing DISCRIMINATOR\n')
    model = Sequential()
    model.add(
            Conv2D(64, (5, 5),
            padding='same',
            input_shape=(28, 28, 1))
            )
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(128, (5, 5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [4]:
def generator_containing_discriminator(g, d):
    print('Initializing GENERATOR contatining DISCRIMINATOR\n')
    model = Sequential()
    model.add(g)
    d.trainable = False
    model.add(d)
    return model

In [20]:
def combine_images(generated_images):
    print('Combining images...\n')
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0], width*shape[1]),
                     dtype=generated_images.dtype)
    for index, img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \
            img[:, :, 0]
    return image

In [32]:
def train(bs=8, n_epochs=100):
    print('Training...\n')
    (X_train, y_train), (X_test, y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train[:, :, :, None]
    X_test = X_test[:, :, :, None]
    # X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
    d = discriminator_model()
    g = generator_model()
    d_on_g = generator_containing_discriminator(g, d)
    d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
    d.trainable = True
    d.compile(loss='binary_crossentropy', optimizer=d_optim)
    for epoch in range(n_epochs):
        print("Epoch is", epoch)
        print("Number of batches", int(X_train.shape[0]/bs))
        for index in range(int(X_train.shape[0]/bs)):
            noise = np.random.uniform(-1, 1, size=(bs, 100))
            image_batch = X_train[index*bs:(index+1)*bs]
            generated_images = g.predict(noise, verbose=0)
            if index % 20 == 0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save(
                    str(epoch)+"_"+str(index)+".png")
            X = np.concatenate((image_batch, generated_images))
            y = [1] * bs + [0] * bs
            d_loss = d.train_on_batch(X, y)
            print("batch %d d_loss : %f" % (index, d_loss))
            noise = np.random.uniform(-1, 1, (bs, 100))
            d.trainable = False
            g_loss = d_on_g.train_on_batch(noise, [1] * bs)
            d.trainable = True
            print("batch %d g_loss : %f" % (index, g_loss))
            if index % 10 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [25]:
def generate(BATCH_SIZE, nice=False):
    print('Generating...\n')
    g = generator_model()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator_model()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=1)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "GAN_keras_out/generated_image.png")

In [33]:
train(bs=4, n_epochs=10)

Training...

Initializing DISCRIMINATOR

Initializing GENERATOR

Initializing GENERATOR contatining DISCRIMINATOR

Epoch is 0
Number of batches 15000
Combining images...

batch 0 d_loss : 0.640239
batch 0 g_loss : 0.754873
batch 1 d_loss : 0.630616
batch 1 g_loss : 0.724089
batch 2 d_loss : 0.640635
batch 2 g_loss : 0.732655
batch 3 d_loss : 0.626306
batch 3 g_loss : 0.732605
batch 4 d_loss : 0.580333
batch 4 g_loss : 0.744703
batch 5 d_loss : 0.594675
batch 5 g_loss : 0.735786
batch 6 d_loss : 0.591691
batch 6 g_loss : 0.742045
batch 7 d_loss : 0.567855
batch 7 g_loss : 0.756632
batch 8 d_loss : 0.579205
batch 8 g_loss : 0.755142
batch 9 d_loss : 0.599971
batch 9 g_loss : 0.740223
batch 10 d_loss : 0.507722
batch 10 g_loss : 0.739184
batch 11 d_loss : 0.504797
batch 11 g_loss : 0.742996
batch 12 d_loss : 0.481896
batch 12 g_loss : 0.751688
batch 13 d_loss : 0.492403
batch 13 g_loss : 0.762809
batch 14 d_loss : 0.538577
batch 14 g_loss : 0.733812
batch 15 d_loss : 0.505899
batch 15 g_l

KeyboardInterrupt: 

In [24]:
.4 * 250

100.0