In [1]:
from typing import cast
import tensorflow as tf
import keras
import retina
import matplotlib.pyplot as plt

In [42]:
def normalize_image(x, y):
  return (x / 255), y
def clamp_zero_one(x, y):
  return tf.maximum(tf.minimum(x, 1), 0), y

def apply_transformation(dataset: tf.data.Dataset):
  return dataset.map(normalize_image).repeat(30)

In [None]:
training_dataset = apply_transformation(cast(tf.data.Dataset, keras.utils.image_dataset_from_directory(
  directory=retina.filesys.DATA_PATH,
  color_mode="grayscale",
  image_size=retina.size.FACE_DIMENSIONS.tuple,
  seed=42,
  validation_split=0.2,
  subset='training',
)))
validation_dataset = apply_transformation(cast(tf.data.Dataset, keras.utils.image_dataset_from_directory(
  directory=retina.filesys.DATA_PATH,
  color_mode="grayscale",
  image_size=retina.size.FACE_DIMENSIONS.tuple,
  seed=42,
  subset='validation',
  validation_split=0.3,
)))

In [44]:
data_augmentation = keras.Sequential([
  keras.layers.RandomFlip("horizontal"),
  keras.layers.RandomRotation(0.05),
  keras.layers.RandomTranslation(0.1, 0.1),
  keras.layers.RandomBrightness(0.1, value_range=(0, 1)),
  keras.layers.RandomContrast(0.1),
])

augmented_dataset = training_dataset.map(lambda x, y: (data_augmentation(x, training=True), y))

In [None]:
for batch, labels in augmented_dataset.take(1):
  images = list(tf.squeeze(batch).numpy())
  str_labels = list(map(str, labels.numpy()))
  subplots = retina.debug.collage_images_plt(images, str_labels, (4, 4))

In [62]:
model = keras.Sequential([
  keras.layers.Input(shape=retina.size.FACE_DIMENSIONS.tuple),
  keras.layers.Flatten(),
  keras.layers.Dense(512, activation="relu"),
  keras.layers.Dropout(0.5),
  keras.layers.Dense(256, activation="relu"),
  keras.layers.Dropout(0.5),
  keras.layers.Dense(128, activation="relu"),
  keras.layers.Dense(6),
  keras.layers.Softmax(),
])

model.compile(
  optimizer=keras.optimizers.Adam(), # type: ignore
  loss=keras.losses.SparseCategoricalCrossentropy(),
  metrics=["accuracy"],
)

In [None]:
history = model.fit(augmented_dataset, epochs=100, validation_data=validation_dataset)

In [None]:
accuracy_fig = plt.figure(figsize=(6,6))
ax = accuracy_fig.add_subplot()
ax.plot(history.history["accuracy"], label="Accuracy", marker='o')
ax.plot(history.history["val_accuracy"], label="Validation Accuracy", marker='o')
ax.legend()
ax.set_ylim(0, 1)
accuracy_fig.show()