In [None]:
# import libraries
import keras
from utils.preprocess import *
from config import *

In [None]:
# load dataset
dataset = load_unipen_dataset(no_cap)

# normalize dataset
def normalize(data, label):
    return tf.cast(data, tf.float32) / 255.0, label

if normalize_dataset:
    dataset = dataset.map(normalize)

# data augmentation
augmentLayer = keras.layers.RandomRotation(0.01, fill_mode='constant')
def augment(data, label):
    return augmentLayer(data), label

dataset = dataset.shuffle(shuffle_buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.map(augment)

# split & filter dataset
train_size = int(train_prop * dataset.cardinality().numpy())
train_dataset = dataset.take(train_size)
test_dataset = dataset.skip(train_size)

In [None]:
# build model
model = keras.Sequential([
    keras.layers.Conv2D(16, (3, 3), name="conv1", input_shape=(32, 32, 1)),
    keras.layers.ReLU(name="relu16"),

    keras.layers.MaxPooling2D((2, 2), name="maxpool1"),

    keras.layers.Conv2D(32, (3, 3), name="conv2"),
    keras.layers.ReLU(name="relu32"),
    
    keras.layers.MaxPooling2D((2, 2), name="maxpool2"),
    
    keras.layers.Conv2D(64, (3, 3), name="conv3"),
    keras.layers.ReLU(name="relu64"),
    
    # keras.layers.MaxPooling2D((2, 2), name="maxpool3"),
    
    keras.layers.Conv2D(128, (3, 3), name="conv4"),
    keras.layers.ReLU(name="relu128"),
    
    keras.layers.MaxPooling2D((2, 2), name="maxpool4"),

    keras.layers.Dropout(0.25),
    
    keras.layers.Flatten(name="flatten"),

    # keras.layers.Dense(128, name="dense256"),
    # keras.layers.ReLU(name="reludense256"),
    
    keras.layers.Dense(128 - 32, name="dense128"),
])

model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

# Param # = ((kersize ** 2) * in + 1) * out
model.summary()

In [None]:
model.fit(train_dataset, epochs=epochs)

In [None]:
# evaluate model
test_loss, test_acc = model.evaluate(test_dataset)

print()
print('Test loss:    ', test_loss)
print('Test accuracy:', test_acc)

In [None]:
# save model
target_model = "unipen_no_cap_model" if no_cap else "unipen_model"
model.save(f"data/{target_model}.h5")