In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [None]:
# Config
SEED = 2021
BATCH_SIZE = 32
IMG_SIZE = [128, 128]
MODEL_DIR = 'saved_model'

np.random.seed(SEED)
tf.random.set_seed(SEED)

In [None]:
train_ds, test_ds = tfds.load('rock_paper_scissors', split = ['train', 'test'], 
                              data_dir = 'dataset', as_supervised = True)

In [None]:
def preprocess(image, label):
    image = tf.image.resize(image, size=IMG_SIZE)
    image = tf.cast(image, dtype = tf.float32)
    image = image / 255.0
    return image, label

In [None]:
train_ds = train_ds.map(preprocess).shuffle(SEED).batch(BATCH_SIZE)
test_ds = test_ds.map(preprocess).batch(BATCH_SIZE)

In [None]:
tf.keras.backend.clear_session()

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, kernel_size = 3, input_shape=IMG_SIZE + [3]),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(32, kernel_size = 3),
    tf.keras.layers.MaxPool2D(),
    tf.keras.layers.Conv2D(64, kernel_size = 3),
    tf.keras.layers.GlobalMaxPool2D(),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(16, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(3, activation='softmax')
])
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', 
              metrics=['accuracy'])
model.summary()

In [None]:
class SaveBest(tf.keras.callbacks.Callback):
    best = None

    def __init__(self, save_model_dir):
        super().__init__()
        self.model_dir = save_model_dir

    def on_epoch_end(self, epoch, logs = {}):
        if self.best == None:
            self.best = logs['val_loss']
        else:
            if logs['val_loss'] < self.best:
                self.best = logs['val_loss']
                self.model.save(self.model_dir, save_format="tf", )

control = SaveBest(MODEL_DIR)

In [None]:
history = model.fit(train_ds, epochs=10, validation_data=test_ds, 
                    callbacks=[control], verbose = 2)

In [None]:
converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
tflite_model = converter.convert()

with open(f'{MODEL_DIR}/model.tflite', 'wb') as f:
  f.write(tflite_model)