## **Training a Classification Model Using Labelchecker Data**
Here, we walk through the of training a classification model using the small dataset provided, demonstrating the essential steps of the process. Keep in mind that the model’s performance may be limited due to the small training data, but you can easily adapt our example to your own dataset.

Here’s an overview of what we’ll cover:

1. **Data Download**: Obtain the example data.
2. **Data Preparation**: Detail the necessary processing steps before training.
3. **Model Building**: Construct the classification model.
4. **Data Loading**: Set up data loaders for model training.
5. **Model Training**: Train the model.
6. **Model Evaluation**: Assess its performance.
7. **Model Serialization**: Save the trained model for future use.

Feel free to replace our example data with your own to train a model tailored to your specific needs 😎. Let’s get started!

## 0 **Import Libraries**

In [None]:
# import libraries
import sys
import cv2
import shutil
import json
from pathlib import Path
import requests
import zipfile
from rich import print
from tqdm import tqdm
from typing import Tuple
from enum import StrEnum, auto
import pandas as pd
import numpy as np

from plotly import express as px
from plotly import graph_objects as go
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.metrics import SparseCategoricalAccuracy
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

## 1. **Data Import**
Let's import some data and start exploring it!

You could do this with the [Example data](https://www.dropbox.com/s/4p1e5j9p9v8xj2s/data.zip?dl=1]) as well


In [None]:
# set the path to the data
data_path = Path("data")

# set dataset name
dataset_name = "example"
data_path = data_path.joinpath(dataset_name)
# make sure the data directory and subdirectories exists
data_path.mkdir(exist_ok=True, parents=True)

In [None]:
# fetch all Labelchecker data files from the data directory
data_files = list(data_path.glob(f"**/LabelChecker_*.csv"))
print(f"Found {len(data_files)} data files")

## 2. **Data Preparation and data cleaning**
The goal is to streamline the process and ensure consistency across all data files before we train the model.  

To clean and prepare the data we:
1. Subset data that has a `LabelTrue value`
2. drop columns with only `missing values`
3. drop columns with `default values`
4. set `image paths`
5. drop columns with `object` data
6. remove labels with less than N examples
7. `encode` label

All this is done in one function.

In [None]:
# check for default values function
def is_default(series: pd.Series) -> bool:
    return len(series.unique()) == 1


# drop all object columns except for LabelTrue function
def is_object(
    series: pd.Series,
    columns_to_keep: list[str] = ["LabelTrue", "ImageFilename", "CollageFile"],
) -> bool:
    if series.name in columns_to_keep:
        return False
    return series.dtype == "object"

# drop labels with less than N examples
def drop_labels_with_less_than_examples(data: pd.DataFrame, min_examples: int) -> pd.DataFrame:
    return data.groupby("LabelTrue").filter(lambda x: len(x) >= min_examples)

# build image paths
def build_image_path(df: pd.DataFrame, directory: Path) -> Tuple[bool, list[str]]:
    """
    Builds a list of image paths based on the given DataFrame and directory.

    Args:
        df (pd.DataFrame): The DataFrame containing the image filenames and names.
        directory (Path): The directory where the images are located.

    Returns:
        Tuple[bool, list[str]]: A tuple containing a boolean value indicating whether the image paths are for collage files,
        and a list of image paths.

    Raises:
        FileNotFoundError: If any of the image files are missing.
    """
    is_collage = True
    image_paths = []
    if "ImageFilename" in df.columns:
        if not df["ImageFilename"].isnull().all() and not df["Name"].isnull().all():
            is_collage = False
            for name, filename in zip(df["Name"], df["ImageFilename"]):
                image_path = Path.joinpath(directory, name, filename)
                if not image_path.exists():
                    raise FileNotFoundError(f"file {filename} not found")
                image_paths.append(image_path.as_posix())
    if "CollageFile" in df.columns:
        if not df["CollageFile"].isnull().all():
            is_collage = True
            for collage_file in df["CollageFile"]:
                image_path = Path.joinpath(directory, collage_file)
                if not image_path.exists():
                    raise FileNotFoundError(f"file {collage_file} not found")
                image_paths.append(image_path.as_posix())
    return is_collage, image_paths


def load_training_data(
    data_files: list[Path], 
    encoder: LabelEncoder, 
    min_examples: int = 5,
) -> Tuple[pd.DataFrame, LabelEncoder]:
    """
    Load the training data from the data files, preprocess the data, and encode the labels.

    Args:
        data_files (list[Path]): A list of file paths to the training data files.
        encoder (LabelEncoder): An instance of the LabelEncoder class used for label encoding.
        min_examples (int, optional): The minimum number of examples required for each label. Defaults to 5.

    Returns:
        Tuple[pd.DataFrame, LabelEncoder]: A tuple containing the preprocessed training data as a DataFrame
        and the label encoder object.
    """
    data = []
    for data_file in data_files:
        if not data_file.exists():
            raise FileNotFoundError(f"File {data_file} not found")
        df = pd.read_csv(data_file)

        # Build the image paths
        is_collage, image_paths = build_image_path(df, data_file.parent)
        if image_paths:
            if is_collage:
                df["CollageFile"] = image_paths
            else:
                df["ImageFilename"] = image_paths
        data.append(df)
    data = pd.concat(data)

    # Drop rows with missing LabelTrue values
    data = data.loc[data["LabelTrue"].str.len() > 0]
    data = data.dropna(subset=["LabelTrue"])

    # Drop columns with all missing values
    data = data.dropna(axis=1, how="all")

    # Drop columns with default values
    data = data.loc[:, ~data.apply(is_default)]

    # Drop all object columns except for LabelTrue function
    data = data.loc[:, ~data.apply(is_object)]

    # Drop labels with less than N examples
    data = drop_labels_with_less_than_examples(data, min_examples=min_examples)
    
    # Drop ProbabilityScore column
    data = data.drop('ProbabilityScore', axis=1)

    # Encode the labels
    data["LabelTrue"] = encoder.fit_transform(data["LabelTrue"])
    return (data, encoder)

In [None]:
# helper function to print label counts
def print_label_counts(data: pd.DataFrame, class_names: list[str]):
    label_counts = data["LabelTrue"].value_counts()
    value_counts = {}
    for label, count in label_counts.items():
        value_counts[class_names[label]] = count

    # sort the labels by count
    sorted_value_counts = sorted(
        value_counts.items(), key=lambda x: x[1], reverse=False
    )
    sorted_labels = [label for label, count in sorted_value_counts]
    sorted_counts = [count for label, count in sorted_value_counts]

    # plot the label counts
    px.bar(
        x=sorted_counts,
        y=sorted_labels,
        title="Label Counts",
        orientation="h",
        labels={"x": "Count", "y": "Label"},
        width=800,
        height=1200,
    ).show()

In [None]:
# initialize the label encoder
encoder = LabelEncoder()

In [None]:
# load the training data
training_data, encoder = load_training_data(data_files, encoder)
print(f"the training data contains {training_data.shape[0]} samples")
print(
    f"the training data contains the following columns: {[column_name for column_name in training_data.columns]}"
)
print(
    f"the training data contains these labels: {encoder.classes_}; \na total of {len(encoder.classes_)} labels"
)
print_label_counts(training_data, encoder.classes_)

## 3. **Model Building**
We're going to train a model that uses the object images and features to classify the object class. 

Lets start with designing the model

In [None]:
# create model
def create_model(
    features_input_shape: tuple[int],
    image_input_shape: tuple[int, int, int],
    nr_classes: int,
    optimizer: tf.keras.optimizers.Optimizer = Adam(),
    loss: tf.keras.losses.Loss = SparseCategoricalCrossentropy(),
    metric: tf.keras.metrics.Metric = SparseCategoricalAccuracy(),
    features_normalization: layers.Normalization = None,
    image_augmentation: tf.keras.Sequential = None,
) -> tf.keras.Model:
    # Multi layer perceptron model
    features_input = layers.Input(shape=features_input_shape, name="features")
    if features_normalization:
        x1 = features_normalization(features_input)
        x1 = layers.Dense(344, activation="relu", name="dense_10")(x1)
    else:
        x1 = layers.Dense(344, activation="relu", name="dense_10")(features_input)
    x1 = layers.Dropout(0.2, name="dropout_10")(x1)
    x1 = layers.Dense(172, activation="relu", name="dense_11")(x1)
    x1 = layers.Dropout(0.15, name="dropout_11")(x1)
    x1 = layers.Dense(86, activation="relu", name="dense_12")(x1)

    # Convolution Neural Network model
    image_input = layers.Input(shape=image_input_shape, name="image")
    x2 = layers.Rescaling(1.0 / 255)(image_input)
    if image_augmentation:
        x2 = image_augmentation(x2)
    x2 = layers.Conv2D(16, 1, activation="relu", padding="same", name="conv2d_10")(x2)
    x2 = layers.MaxPooling2D(pool_size=1, padding="same", name="max_pooling2d_10")(x2)
    x2 = layers.Conv2D(32, 1, activation="relu", padding="same", name="conv2d_11")(x2)
    x2 = layers.Conv2D(64, 1, activation="relu", padding="same", name="conv2d_12")(x2)
    x2 = layers.GlobalAveragePooling2D(name="global_average_pooling2d_10")(x2)

    # concatenate MLP and CNN models
    x = layers.concatenate([x1, x2], name="concatenate_20")
    x = layers.Dense(500, activation="relu", name="dense_20")(x)
    x = layers.Dropout(0.1, name="dropout_20")(x)
    output = layers.Dense(nr_classes, activation="softmax", name="output")(x)

    # create the model
    model = tf.keras.Model(
        inputs=[features_input, image_input], outputs=output
    )  # note the order of the inputs
    
    # compile the model
    model.compile(optimizer=optimizer, loss=loss, metrics=[metric])
    return model

## 4. **Data loader**
Now that we’ve loaded and prepared the data, making it ready for training, we need to set up a data loader that can load the each image, retrieving its features, and obtaining its label.


In [None]:
#  read the image file
def decode_image(row: pd.Series, image_size: tuple[int, int, int]) -> tf.Tensor:
    if "ImageFilename" in row:
        image_string = tf.io.read_file(row["ImageFilename"])
        image = tf.io.decode_png(image_string, channels=image_size[-1])  # png images
        return image
    else:
        image_path = tf.strings.as_string(row["CollageFile"])
        image = tf.numpy_function(read_tiff, [image_path], tf.uint8)
        image.set_shape([None, None, 3])
        image = remove_alpha_channel(
            image, image_size=image_size
        )  # RGBA (4 channels) to RGB (3 channels)
        image = crop_image(row, image)  # crop out the object image
        return image

# read TIFF images
def read_tiff(path_tensor: tf.Tensor):
    # path_tensor is already bytes, just decode it
    path = path_tensor.decode("utf-8")
    img = cv2.imread(path, cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"Image not found at path: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img.astype(np.uint8)

# remove the alpha channel
def remove_alpha_channel(image, image_size: tuple[int, int, int]) -> tf.Tensor:
    return tf.convert_to_tensor(image[:, :, : image_size[-1]])  # remove alpha channel


# crop out the object image from the collage
def crop_image(row: pd.Series, image):
    image_x = tf.squeeze(row["ImageX"])
    image_y = tf.squeeze(row["ImageY"])
    image_width = tf.squeeze(row["ImageW"])
    image_height = tf.squeeze(row["ImageH"])
    return image[
        int(image_y) : int(image_y) + int(image_height),
        int(image_x) : int(image_x) + int(image_width),
    ]


def resize_image(image, image_size: tuple[int, int, int]) -> tf.Tensor:
    image = tf.image.resize(image, [image_size[0], image_size[1]])  # H, W only
    return image


# combining all the image processing functions
def get_image(row: pd.Series, image_size: tuple[int, int, int]) -> tf.Tensor:
    image = decode_image(row, image_size=image_size)
    return resize_image(image, image_size=image_size)

In [None]:
# object features
def get_features(row: pd.Series, feature_names: list[str]) -> tf.Tensor:
    return tf.convert_to_tensor(
        [float(row[feature]) for feature in feature_names], dtype=tf.float64
    )

In [None]:
# labels
def get_label(row: pd.Series):
    return row.pop("LabelTrue")

In [None]:
def get_data(
    row: pd.Series,
    image_size: Tuple[int, int, int],
    feature_names: list[str],
):
    image = get_image(row, image_size=image_size)
    features = get_features(row, feature_names=feature_names)
    label = get_label(row)
    return (
        features,
        image,
    ), label  # Note: the order of the features and image is important for the model input

### 4.1 *Data loader parameters*
To initialize the dataloader we have to set a few parameters.

In [None]:
# set the image size
image_size = (32, 32, 3)

# set the object describing feature names
feature_names = [
    column_name
    for column_name in training_data.select_dtypes(exclude="object").columns
    if column_name
    not in [
        "LabelTrue",  # this is the target column and no longer a "object" data type but integer
        "Id",
        "CalImage",
        "ElapsedTime",
        "ImageY",
        "ImageX",
        "ImageW",
        "ImageH",
        "IntensityCalimage",
        "SrcX",
        "SrcY",
        "SrcImage",
    ]
]
feature_size = len(feature_names)
nr_of_classes = len(encoder.classes_)

# selected features check
print(f"Selected features: {feature_names}")
print(f"Number of selected features: {feature_size}")

### 4.2 *Train-test split*
We need to split the data in train and test data. The test data we use for validating the model during training and to detect any model overfitting.

In [None]:
X_train, X_test = train_test_split(
    training_data, stratify=training_data["LabelTrue"], test_size=0.2, random_state=42
)  # we set the random state for reproducibility

In [None]:
batch_size = 22

# create the training datasets
train_ds = tf.data.Dataset.from_tensor_slices(dict(X_train))
train_ds = train_ds.map(
    lambda x: get_data(x, image_size=image_size, feature_names=feature_names),
    num_parallel_calls=tf.data.AUTOTUNE,
)
train_ds = (
    train_ds.shuffle(buffer_size=1024)
    .batch(batch_size=batch_size)
    .prefetch(tf.data.AUTOTUNE)
)

In [None]:
# create the test dataset
test_ds = tf.data.Dataset.from_tensor_slices(dict(X_test))
test_ds = test_ds.map(
    lambda x: get_data(x, image_size=image_size, feature_names=feature_names),
    num_parallel_calls=tf.data.AUTOTUNE,
)
test_ds = test_ds.batch(batch_size=batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
# Check if train_ds is empty
if train_ds:
    # Inspect the first batch of the train
    for (features, images), labels in train_ds.take(1):
        print(f"image shape: {images.shape}")
        print(f"features shape: {features.shape}")
        print(f"label shape: {labels.shape}")

        # plot the images
        plt.figure(figsize=(10, 10), frameon=False)
        for i in range(9):
            ax = plt.subplot(3, 3, i + 1)
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(encoder.classes_[labels[i]], color="lightgreen")
            plt.axis("off")
else:
    print("train_ds is empty")

In [None]:
# check if val_ds is empty
if test_ds:
    # Iterate over the dataset
    for (features, image), label in test_ds.take(1):
        print(f"image shape: {image.shape}")
        print(f"features shape: {features.shape}")
        print(f"label shape: {label.shape}")
else:
    print("val_ds is empty")

## 5. **Model Training**

### 5.1. *Initialize feature normalization layer*
We need to set the mean and standard deviation for the normalization layer before running

In [None]:
# initialize and set the normalization layer
def create_normalization_layer(
    data: pd.DataFrame, features: list[str]
) -> layers.Normalization:
    # features = data[features].to_numpy(dtype=np.float64)
    normalization_layer = layers.Normalization()
    normalization_layer.adapt(
        tf.convert_to_tensor(data[features].to_numpy(), dtype=tf.float64)
    )
    return normalization_layer

In [None]:
# create the normalization layer
features_normalization = create_normalization_layer(X_train, feature_names)
features_normalization(X_train[feature_names].to_numpy()[:1])

### 5.2. *Initialize the model*

In [None]:
# create the model
model = create_model(
    features_input_shape=(feature_size,),
    image_input_shape=image_size,
    nr_classes=nr_of_classes,
    features_normalization=features_normalization,
    optimizer=Adam(learning_rate=0.001) if sys.platform == "win32" else Adam(learning_rate=0.001, clipnorm=1.0),
)

In [None]:
# print the model summary
tf.keras.utils.plot_model(model, show_shapes=True, show_layer_names=True)
# note: you need to install Graphviz for the plot to work (https://graphviz.gitlab.io/download/)

### 5.3. *Callbacks*
A callback performs actions at various stages of training. We're going to use two callbacks, namely:
- **ModelCheckpoint** save the model after each epoch when the model has improved; and
- **EarlyStopping** stop the training of the model when classification performance did not increase for 6 epochs

In [None]:
def get_model_version(models_dir: Path, model_name: str) -> Path:
    """
    Get the version of the model based on the existing models in the model directory.

    Args:
        model_dir (Path): The directory where the models are stored.
        model_name (str): The name of the model.

    Returns:
        Path
    """
    version = 1
    model_path = Path.joinpath(models_dir, model_name, str(version))
    while model_path.exists():
        version += 1
        model_path = Path.joinpath(models_dir, model_name, str(version))
    return model_path

We'll be saving the models directly into the Classification service, found at this location:

```bash
|--src
    |-- services
        |--classification
            |--ObjectClassifcation
                |--models
                    |--<model_name>
                        |--<model_version>
                            |--config.json
                            |--<serialized_model>
                |--...
    |--....
|--main.py
```

In [None]:
# set the path to service
path_to_service = Path().joinpath("src", "services", "classification")
service_name = "ObjectClassification"

# set the path to the models directory
models_dir = Path().joinpath(path_to_service, service_name, "models")
model_name = "Example" 
model_dir = get_model_version(models_dir, model_name)
print(f"Model directory: {model_dir}")

In [None]:
# Model checkpoint callback
checkpoint_filepath = Path.joinpath(model_dir, "checkpoint", "checkpoint.weights.h5")
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath.as_posix(),
    save_weights_only=True,
    monitor="val_sparse_categorical_accuracy",
    mode="max",
    save_best_only=True,
)
checkpoint_dir = checkpoint_filepath.parent
print(f"Model checkpoint directory: {checkpoint_dir}")

In [None]:
# early stopping callback
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor="val_sparse_categorical_accuracy",
    patience=6,
    min_delta=0.001,
    mode="max",
    restore_best_weights=True,
)

### 5.4 *Train the model*

In [None]:
epochs = 25

# train the model
history = model.fit(
    train_ds,
    validation_data=test_ds,
    epochs=epochs,
    callbacks=[model_checkpoint_callback, early_stopping_callback],
    verbose=1,
)

# note: make sure your computer doesn't go into sleep-mode while training! The process will stop!

## 6. **Model evaluation**

In [None]:
# plot the training history
def plot_training_history(history: tf.keras.callbacks.History):
    history_df = pd.DataFrame(history.history)
    history_df.loc[:, ["loss", "val_loss"]].plot(title="Loss")
    history_df.loc[
        :, ["sparse_categorical_accuracy", "val_sparse_categorical_accuracy"]
    ].plot(title="Accuracy")


plot_training_history(history)

In [None]:
# load model weights from best model checkpoint
model.load_weights(checkpoint_filepath)

In [None]:
# Evaluate the model
loss, acc = model.evaluate(test_ds, verbose=2)
print("Trained model, accuracy: {:5.2f}%".format(100 * acc))

In [None]:
# make predictions
predictions = model.predict(test_ds)
predicted_labels = np.argmax(predictions, axis=1)

In [None]:
# plot confusion matrix
def plot_confusion_matrix(
    true_labels: np.ndarray,
    predicted_labels: np.ndarray,
    class_names: list[str],
    text_size: int = 10,
    normalize: bool = True,
    width: int = 1000,
    height: int = 1000,
):
    cm = confusion_matrix(
        y_true=true_labels,
        y_pred=predicted_labels,
        normalize="true" if normalize else None,
    )
    # normalize the confusion matrix

    fig = go.Figure(
        data=go.Heatmap(
            z=cm,
            x=class_names,
            y=class_names,
            colorscale="Viridis",
            showscale=False,
            text=cm,
            texttemplate="%{text:.2f}",
            textfont={"size": text_size},
        )
    )

    fig.update_layout(
        title="Confusion Matrix",
        title_x=0.5,
        xaxis_title="Predicted",
        yaxis_title="True",
        autosize=False,
        width=width,
        height=height,
    )

    fig.show()

In [None]:
plot_confusion_matrix(
    predicted_labels=predicted_labels,
    true_labels=X_test["LabelTrue"],
    class_names=encoder.classes_,
    text_size=10,
    normalize=True,
    width = 1000,
    height = 1000,
)

In [None]:
# print classification report
print(
    classification_report(
        X_test["LabelTrue"], predicted_labels, target_names=encoder.classes_, zero_division=0
    )
)

## 7. **Model Serialization**
We've trained the model, evaluated it and now it's time to save it, or serialize it.  
We save:
- the model itself
- the model configuration
- the evaluation output ???

let's start with the model configuration.

### 7.1. *Model configuration*
Often the first model we train is not the one we end up using. We might want to add more data, test specific features, increase the image size, increase model size, etc. Whatever we change, we want our model to work, so we need to save these choices in a `configuration` file. To do this we'll use a python class which we save as a .json file.

The python class is called `ModelConfig` and in this configuration class we save the following information:
- **name**: name of the model
- **version**: which version of the model
- **framework**: e.g. Tensorflow
- **Class_names**: the encoder classes to be able to translate predicted numbers back to labels
- **Input_shape**: input shape of the image
- **Features**: list of features we used in training

In [None]:
# ModelConfig class
class ModelConfig:
    def __init__(
        self,
        name: str,
        version: str,
        framework: str,
        class_names: list[str],
        input_shape: Tuple[int, int, int],
        features: list[str],
    ) -> None:
        self.Name: str = name
        self.Version: str = version
        self.Framework: str = framework
        self.Class_names: list[str] = class_names
        self.Input_shape: list[int] = list(input_shape)
        self.Features: list[str] = features

        # Check if any value is None
        if any(value is None for value in self.__dict__.values()):
            raise ValueError("Not all values have been initialized")

    # representation of the class
    def __repr__(self) -> str:
        return f"ModelConfig(Name={self.Name},\n Version={self.Version},\n Framework={self.Framework},\n Class_names={self.Class_names},\n Input_shape={self.Input_shape},\n Features={self.Features})"

In [None]:
class Frameworks(StrEnum):
    TENSORFLOW = auto()
    PYTORCH = auto()
    ONNX = auto() 
    SKOPS = auto() # scikit-learn pipeline

In [None]:
# create the model configurations
model_configurations = ModelConfig(
    name=model_name,
    version=model_dir.name,
    framework=Frameworks.TENSORFLOW,
    class_names=list(encoder.classes_),
    input_shape=image_size,
    features=feature_names,
)
print(model_configurations)

In [None]:
# ensure the model directory exists
model_dir.mkdir(parents=True, exist_ok=True)

# save model configuration to json
model_config_file = Path.joinpath(model_dir, "config.json")
with open(model_config_file, "w") as f:
    json.dump(model_configurations.__dict__, f, indent=4)

### 7.2. *Save model*

The model and configuration file are saved in a folder with the model version in the 'model_name' folder.

In [None]:
def save_model(model: tf.keras.Model, model_dir: Path, model_suffix: str):
    # ensure the model directory exists
    model_dir.mkdir(parents=True, exist_ok=True)

    if model_suffix == "keras":
        model_path = Path.joinpath(model_dir, f"model.{model_suffix}")
        print(f"Model path: {model_path}")
        model.save(model_path.as_posix())
    elif model_suffix in ["h5", "hdf5"]:  # legacy formats
        model_path = Path.joinpath(model_dir, f"model.{model_suffix}")
        print(f"Model path: {model_path}")
        model.save(model_path.as_posix())
        shutil.copy(
            checkpoint_filepath,
            Path.joinpath(model_dir, f"model.weights.{model_suffix}"),
        )  # needed for legacy formats and must be loaded in after loading the model for classification
    else:
        raise ValueError(f"Model suffix {model_suffix} not supported")

In [None]:
model_suffix = "keras"
save_model(model, model_dir, model_suffix)

Additional resources:
For much of this notebook we used inspiration from theses tensorflow tutorials:
1. [Load and preprocess images](https://www.tensorflow.org/tutorials/load_data/images)
2. [Load a Pandas dataframe](https://www.tensorflow.org/tutorials/load_data/pandas_dataframe)
