In [1]:
from keras.models import Sequential
from keras.layers import Dense,Activation,Conv2D,MaxPooling2D
from keras.layers import UpSampling2D,Flatten,Reshape,BatchNormalization
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
def generator():
    model = Sequential()
    model.add(Dense(input_dim=100,output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((7,7,128),input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2,2)))
    model.add(Conv2D(64,(5,5),padding='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2,2)))
    model.add(Conv2D(1,(5,5),padding='same'))
    model.add(Activation('tanh'))
    return model

In [11]:
def discriminator():
    model = Sequential()
    model.add(Conv2D(64,(5,5),padding='same',input_shape=(28,28,1)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Conv2D(128,(5,5)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size=(2,2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

In [12]:
def generator_given_discriminator(generator_model,discriminator_model):
    model = Sequential()
    model.add(generator_model)
    d.trainable = False
    model.add(discriminator_model)
    return model

In [13]:
def combine_images(generated_images):
    num = generated_images.shape[0]
    width = int(math.sqrt(num))
    height = int(math.ceil(float(num)/width))
    shape = generated_images.shape[1:3]
    image = np.zeros((height*shape[0],width*shape[1]),
                    dtype=generated_images.dtype)
    for index,img in enumerate(generated_images):
        i = int(index/width)
        j = index % width
        image[i*shape[0]:(i+1)*shape[0],j*shape[1]:(j+1)*shape[1]] = img[:,:,0]
    return image

In [14]:
def train(BATCH_SIZE):
    (X_train,y_train),(X_test,y_test) = mnist.load_data()
    X_train = (X_train.astype(np.float32)-127.5)/127.5
    X_train = X_train[:,:,:,None]
    X_test = X_test[:,:,:,None]
    
    d = discriminator()
    g = generator()
    d_on_g = generator_given_discriminator(g,d)
    
    d_optim = SGD(lr=0.001,momentum=0.9,nesterov=True)
    g_optim = SGD(lr=0.001,momentum=0.9,nesterov=True)
    
    g.compile(loss='binary_crossentropy',optimizer='SGD')
    d_on_g.compile(loss='binary_crossentropy',optimizer=g_optim)
    
    d.trainable=True
    d.compile(loss='binary_crossentropy',optimizer=d_optim)
    
    for epoch in range(30):
        print('Epoch is', epoch)
        print('Number of batches', int(X_train.shape[0]/BATCH_SIZE))
        
        for index in range(int(X_train.shape[0]/BATCH_SIZE)):
            noise = np.random.uniform(-1,1,size=(BATCH_SIZE,100))
            image_batch=X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images=g.predict(noise,verbose=0)
            if index%100==0:
                image = combine_images(generated_images)
                image = image*127.5+127.5
                Image.fromarray(image.astype(np.uint8)).save('./GAN/'+str(epoch)+'_'+str(index)+'.jpg')
                
            X = np.concatenate((image_batch,generated_images))
            
            y = [1]*BATCH_SIZE + [0] * BATCH_SIZE
            
            d_loss = d.train_on_batch(X,y)
            print('batch %d d_loss : %f' % (index, d_loss))
            
            noise = np.random.uniform(-1,1,(BATCH_SIZE,100))
            
            d.trainable = False
            
            g_loss = d_on_g.train_on_batch(noise,[1]*BATCH_SIZE)
            
            d.trainable = True
            print('batch %d g_loss : %f' % (index,g_loss))
            
            if index % 100 == 9:
                g.save_weights('generator', True)
                d.save_weights('discriminator', True)

In [15]:
def generate(BATCH_SIZE, nice= False ):
    
    g = generator()
    g.compile(loss='binary_crossentropy', optimizer="SGD")
    g.load_weights('generator')
    if nice:
        d = discriminator()
        d.compile(loss='binary_crossentropy', optimizer="SGD")
        d.load_weights('discriminator')
        noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
        generated_images = g.predict(noise, verbose=1)
        d_pret = d.predict(generated_images, verbose=1)
        index = np.arange(0, BATCH_SIZE*20)
        index.resize((BATCH_SIZE*20, 1))
        pre_with_index = list(np.append(d_pret, index, axis=1))
        pre_with_index.sort(key=lambda x: x[0], reverse=True)
        nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
        nice_images = nice_images[:, :, :, None]
        for i in range(BATCH_SIZE):
            idx = int(pre_with_index[i][1])
            nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
        image = combine_images(nice_images)
    else:
        noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
        generated_images = g.predict(noise, verbose=0)
        image = combine_images(generated_images)
    image = image*127.5+127.5
    Image.fromarray(image.astype(np.uint8)).save(
        "./GAN/generated_image.jpg")

In [17]:
train(BATCH_SIZE=132)
generate(BATCH_SIZE=132)

  This is separate from the ipykernel package so we can avoid doing imports until


Epoch is 0
Number of batches 454
batch 0 d_loss : 0.667026
batch 0 g_loss : 0.707065
batch 1 d_loss : 0.640115
batch 1 g_loss : 0.697564
batch 2 d_loss : 0.613985
batch 2 g_loss : 0.687107
batch 3 d_loss : 0.582559
batch 3 g_loss : 0.682326
batch 4 d_loss : 0.560834
batch 4 g_loss : 0.673343
batch 5 d_loss : 0.544283
batch 5 g_loss : 0.665048
batch 6 d_loss : 0.521645
batch 6 g_loss : 0.668004
batch 7 d_loss : 0.493725
batch 7 g_loss : 0.662798
batch 8 d_loss : 0.475911
batch 8 g_loss : 0.673772
batch 9 d_loss : 0.479765
batch 9 g_loss : 0.673942
batch 10 d_loss : 0.455973
batch 10 g_loss : 0.685020
batch 11 d_loss : 0.430585
batch 11 g_loss : 0.679927
batch 12 d_loss : 0.418536
batch 12 g_loss : 0.705454
batch 13 d_loss : 0.430059
batch 13 g_loss : 0.706780
batch 14 d_loss : 0.414130
batch 14 g_loss : 0.722976
batch 15 d_loss : 0.406213
batch 15 g_loss : 0.745810
batch 16 d_loss : 0.408876
batch 16 g_loss : 0.749358
batch 17 d_loss : 0.403421
batch 17 g_loss : 0.776455
batch 18 d_loss

KeyboardInterrupt: 