# Sketch Classifier for "How Do Humans Sketch Objects?"

A sketch classifier using the dataset from the paper <a href='http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/'>How Do Humans Sketch Objects?</a> where the authors collected 20,000 unique sketches evenly distributed over 250 object categories - we will use a CNN (using Keras) to classify a sketch. 

<img src='http://cybertron.cg.tu-berlin.de/eitz/projects/classifysketch/teaser_siggraph.jpg'/>

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
import numpy as np
from scipy.misc import imresize
import os

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

plt.style.use('ggplot')

In [None]:
import keras 
keras.__version__

In [None]:
from keras import layers
from keras import models
from keras import optimizers
from keras import callbacks

from keras.utils import plot_model

from keras import preprocessing
from keras.preprocessing import image

## Trained on Floydhub

In [None]:
DEST_SKETCH_DIR = '/sketches_training_data/'
TARGET_SIZE = (256,256)
CATEGORIES_COUNT = 199
TRAINING_SAMPLES = 12736
VALIDATION_SAMPLES = 3184

In [None]:
!ls /sketches_training_data

## Create model 

In [None]:
def plot_accuracy_loss(history):
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs = range(len(acc))

    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'b', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()

    plt.figure()

    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()

    plt.show()

In [None]:
def train(model, 
          training_dir,
          validation_dir,
          target_size=TARGET_SIZE, 
          training_samples=TRAINING_SAMPLES, 
          validation_samples=VALIDATION_SAMPLES,
          epochs=1000, 
          batch_size=512, 
          load_previous_weights=True,
          model_weights_file=None):
    """
    
    """
    
    if model_weights_file is None:
        raise("No model weights file set")
    
    print("Training STARTED - target size {}, batch size {}".format(
        target_size, 
        batch_size))
    
    if model_weights_file is not None and os.path.isfile(model_weights_file) and load_previous_weights:
        print("Loading weights from file {}".format(model_weights_file))
        model.load_weights(model_weights_file)

    model.compile(
        loss='categorical_crossentropy', 
        optimizer='rmsprop', 
        metrics=['accuracy'])
    
    # create data generator 
    # check the official documentation for more details: https://keras.io/preprocessing/image/
    datagen = preprocessing.image.ImageDataGenerator(
        rescale=1./255., # rescaling factor applied by multiply the data by this value  
        rotation_range=10, # value in degrees (0-180), a range within which to randomly rotate pictures
        width_shift_range=0.1, # ranges (as a fraction of total width) to randomly translate pictures 
        height_shift_range=0.1, # ranges (as a fraction of total height) to randomly translate pictures 
        shear_range=0.1, # randomly applying shearing transformations
        zoom_range=0.1, # randomly zooming inside pictures
        horizontal_flip=True, # randomly flipping half of the images horizontally
        fill_mode='nearest') # strategy used for filling in newly created pixels
    
    if model.layers[0].input_shape[0] == target_size[0] and model.layers[0].input_shape[1] == target_size[1]:
        target_size = None
    
    # create an iterator for the training data 
    train_generator = datagen.flow_from_directory(
        training_dir,
        target_size=target_size,
        batch_size=batch_size, 
        color_mode='grayscale')
    
    # create an iterator for the validation data 
    validation_generator = datagen.flow_from_directory(
        validation_dir,
        target_size=target_size,
        batch_size=batch_size, 
        color_mode='grayscale')
    
    checkpoint = callbacks.ModelCheckpoint(model_weights_file, 
                                           monitor='val_loss', 
                                           verbose=0, 
                                           save_best_only=True, 
                                           save_weights_only=True, 
                                           mode='auto', 
                                           period=2)
    
    early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=10)
    
    data_augmentation_multiplier = 2.5 
    
    history = model.fit_generator(
        train_generator,
        steps_per_epoch=int((training_samples/batch_size) * data_augmentation_multiplier),
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=int((validation_samples/batch_size) * data_augmentation_multiplier), 
        callbacks=[checkpoint, early_stopping]) 
    
    print("Training FINISHED - target size {}, batch size {}".format(
        target_size, 
        batch_size))
    
    return history, model     

In [None]:
def create_model(input_shape=(128,128,1), classes=250, is_training=True):
    """
    Create a CNN model 
    """
    
    model = models.Sequential() 
    model.add(layers.Conv2D(32, 
                            kernel_size=(7,7), 
                            strides=(2,2), 
                            padding='same', 
                            activation='relu', 
                            input_shape=input_shape))
    model.add(layers.MaxPooling2D(2,2))    
    
    model.add(layers.Conv2D(64, 
                            kernel_size=(5,5), 
                            padding='same', 
                            activation='relu'))
    model.add(layers.MaxPooling2D(2,2))    
    if is_training:
        model.add(layers.Dropout(0.125))
    
    model.add(layers.Conv2D(128, 
                            kernel_size=(3,3), 
                            padding='same', 
                            activation='relu'))
    model.add(layers.MaxPooling2D(2,2))    
    
    model.add(layers.Conv2D(128, 
                            kernel_size=(3,3), 
                            padding='same', 
                            activation='relu'))
    model.add(layers.MaxPooling2D(2,2))    
    
    model.add(layers.Flatten())                
    model.add(layers.Dense(512, activation='relu'))
    if is_training:
        model.add(layers.Dropout(0.5))
        
    model.add(layers.Dense(classes, activation='softmax', name='output'))
    
    return model

In [None]:
model = create_model() 
model.summary()

In [None]:
history, model = train(model, 
                       training_dir=os.path.join(DEST_SKETCH_DIR, 'training'), 
                       validation_dir=os.path.join(DEST_SKETCH_DIR, 'validation'), 
                       target_size=(256,256),
                       epochs=1000, 
                       batch_size=512,
                       model_weights_file="/output/cnn_sketch_weights_1.h5", 
                       load_previous_weights=True)

In [None]:
plot_accuracy_loss(history)