# Interactive TensorFlow Classification Model Training Workbench

## Setup
First we setup the project, import the configuration and load the datasets to use during training.

### Imports
Import all the Python modules that are required by default.

In [14]:
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.models import load_model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import Callback, ModelCheckpoint, BackupAndRestore

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from IPython.display import clear_output

import os
import yaml
from collections import defaultdict

f'TensorFlow v{tf.__version__}'

'TensorFlow v2.8.0'

### Config
Load the `config.yaml` file.

In [None]:
with open('./config.yaml') as config_file:
    config = yaml.safe_load(config_file)
    print(yaml.dump(config))

dataset_cfg = config['dataset']
model_cfg = config['model']
optimizer_cfg = config['optimizer']
callback_cfg = config.get('callbacks', dict())
training_cfg = config.get('training', dict())

### Data Generators
Define the training and validation datasets.

In [None]:
gen_train = ImageDataGenerator().flow_from_directory(
    dataset_cfg['src']+'/training',
    batch_size=dataset_cfg.get('batch'),
    class_mode=dataset_cfg['class_mode'],
    **dataset_cfg.get('train_options', dict())
)

gen_valid = ImageDataGenerator().flow_from_directory(
    dataset_cfg['src']+'/validation',
    batch_size=dataset_cfg['batch'],
    class_mode=dataset_cfg['class_mode'],
    **dataset_cfg.get('valid_options', dict())
)

### Preview training frames
Preview 8 Frames from each class. Second cell can be run again to preview new frames each time.

In [None]:
%matplotlib inline

nrows = 2
ncols = 4
nimgs = nrows*ncols
index = 0

training_dir = dataset_cfg['src']+'/training/'

class_frames = dict()
for cls in dataset_cfg['classes']:
    class_frames[cls] = os.listdir(training_dir+cls)


def plot_images(images, title):
    fig = plt.gcf()
    fig.set_size_inches(ncols * 3, nrows * 3)
    fig.suptitle(title, size=20)
    for i, image in enumerate(images):
        sp = plt.subplot(nrows, ncols, i + 1)
        sp.axis('Off')
        img = mpimg.imread(image)
        plt.imshow(img)
    plt.show()

In [None]:
index += nimgs

for cls, frames in class_frames.items():
    frames = [os.path.join(training_dir+cls, f) for f in class_frames[cls][index-nimgs:index]]
    plot_images(frames, cls)

## Model
Here we first define the model that is going to be trained and compile it.

### Model Architecture
Import the model class and initialize the model.

In [None]:
ModelClass = getattr(
    __import__(
        model_cfg['module'],
        fromlist=[model_cfg['class']]
    ),  model_cfg['class']
)

input_shape = model_cfg.get('input_shape', '(224, 224, 3)')
input_shape = tuple(map(int, input_shape[1:-1].split(',')))

base = ModelClass(
    include_top=model_cfg.get('include_top', False),
    weights=model_cfg.get('weights'),
    input_shape=input_shape,
    **model_cfg.get('class_options', dict())
)

model = Sequential()
model._name = model_cfg.get('name', base.name)
model.add(base)
model.add(GlobalAveragePooling2D())
model.add(
    Dense(
        len(dataset_cfg['classes'])-1,
        activation=model_cfg.get('fc_layer_activation', 'sigmoid')
    )
)
model.layers[0].trainable = False

model.summary()

### Optimizer
Import and Initialize the optimizer.

In [None]:
OptimizerClass = getattr(
    __import__(
        optimizer_cfg['module'],
        fromlist=[optimizer_cfg['class']]
    ),  optimizer_cfg['class']
)

optimizer = OptimizerClass(**optimizer_cfg.get('options', dict()))

print('optimizer:', OptimizerClass.__name__)

### Compile Model
Compile the model and set the optimizer.

In [None]:
model.compile(
    optimizer=optimizer,
    loss=model_cfg.get('loss', 'binary_crossentropy'),
    metrics=model_cfg.get('metrics', 'accuracy'),
    **model_cfg.get('compile_options', dict())
)

## Training & Evaluation
Finally we define what callbacks to use during training and start the training. Afterwards, we evaluate the best checkpoint model saved during training.

### Callbacks
Import and initiate the different callbacks to be used during training.

In [None]:
class PlotLearning(Callback):
    """Callback to plot the learning curves of the model during training."""

    def on_train_begin(self, logs={}):
        self.metrics = defaultdict(list)
        for metric in logs:
            self.metrics[metric] = []

    def plot(self, epoch, logs={}):
        for metric in logs:
            self.metrics[metric].append(logs.get(metric))
        metrics = [x for x in logs if 'val' not in x]

        x = range(1, epoch + 2)
        f, axs = plt.subplots(1, len(metrics), figsize=(15, 5))

        clear_output(wait=True)

        for i, metric in enumerate(metrics):
            axs[i].plot(x,
                        self.metrics[metric],
                        label=metric)
            if logs.get('val_' + metric):
                axs[i].plot(range(1, epoch + 2),
                            self.metrics['val_' + metric],
                            label='val_' + metric)

            axs[i].legend()
            axs[i].grid()

        plt.tight_layout()
        plt.show()

    def on_epoch_end(self, epoch, logs={}):
        self.plot(epoch, logs)

    def on_training_end(self, logs={}):
        self.plot(training_cfg.get('epochs', 10), logs)
        plt.save(model_cfg.get('checkpoints', 'models')+'/'+base.name+'.png')


plot_learning = PlotLearning()

checkpoint = ModelCheckpoint(
    training_cfg.get('checkpoints', 'models')+'/'+base.name+'.h5',
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=False,
    mode="min",
    save_freq="epoch",
    options=None,
)

backup_and_restore = BackupAndRestore(training_cfg.get('backups', 'backup'))

callbacks = [plot_learning, checkpoint, backup_and_restore]

for callback, options in callback_cfg.items():
    CallbackClass = getattr(
        __import__(
            'tensorflow.keras.callbacks',
            fromlist=[callback]
        ),  callback
    )
    callbacks.append(CallbackClass(**options))

print(
    'callbacks:',
    ['PlotLearning',
     'ModelCheckpoint',
     'BackupAndRestore',
     *callback_cfg.keys()]
 )

### Training Model
Start the training process with the defined configuration.

In [None]:
history = model.fit(
    gen_train,
    steps_per_epoch=training_cfg.get('training_steps_per_epoch'),
    validation_data=gen_valid,
    validation_steps=training_cfg.get('validation_steps_per_epoch'),
    callbacks=callbacks,
    epochs=training_cfg.get('epochs', 10),
    **training_cfg.get('options', dict())
)

### Evaluate Model Checkpoint
Evaluate the best checkpoint model saved during training.

In [None]:
val_model = load_model(model_cfg.get('checkpoints', 'models')+'/'+base.name+'.h5')
val_model.evaluate(gen_valid)