# RVAI Training Example
In this example we will train an ImageClassifierCell (which wraps a small convolutional neural net) on the Fashion MNIST dataset.

## Prerequisites
First, let's install all the prerequisites:

In [None]:
!pip install -qqq rvai==0.6.0rc2 pygraphviz

In [None]:
# some global notebook configuration
%matplotlib inline
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

## Create a TrainableCell
Okay, let's create a cell. A cell can can be used as building block in a computation DAG - called a Pipeline in RVAI. Since our Cell should also be trainable, we select the `TrainableCell` base class.
The basic skeleton of a `TrainableCell` can be found in the [docs [1]](https://base.rvai.dev/rvai.base.html#rvai.base.cell.TrainableCell). It tells us that we need a `load_model`, `predict`, `test` and `train` method. For the sake of this tutorial we will only implement `load_model` and `train` however. Testing and inference will be discussed in a subsequent tutorial.

### Cell IO

But before we start implementing these methods, we should think about the required IO 'ports' of our Cell first. What are the Inputs and the Ouputs of this building block? What kind of example data does the training need? Let's describe that first in code. RVAI uses [dataclasses [2]](https://docs.python.org/3/library/dataclasses.html) as a mechanism to encode that information in a convenient struct.

- [1] https://base.rvai.dev/rvai.base.html#rvai.base.cell.TrainableCell (WIP)
- [2] https://docs.python.org/3/library/dataclasses.html

In [None]:
from dataclasses import dataclass, field

# base classes for the structs used to describe a Cell's IO
from rvai.base.data import Inputs, Outputs, Samples, Annotations, Parameters, metadata

# some RVAI types we need for describing the fields of these IO dataclasses
from rvai.types import Image, Integer

In [None]:
# Inference mode IO

@dataclass
class ImageClassificationInputs(Inputs):
    image: Image = field(metadata=
        metadata(name="Image", description="The image to be classified."))

@dataclass
class ImageClassificationOutputs(Outputs):
    label: Integer = field(
        metadata=metadata(name="Class", description="The class of the image."))

# Training mode IO
        
@dataclass
class ImageClassificationSamples(Samples, ImageClassificationInputs):
    """Inherits from ImageClassificationInputs because the Samples this Cell expects during training are the same as its inputs."""

@dataclass
class ImageClassificationAnnotations(Annotations, ImageClassificationOutputs):
    """Inherits from ImageClassificationOutputs because the Annotations this Cell expects during training are the same as its outputs."""
    
# Parameters

@dataclass
class ImageClassificationParameters(Parameters):
    epochs: Integer = field(default=Integer(2), metadata=metadata(name="Epochs", description="The amount of times the training loop should process the data."))
    batch_size: Integer = field(default=Integer(4), metadata=metadata(name="Batch Size", description="SGD mini-batch size."))

We can now go over the declared IO dataclasses in more detail.

During inference, the Cell will have:
- images in the form of `Images` going in
- labels in the form of `Integers` going out

During training, the Cell will:
- receive examples in the form of an image (i.e. an `Image`) and a label (i.e. an `Integer`) going in

The `Samples` and `Annotations` dataclasses inherit from `Inputs` en `Outputs` respectively. This is purely out of convenience and in a lot of cases the sample and annotation types will differ from their inference counterparts.

Now, let's actually implement a TrainableCell!

### Cell Body

In [None]:
# necessary RVAI imports:
from rvai.base.cell import cell # used as a decorator to register a cell in RVAI
from rvai.base.cell import TrainableCell # base class, defines main functionality

# used for typing:
from rvai.base.cell import CellMode # enum, defines what mode the cell is running in
from rvai.base.data import Example, Dataset
from rvai.base.context import Context # required argument for most cell methods, do not worry about this yet
from typing import Type, Optional, Tuple, Sequence

# dependencies used in the cell body itself:
import numpy as np
import tensorflow as tf
tf.autograph.set_verbosity(1)
tf.logging.set_verbosity(tf.logging.ERROR)

from rvai.base import compat # convenience methods for integrating external ML framework (e.g. keras) code into RVAI
from rvai.base.evaluation import CellEvaluationUpdate
from rvai.base.training import Metrics

In [None]:
@cell # used for registering the ImageClassificationCell in RVAI
class ImageClassificationCell(TrainableCell):

    # A Cell's IO is declared by using class attributes.
    # These are the same dataclasses we defined before.
    inputs: Type[ImageClassificationInputs]
    outputs: Type[ImageClassificationOutputs]

    samples: Type[ImageClassificationSamples]
    annotations: Type[ImageClassificationAnnotations]

    parameters: Type[ImageClassificationParameters]
        
    @classmethod
    def load_model(
        cls,
        context: Context,
        parameters: ImageClassificationParameters,
        model_path: Optional[str],
        mode: CellMode,
    ):
        """
        Create or load a model.
        """        
        if model_path is not None:
            return tf.keras.models.load_model(model_path)
        else:
            model = tf.keras.Sequential()
            model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)))
            model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
            model.add(tf.keras.layers.Dropout(0.3))
            model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
            model.add(tf.keras.layers.MaxPooling2D(pool_size=2))
            model.add(tf.keras.layers.Dropout(0.3))
            model.add(tf.keras.layers.Flatten())
            model.add(tf.keras.layers.Dense(256, activation='relu'))
            model.add(tf.keras.layers.Dropout(0.5))
            model.add(tf.keras.layers.Dense(10, activation='softmax'))
            model.compile(loss='categorical_crossentropy',
                          optimizer='adam',
                          metrics=['accuracy'])

        return model

    @classmethod
    def _unpack_example(
        cls,
        example: Example[ImageClassificationSamples, ImageClassificationAnnotations],
    ) -> Tuple[np.ndarray, int]:
        """
        Convert a RVAI Example into plain numpy/python data types.
        """

        samples: ImageClassificationSamples = example[0]
        annotations: ImageClassificationAnnotations = example[1]

        # standardize image input
        image = np.atleast_3d(samples.image)
        label = int(annotations.label)

        return image, label

    @classmethod
    def _collate_batch(
        cls,
        examples: Sequence[Tuple[np.ndarray, np.ndarray]],
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Collate a bunch of unpacked Examples (see `_unpack_example`) into a batch format Keras understands.
        """
        
        x, y = zip(*examples)

        images: np.ndarray = np.stack(arrays=x, axis=0)
        labels: np.ndarray = tf.keras.utils.to_categorical(
            y=y, num_classes=10, dtype=np.float32
        )

        return images, labels

    @classmethod
    def train(
        cls,
        context: Context,
        parameters: ImageClassificationParameters,
        model,
        train_dataset: Dataset[
            ImageClassificationSamples, ImageClassificationAnnotations
        ],
        validation_dataset: Dataset[
            ImageClassificationSamples, ImageClassificationAnnotations
        ],
    ):
        """
        Train a model on a Dataset.
        """

        # Integer -> int
        batch_size = int(parameters.batch_size)

        train_generator = compat.keras.as_generator(
            train_dataset,
            batch_size=batch_size,
            process_example=cls._unpack_example,
            process_batch=cls._collate_batch,
        )

        validation_generator = compat.keras.as_generator(
            validation_dataset,
            batch_size=batch_size,
            process_example=cls._unpack_example,
            process_batch=cls._collate_batch,
        )

        nb_epochs = int(parameters.epochs)
        nb_training_batches = int(len(train_dataset) // batch_size)
        nb_validation_batches = int(len(validation_dataset) // batch_size)

        model.fit_generator(
            generator=train_generator,
            steps_per_epoch=nb_training_batches,
            validation_data=validation_generator,
            validation_steps=nb_validation_batches,
            epochs=nb_epochs,
            verbose=0,
            callbacks=[compat.keras.training_update_callback(context)],
        )

        model_path = context.training.get_model_path()

        tf.keras.models.save_model(model=model, filepath=model_path)

        return model_path

    @classmethod
    def test(
        cls,
        context: Context,
        parameters: ImageClassificationParameters,
        model,
        test_dataset: Dataset[
            ImageClassificationSamples, ImageClassificationAnnotations
        ],
    ):
        
        batch_size = int(parameters.batch_size)
        
        test_generator = compat.keras.as_generator(
            test_dataset,
            batch_size=batch_size,
            process_example=cls._unpack_example,
            process_batch=cls._collate_batch,
        )

        nb_test_batches = int(len(test_dataset) // batch_size)
    
        metrics = model.evaluate_generator(
            generator=test_generator,
            steps=nb_test_batches,
            verbose=0,
        )
        
        return Metrics(
            {name: metric for name, metric in zip(model.metrics_names, metrics)},
            performance="acc",
        )

    @classmethod
    def predict(
        cls,
        context: Context,
        parameters: ImageClassificationParameters,
        model,
        inputs: ImageClassificationInputs,
    ):
        raise NotImplementedError

Let's discuss.

**the `load_model` method**

The `load_model` method is fairly straightforward. It should always return a model object. RVAI does not care about what this model looks like. It will just pass whatever this method returns to all other methods that need a model (e.g. `TrainableCell::train`). This method also has an optional argument `model_path`. When a model path is given, it is expected from the `load_model` method that it loads and returns _that_ specific model from disk. The disk format of the model is always something that the cell understands because it was previously produced and saved to disk by the Cell's train method.

In this case the model object is a tf.keras CNN created using the Sequential API. It's on disk representation is a tensorflow SavedModel.

**the `_unpack_example` and `_collate_batch` methods**

These methods are not part of the `TrainableCell` API and will be discussed in the `train` section.

**the `train` method**

The `train` method's body looks like a normal deep learning training loop except for some data preparation code. To massage the data from an RVAI Dataset, which contains RVAI typed data (e.g. `rvai.types.Image`, `rvai.types.Integer`, etc. - as described the `Samples` and `Annotations` dataclasses) into something the chosen deep learning framework understands. RVAI provides some compatibility methods to aid this transformation. We use `rvai.base.compat.keras.as_generator` in this case. This function transforms a Dataset into an infinite keras generator. We do have to supply some helper methods though:

- `process_example [=cls._unpack_example]` : a method to transform individual Dataset Examples
- `process_batch [=cls._collate_batch]` : a method to transform batches of data

During a training the cell can send updates about the training process by either a) `yield`-ing `TrainingUpdate`s or by using (in the case of a Keras training loop) a `rvai.base.compat.keras.training_update_callback`, which automatically creates and sends `TrainingUpdate`s.

At the end of a training a model path should be returned. It should always point to a trained model on disk.

## Creating a Pipeline

To start using the `ImageClassificationCell` we build a (single-cell) `Pipeline` with it. Next, we also create a `TrainingPipeline` which defines how the cell should be trained. This is where you would include preprocessing.

In [None]:
# necessary
from rvai.base.pipeline import Pipeline, TrainingPipeline
from rvai.base.pipeline.declarative import pipeline

# typing
from rvai.base.pipeline.declarative import DatasetAnnotations, DatasetSamples, PipelineCells, PipelineInputs, PipelineOutputs, PipelineConnections, TrainingPipelines

### Training Pipeline

In [None]:
@pipeline
class ImageClassificationTrainingPipeline(TrainingPipeline):
    
    # this class does not have to be inline
    class TrainingCells(PipelineCells): 
        classifier: ImageClassificationCell
    
    cells: TrainingCells
    train: ImageClassificationCell = cells.classifier # this attribute marks which cell you want to train with this training pipeline
    samples: DatasetSamples = (cells.classifier.samples.image,)
    annotations: DatasetAnnotations = (cells.classifier.annotations.label,)

### Inference Pipeline

In [None]:
@pipeline
class ImageClassificationPipeline(Pipeline):
    
    # this class does not have to be inline
    class InferenceCells(PipelineCells):
        classifier: ImageClassificationCell
    
    cells: InferenceCells
    inputs: PipelineInputs = (("image", cells.classifier.inputs.image),)
    outputs: PipelineOutputs = (("label", cells.classifier.outputs.label),)
    training_pipelines: TrainingPipelines = (
        (cells.classifier, ImageClassificationTrainingPipeline),
    )

### Visualize Pipelines

WIP, visualization of pipelines is very crude at the moment!

In [None]:
inference_pipeline = ImageClassificationPipeline()
training_pipeline = inference_pipeline.get_training_pipeline(inference_pipeline.cells.classifier)
inference_pipeline.show()
training_pipeline.show()

## Training

Finally we start the training process on the debug runtime.

### Dataset

In order to train our cell we also need a Dataset of course. An RVAI compatible Dataset requires two methods to be implemented `__getitem__` and `__len__` (cfr. a [pytorch Dataset](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html#dataset-class)). In this tutorial we will use [Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist).

In [None]:
# required RVAI base class
from rvai.base.data import Dataset

# used for typing
from rvai.types import Image, Integer
from typing import Sequence, Tuple
import numpy as np

# actual data
from tensorflow.keras.datasets import fashion_mnist

# some imports for displaying data
from IPython.display import display, HTML
import PIL

In [None]:
class FashionMNISTDataset(
    Dataset[ImageClassificationSamples, ImageClassificationAnnotations]
):
    def __init__(
        self, images: Sequence[np.ndarray], labels: Sequence[np.ndarray]
    ):
        self.images = images
        self.labels = labels

    def __getitem__(
        self, index
    ) -> Tuple[ImageClassificationSamples, ImageClassificationAnnotations]:
        return (
            ImageClassificationSamples(image=Image(self.images[index])),
            ImageClassificationAnnotations(label=Integer(self.labels[index])),
        )

    def __len__(self):
        return len(self.images)

# Class names for FashionMNIST
class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

train_dataset, validation_dataset = (FashionMNISTDataset(images=images, labels=labels) for images, labels in fashion_mnist.load_data())

# display an example image and its label
samples, annotations = train_dataset[0]
display(PIL.Image.fromarray(samples.image)); print(class_names[annotations.label])

### Training Loop

We only print the progress in this example code, but the returned `TrainingUpdate`s contain some other information as well. For example, all the metrics you have defined in your keras model.

In [None]:
from rvai.base.runtime import init, Training
from rvai.base.training import Tensorboard

In [None]:
# create a runtime, we choose the debug runtime
runtime = init("debug")

# generate a training pipeline
training_pipeline = inference_pipeline.get_training_pipeline(inference_pipeline.cells.classifier)

# configure a training task
training = Training(
    pipeline=training_pipeline,
    models={}, # no previous models yet
    parameters={"classifier": ImageClassificationParameters(epochs=1)}, # defaults are fine for us 
    train_dataset=train_dataset,
    validation_dataset=validation_dataset,
)


training_loop = runtime.start_training(training)

print('Starting training')
for update in training_loop.updates():
    print(f"\r[{update.progress * 100:.3}%] - accuracy: {update.metrics.values.get('acc')}", end='')
model_path = training_loop.result()
print(f'\nTraining done. Model can be found at: {model_path}')
# Stop the training process
training_loop.stop()

## Testing

Now we can evaluate the trained model on a test dataset.

In [None]:
from rvai.base.runtime import CellEvaluation

# configure a cell evaluation task
cell_evaluation = CellEvaluation(
    trainable_cell=training_pipeline.trainable_cell, # we select the cell we have trained
    model=model_path, # path to the trained model
    parameters=ImageClassificationParameters(), # defaults are fine for us
    dataset=validation_dataset, # normally this would be a separate test dataset, but for this tutorial re-using the validation dataset
)

evaluation_loop = runtime.start_cell_evaluation(cell_evaluation)

print('Starting cell evaluation')
for update in evaluation_loop.updates():
    print(f"\r{update}", end='')
result = evaluation_loop.result()
print(f'\nEvaluation done: {result}')
# Stop the evaluation loop
evaluation_loop.stop()

Stop the runtime

In [None]:
runtime.stop()