In [1]:
# Get all inputs
from keras.models import Model, Sequential
from keras.layers import Dense, Flatten, LeakyReLU, BatchNormalization, Reshape
from keras.optimizers import Adam
from keras.engine.input_layer import Input
import numpy as np
from keras.datasets import mnist
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image

Using TensorFlow backend.


In [2]:
# Define functions to convert from human-readable images
# to images the networks can comprehend
def preprocess_input(im):
    return (im.astype(np.float32) - 127.5)/127.5
def postprocess_input(im):
    return ((im * 127.5) + 127.5).astype(np.int32)

In [3]:
# This class exists as to keep the entire array of
# images out of RAM and instead load them in as needed
class ImageLoader():
    def __init__(self, img_rows, img_cols, channels, files, func):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channels = channels
        self.files = files
        self.c_index = 0
        self.max = len(self.files)
        self.func = func    # The post-processing function 
                            # that should be applied to images
    
    def getNextFiles(self, num):
        arr = []
        for i in range(num):
            arr.append(self.load_image(self.files[self.c_index]))
            self.c_index += 1
            if self.c_index == self.max:
                print("Looping data now")
                self.c_index = 0
        return self.func(np.array(arr))
        
    def load_image(self, filename):
        img = Image.open(filename)
        if self.channels == 1:
            img = img.convert('1')
        img = img.resize((self.img_rows,self.img_cols))
        img = list(img.getdata())
        img = np.array(img)
        return img.reshape((self.img_rows,self.img_cols,self.channels))

In [4]:
class NFruit3():
    def __init__(self, save_loc):
        self.img_rows = 64
        self.img_cols = 64
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.noise_shape = (100,)
        self.save_loc = save_loc
        self.doSaveTraining = True

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build and compile the generator
        self.generator = self.build_generator()
        self.generator.compile(loss='binary_crossentropy', optimizer=optimizer)

        # The generator takes noise as input and generated imgs
        z = Input(shape=self.noise_shape)
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The valid takes generated images as input and determines validity
        valid = self.discriminator(img)

        # The combined model  (stacked generator and discriminator) takes
        # noise as input => generates images => determines validity
        self.combined = Model(z, valid)
        self.combined.summary()
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        model = Sequential()

        model.add(Dense(256, input_shape=self.noise_shape))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(2048))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        model.summary()

        noise = Input(shape=self.noise_shape)
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):

        img_shape = (self.img_rows, self.img_cols, self.channels)

        model = Sequential()

        model.add(Flatten(input_shape=img_shape))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()

        img = Input(shape=img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=256, save_interval=50, start=0):
        # Create the image loader and the give it the collection of images
        image_locs = glob("fruits/fruits-360/Test/*/*.jpg")
        il = ImageLoader(self.img_rows, self.img_cols, self.channels, image_locs, preprocess_input)
        
        half_batch = int(batch_size / 2)

        print("Starting the training...")
        
        for epoch in range(start, epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            imgs = il.getNextFiles(half_batch)

            noise = np.random.normal(0, 1, (half_batch, 100))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, np.ones((half_batch, 1)))
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, np.zeros((half_batch, 1)))
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)


            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, 100))

            # The generator wants the discriminator to label the generated samples
            # as valid (ones)
            valid_y = np.array([1] * batch_size)

            # Train the generator
            g_loss = self.combined.train_on_batch(noise, valid_y)

            # If at save interval then save generated image samples
            if epoch % save_interval == 0:
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
                self.save_imgs(epoch)
                if self.doSaveTraining:
                    self.save_training_imgs(epoch, imgs)
                
        self.save_imgs(epochs)
        
    def save_training_imgs(self, epoch, training_imgs):
        r, c = 4, 4
        hf = (r*c)//2
        imgs = postprocess_input(training_imgs)

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(imgs[cnt, :,:,:])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(self.save_loc + ("epoch_%d_training.png" % epoch))
        plt.close()
        
    def save_imgs(self, epoch):
        r, c = 4, 4
        noise = np.random.normal(0, 1, (r * c, 100))
        gen_imgs = self.generator.predict(noise)

        # prepare the images for human viewing
        gen_imgs = postprocess_input(gen_imgs)

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,:])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig(self.save_loc + ("epoch_%d.png" % epoch))
        plt.close()
        
    def save_models(self):
        self.generator.save_weights('generator.h5')
        self.discriminator.save_weights('discriminator.h5')
        self.combined.save_weights('combined.h5')
        
    def load_models(self):
        self.generator.load_weights('generator.h5')
        self.discriminator.load_weights('discriminator.h5')
        self.combined.load_weights('combined.h5')


In [6]:
# Create the models
nfruit = NFruit3("generated_images/attempt_3-11/")

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 12288)             0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              12583936  
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 1024)              0         
_________________________________________________________________
dense_2 (Dense)              (None, 512)               524800    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_3 (LeakyReLU)    (None, 256)               0         
__________

In [7]:
# Train the networks for 1000 epochs
nfruit.train(epochs=10000, batch_size=1024, save_interval=50)

Starting the training...


  'Discrepancy between trainable weights and collected trainable'


0 [D loss: 0.771962, acc.: 7.03%] [G loss: 0.634349]
Looping data now
50 [D loss: 0.499506, acc.: 81.93%] [G loss: 5.265211]
Looping data now
100 [D loss: 0.559352, acc.: 65.82%] [G loss: 2.021865]
Looping data now
150 [D loss: 0.328662, acc.: 78.71%] [G loss: 1.470869]
Looping data now
Looping data now
200 [D loss: 0.384305, acc.: 73.44%] [G loss: 0.963032]
Looping data now
250 [D loss: 0.396781, acc.: 74.80%] [G loss: 0.949578]
Looping data now
300 [D loss: 0.442519, acc.: 55.18%] [G loss: 0.747975]
Looping data now
Looping data now
350 [D loss: 0.381939, acc.: 76.86%] [G loss: 0.867527]
Looping data now
400 [D loss: 1.018436, acc.: 35.84%] [G loss: 0.734381]
Looping data now
450 [D loss: 0.432208, acc.: 80.47%] [G loss: 0.946367]
Looping data now
Looping data now
500 [D loss: 0.747735, acc.: 46.09%] [G loss: 0.737277]
Looping data now
550 [D loss: 0.560114, acc.: 69.14%] [G loss: 0.876225]
Looping data now
600 [D loss: 0.626781, acc.: 66.50%] [G loss: 0.806359]
Looping data now
Loop

KeyboardInterrupt: 

In [8]:
# Save the current model
nfruit.save_models()

In [19]:
# Load the current saved models
nfruit.load_models()

In [15]:
# Save an example image
nfruit.save_imgs(0)