In [1]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
import tensorflow as tf
from tensorflow.keras.layers import Dense,Flatten, Conv2D
from tensorflow.keras.layers import MaxPooling2D, Dropout
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import TensorBoard
import pickle

In [14]:
def keras_model(image_x, image_y, num_of_classes):
    model = Sequential()
    model.add(Conv2D(32, (5, 5), input_shape=(image_x,image_y,1), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same'))
    model.add(Conv2D(64, (5, 5), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same'))

    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dropout(0.6))
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.6))
    model.add(Dense(num_of_classes, activation='softmax'))

    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    filepath = "../models/QuickDraw.keras"
    checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
    callbacks_list = [checkpoint]

    return model, callbacks_list

In [9]:
def loadFromPickle():
    with open("../features", "rb") as f:
        features = np.array(pickle.load(f))
    with open("../labels", "rb") as f:
        labels = np.array(pickle.load(f))

    return features, labels

In [4]:
def augmentData(features, labels):
    features = np.append(features, features[:, :, ::-1], axis=0)
    labels = np.append(labels, -labels, axis=0)
    return features, labels

In [5]:
def prepress_labels(labels):
    labels = tf.keras.utils.to_categorical(labels)
    return labels

In [10]:
features, labels = loadFromPickle()
# features, labels = augmentData(features, labels)
features, labels = shuffle(features, labels)
labels=prepress_labels(labels)
    
# Automatically detect number of classes
num_of_classes = labels.shape[1]
print(f"Detected {num_of_classes} classes in the dataset")
    
train_x, test_x, train_y, test_y = train_test_split(features, labels, random_state=0,
                                                        test_size=0.1)

Detected 21 classes in the dataset


In [15]:
train_x = train_x.reshape(train_x.shape[0], 28, 28, 1)
test_x = test_x.reshape(test_x.shape[0], 28, 28, 1)
model, callbacks_list = keras_model(28, 28, num_of_classes)
model.summary()

In [16]:
model.fit(train_x, train_y, validation_data=(test_x, test_y), epochs=3, batch_size=64,
              callbacks=[TensorBoard(log_dir="QuickDraw")])

Epoch 1/3
[1m2954/2954[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 14ms/step - accuracy: 0.5717 - loss: 1.4126 - val_accuracy: 0.8309 - val_loss: 0.5262
Epoch 2/3
[1m2954/2954[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 14ms/step - accuracy: 0.8113 - loss: 0.6254 - val_accuracy: 0.8554 - val_loss: 0.4577
Epoch 3/3
[1m2954/2954[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 14ms/step - accuracy: 0.8372 - loss: 0.5439 - val_accuracy: 0.8650 - val_loss: 0.4285


<keras.src.callbacks.history.History at 0x23552c959d0>

In [20]:
model.save('models/QuickDraw.h5')

