In [3]:
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, AvgPool2D, Dropout, Flatten, Dense
from tensorflow.keras.optimizers import Adam

def load_train(path):
    train_datagen = ImageDataGenerator(
        validation_split=0.25,
        rescale=1/255.
#         horizontal_flip=True,
#         vertical_flip=True
    )
    
    train_data = train_datagen.flow_from_directory(
        directory=path,  
        target_size=(150, 150),
        batch_size=16,
        class_mode='sparse',
        subset='training',
        seed=12345
    )
    return train_data


def create_model(input_shape):
    model = Sequential()
    optimizer = Adam(learning_rate=0.001)

    model.add(Conv2D(filters=32, kernel_size=(3, 3), input_shape=input_shape, padding='same', activation='relu'))
    model.add(AvgPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))

    model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu', padding='same'))
    model.add(AvgPool2D(pool_size=(2, 2)))
    model.add(Dropout(0.2))

#     model.add(Conv2D(filters=128, kernel_size=(3, 3), activation='relu'))
#     model.add(AvgPool2D(pool_size=(2, 2)))
#     model.add(Dropout(0.2))

    model.add(Flatten())

#     model.add(Dense(units=128, activation='relu'))
#     model.add(Dropout(0.2)) 

    model.add(Dense(units=64, activation='relu'))
    model.add(Dropout(0.2))

    model.add(Dense(units=12, activation='softmax')) 

    model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['acc'])
    
    return model 


def train_model(model, train_data, test_data, batch_size=None, epochs=10,
               steps_per_epoch=None, validation_steps=None):

    
    model.fit(train_data, 
              validation_data=test_data,
              epochs=epochs, batch_size=batch_size,
              steps_per_epoch=steps_per_epoch,
              validation_steps=validation_steps,
              verbose=2, shuffle=True)

    return model 