In [1]:
# Get all the imports
from keras.models import Model, Sequential
from keras.layers import Dense, Flatten, LeakyReLU, BatchNormalization, Reshape
from keras.optimizers import SGD, Adam
from keras.engine.input_layer import Input
from keras.preprocessing.image import ImageDataGenerator
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image

Using TensorFlow backend.


In [2]:
# Setup some variables
image_rows = 52
image_cols = 52
channels = 3
BS=64
generated_image_path = "generated_images/"
image_shape = (image_rows, image_cols, channels)

In [3]:
def preprocess_input(im):
    return (im.astype(np.float32) - 127.5)/127.5

def reformat_image(im):
    return ((im * 127.5) + 127.5).astype(np.int32)

In [4]:
# This class exists as to keep the entire array of
# images out of RAM and instead load them in as needed
class ImageLoader():
    def __init__(self, files, func):
        self.files = files
        self.c_index = 0
        self.max = len(self.files)
        self.func = func
    
    def getNextFiles(self, num):
        arr = []
        for i in range(num):
            arr.append(self.load_image(self.files[self.c_index]))
            self.c_index += 1
            if self.c_index == self.max:
                print("Looping data now")
                self.c_index = 0
        return self.func(np.array(arr))
        
    def load_image(self, filename):
        img = Image.open(filename)
        img = img.resize((image_rows,image_cols))
        img = list(img.getdata())
        img = np.array(img)
        return img.reshape(image_shape)

In [5]:
class NFruit():
    def __init__(self):
        self.discriminator = self.makeDiscriminator()
        self.generator = self.makeGenerator()
        
        # Compile the models
        #optimizer = SGD(0.00002,momentum=0.3, decay=0.000001, nesterov=True)
        optimizer = Adam(0.002,0.5);
        self.discriminator.compile(
            loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.generator.compile(
            loss='binary_crossentropy',
            optimizer=optimizer)
        
        # Build the combined model
        model_input = Input(shape=(80,))
        image = self.generator(model_input)
        #self.dm.trainable = False
        validifier = self.discriminator(image)
        self.combined = Model(model_input, validifier)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
        self.combined.summary()
        
        # Create the image loader and the give it the collection of images
        image_locs = glob("fruits/fruits-360/Training/*/*.jpg")
        self.il = ImageLoader(image_locs, preprocess_input)
        
    def makeDiscriminator(self):
        # Create our discriminator.
        # This is the model that attempts to find
        # real and fake fruit
        dm = Sequential()
        dm.add(Flatten(input_shape=image_shape))
        #dm.add(Dense(1024))
        #dm.add(LeakyReLU(alpha=0.25))
        dm.add(Dense(512))
        dm.add(LeakyReLU(alpha=0.25))
        dm.add(Dense(256))
        dm.add(LeakyReLU(alpha=0.25))
        dm.add(Dense(1, activation='sigmoid'))
        dm.summary()
        img = Input(shape=image_shape)
        validity = dm(img)
        dm = Model(img, validity)
        return dm
    
    def makeGenerator(self):
        # Create our generator
        # This is the model that attempts
        # to fool the discriminator
        noise_shape = (80,)
        gn = Sequential()
        gn.add(Dense(256, input_shape=noise_shape))
        gn.add(LeakyReLU(alpha=0.25))
        gn.add(BatchNormalization(momentum=0.9))
        gn.add(Dense(512))
        gn.add(LeakyReLU(alpha=0.25))
        gn.add(BatchNormalization(momentum=0.9))
        #gn.add(Dense(1024))
        #gn.add(LeakyReLU(alpha=0.25))
        #gn.add(BatchNormalization(momentum=0.9))
        gn.add(Dense(np.prod(image_shape), activation='sigmoid'))
        gn.add(Reshape(image_shape))
        gn.summary()
        n = Input(shape=noise_shape)
        img = gn(n)
        gn = Model(n, img)
        return gn
    
    def generate_image(self, count=1):
        noise = np.random.normal(0,1,(count,80))
        return self.generator.predict(noise)
    
    # Write a function to save some images to files
    def save_image_array(self, filename, file_shape):
        r, c = file_shape
        images = self.generate_image(count=np.prod(file_shape))
        f, a = plt.subplots(r,c)
        for x in range(r):
            for y in range(c):
                a[x,y].imshow(images[x*c+y,:,:,:])
                a[x,y].axis('off')
        f.savefig(filename)
        plt.close()
                  
    def train(self, epochs = 10000, print_interval = 10):
        dm_total_hist = []
        self.past_data = []
        print("Starting")
        
        for e in range(epochs):
            # train the discriminator on the batch of real data
            real_data = self.il.getNextFiles(BS//2)
            self.past_data.append(real_data)
            dm_hist_real = self.discriminator.train_on_batch(real_data, np.ones((BS//2, 1))
                                            )
            # train the discriminator on the batch of fake data
            #fake_data = generate_image(count=BS//2)
            noise = np.random.normal(0, 1, (BS//2, 80))
            gen_imgs = self.generator.predict(noise)
            
            dm_hist_fake = self.discriminator.train_on_batch(real_data, np.zeros((BS//2, 1)))
            
            # calculate total loss
            dm_loss = np.add(dm_hist_real, dm_hist_fake) / 2
            dm_total_hist.append(dm_loss)
            
            # Now train the generator
            noise = np.random.normal(0,1,(BS,80))
            valid_y = np.array([1] * BS)
            gn_loss = self.combined.train_on_batch(noise, valid_y)
            
            
            if e % print_interval == 0:
                print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (e, dm_loss[0], 100*dm_loss[1], gn_loss))
                self.save_image_array(generated_image_path + 
                                 "generated_images" + str(e) + 
                                 ".png", (5,5))
                
        return dm_total_hist

In [6]:
nfruit = NFruit()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 8112)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               4153856   
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 256)               131328    
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 256)               0         
_________________________________________________________________
dense_3 (Dense)              (None, 1)                 257       
Total params: 4,285,441
Trainable params: 4,285,441
Non-trainable params: 0
_________________________________________________________________


In [7]:
nfruit.train(epochs=10,print_interval=1)

Starting
0 [D loss: 8.167011, acc.: 50.00%] [G loss: 16.118095]
1 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
2 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
3 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
4 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
5 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
6 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
7 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
8 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]
9 [D loss: 7.971192, acc.: 50.00%] [G loss: 16.118095]


[array([8.167011, 0.5     ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32),
 array([7.9711924, 0.5      ], dtype=float32)]

In [None]:
nfruit.save_image_array(generated_image_path + "generated_images.png", (5,5))

In [None]:
arr = np.array(nfruit.past_data)

In [None]:
arr.shape

In [None]:
np.array_equal(arr[0],arr[1])

In [None]:
dm_total_hist[:5] + dm_total_hist[-5:]

In [None]:
save_image_array(generated_image_path + "generated_images.png", (5,5))

In [None]:
image_locs = glob("fruits/fruits-360/Training/*/*.jpg")
il = ImageLoader(image_locs, preprocess_input)

In [None]:
r, c = (5,5)
images = il.getNextFiles(r*c)
f, a = plt.subplots(r,c)
for x in range(r):
    for y in range(c):
        a[x,y].imshow(images[x*c+y,:,:,:])
        a[x,y].axis('off')
plt.show()
plt.close()
images = il.getNextFiles(r*c)
f, a = plt.subplots(r,c)
for x in range(r):
    for y in range(c):
        a[x,y].imshow(images[x*c+y,:,:,:])
        a[x,y].axis('off')
plt.show()
plt.close()