In [25]:

import numpy as np
import random as rn

import os
from skimage import io
from skimage.transform import resize, rotate
from skimage.util import pad
import matplotlib.pyplot as plt
from keras.utils import to_categorical, plot_model

import itertools

from keras.models import Sequential, Model
from keras.layers import Conv2D, MaxPooling2D, BatchNormalization, Dropout
from keras.layers import Dropout, Flatten, Dense, Input, Concatenate, Reshape, Flatten
from keras.applications.inception_v3 import InceptionV3
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger

from sklearn.metrics import confusion_matrix
# import tensorflow as tf

import datetime
import time
import threading

datafolder = os.getenv('data', 'data')
outputfolder = os.path.join(datafolder, 'output')
imgfolder = os.path.join(datafolder, 'test')

In [35]:
def fit_and_save_model(name, model, train_gen,  
                       steps_per_epoch = 10,  epochs=10, img_input=True):
    
    modeloutputfolder = os.path.join(outputfolder, name, datetime.datetime.now().strftime("%Y-%m-%d-%H-%M"))
    if not os.path.exists(modeloutputfolder):
        os.makedirs(modeloutputfolder)
    
    cbs=[ModelCheckpoint(filepath=os.path.join(modeloutputfolder,'model_fitted.h5'), verbose=1, save_best_only=True),
         CSVLogger(os.path.join(modeloutputfolder,'model_history.csv'), separator=',', append=False)]
         
    start_time = time.time()
    history = model.fit_generator(train_gen,
                             steps_per_epoch=steps_per_epoch,
                             #validation_data=validation_gen,
                             #validation_steps=validation_steps,
                             epochs=epochs, 
                            callbacks=cbs)
    end_time = time.time()
    return history

In [22]:
# Code from:
# https://github.com/keras-team/keras/issues/1638

class threadsafe_iter:
    """Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    """
    def __init__(self, it):
        self.it = it
        self.lock = threading.Lock()

    def __iter__(self):
        return self

    def __next__(self): # Py3
        with self.lock:
            return next(self.it)

    def next(self):     # Py2
        with self.lock:
            return self.it.next()


def threadsafe_generator(f):
    """A decorator that takes a generator function and makes it thread-safe.
    """
    def g(*a, **kw):
        return threadsafe_iter(f(*a, **kw))
    return g

In [37]:
class MyGenerators(object):
    def __init__(self, path, val_fac=0.1, batch_size=16, validation_size=64, target_size=(224,224)):
        self.path = path
        self.val_fac = val_fac
        self.batch_size = batch_size
        self.validation_size = validation_size
        self.target_size = target_size
        
        self.labels = os.listdir(path)
        self.number_labels = len(self.labels)
        self.number_label_elements = {}
        self.label_paths = {}
        self.label_len = {}
        self.images = {}
        self.labels_to_classify = []
        for i, l in enumerate(self.labels):
            self.label_paths[i] = os.path.join(self.path, l)
            self.images[i] = (os.listdir(self.label_paths[i]))
            self.label_len[i] = len(self.images[i])
        self.test_split = int(val_fac*min(self.label_len.values()))
        self.val_split = int(2*val_fac*min(self.label_len.values()))
        self.val_steps = int(self.number_labels*(self.val_split-self.test_split)/self.batch_size)
        self.val_log = []
        self.training_log = []
    
    def make_square(self, image, mode='constant'):
        max_dim = max(image.shape)
        pads = ((int((max_dim-image.shape[0])/2),
                int((max_dim-image.shape[0])/2)),
                (int((max_dim-image.shape[1])/2),
                int((max_dim-image.shape[1])/2)))
        if (mode=='constant'):
            image = pad(image,pads , mode=mode, constant_values=255)
        else:
            image = pad(image,pads , mode=mode)
        return resize(image, (224,224), mode=mode)

    def augment_image(self, image):
        if(np.random.choice([True, False])):
            image = np.flip(image, axis=1)
        max_dim = max(image.shape)
        #pads = ((0, int((max_dim-image.shape[0]))),
        #            (0, int((max_dim-image.shape[1])/2)))
        #image = pad(image,pads , mode='constant', constant_values=255)
        #angle = np.random.random_integers(0,359)
        #resize_var = np.random.choice([True, False])
        #image = rotate(image, angle, resize=True, mode=np.random.choice(['symmetric', 'reflect', 'wrap', 'edge']))
        return self.make_square(image, mode='constant')

    def get_statistics(self, image):
        return (max(image.shape)/1024.0, min(image.shape)/1024.0, np.sum(image==255)/(1.0*np.product(image.shape)))

    def base_generator(self, augment_image, get_random_label_and_image):
        output_list = []
        output_labels = []
        while not len(output_list) == self.batch_size:
            random_label, random_image = next(get_random_label_and_image)
            output_labels.append(random_label)
            output_list.append(os.path.join(self.label_paths[random_label], self.images[random_label][random_image]))
        output_images = [io.imread(fp) for fp in output_list]
        output_statistics = [self.get_statistics(image) for image in output_images]
        output_augmented_images = [augment_image(image).reshape(224,224,1) for image in output_images]
        return (np.stack(output_augmented_images),  
              to_categorical(np.stack(output_labels), num_classes = self.number_labels))

    def training_image_selector(self):
        while True:
            random_label = np.random.choice( self.number_labels)
            random_image = np.random.choice(range(self.val_split, self.label_len[random_label]))
            self.training_log.append((random_label, random_image))
            yield (random_label, random_image)
    
    @threadsafe_generator
    def training_generator(self):
        it = self.training_image_selector()
        while True:
            yield self.base_generator(self.augment_image, it)

    def val_image_selector(self):
        while True:
            for random_label in range(self.number_labels):
                for random_image in range(self.test_split, self.val_split):
                    self.val_log.append((random_label, random_image))
                    yield (random_label, random_image)
        
    
    @threadsafe_generator
    def validation_generator(self):
        it = self.val_image_selector()
        while True:
            yield self.base_generator(self.make_square, it)    
    
    def test_image_selector(self):
        while True:
            for random_label in range(self.number_labels):
                for random_image in range(self.test_split):
                    yield (random_label, random_image)
    
    @threadsafe_generator
    def test_generator(self):
        it = self.test_image_selector()
        while True:
            yield self.base_generator(self.make_square, it)    


@threadsafe_generator            
def make_image_generator(gen):
    while True:
        res = next(gen)
        yield res[0][0], res[1]

@threadsafe_generator
def make_stat_generator(gen):
    while True:
        res = next(gen)
        yield res[0][1], res[1]

In [38]:
dategen = MyGenerators(imgfolder, batch_size=8, val_fac=0.01)
train_gen = dategen.training_generator()
validation_gen = dategen.validation_generator()
NUMBER_LABELS = dategen.number_labels
#visualize_generator(dategen)

In [40]:
pretrained_model = InceptionV3(include_top=False, weights='imagenet', input_shape=(223,224,3))

inputs_image = Input(shape=(224,224,1))
x = Conv2D(3, kernel_size=(3,3), padding='same', activation='relu')(inputs_image)
x = pretrained_model(x)
x = Conv2D(1, (1,1), activation='relu')(x)
outputs_pretrained = Flatten()(x)

#x = Concatenate()([outputs_pretrained, output_stat])
x = Dropout(0.2)(outputs_pretrained)
x = Dense(1024, activation='relu')(x)
predictions = Dense(3, activation='softmax', name="output")(x)

composite_model_w_pretrained_model = Model(inputs=[inputs_image], outputs=predictions)

composite_model_w_pretrained_model.compile(optimizer='rmsprop', loss="categorical_crossentropy", metrics=["accuracy"])

In [41]:
hist = fit_and_save_model("composite_model_w_pretrained_model", composite_model_w_pretrained_model, train_gen,
                  steps_per_epoch = 2*10, epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
