# Registering a TF/Keras (IMBD Movie Review Classification) model on Verta

Within Verta, a "Model" can be any arbitrary function: a traditional ML model (e.g., sklearn, PyTorch, TF, etc); a function (e.g., squaring a number, making a DB function etc.); or a mixture of the above (e.g., pre-processing code, a DB call, and then a model application.) See more [here](https://docs.verta.ai/verta/registry/concepts).

This notebook provides an example of how to catalog a Keras model on Verta as a Verta Standard Model by convinience functions.

Updated for Verta version: 0.21.0

This notebook walks through training an RNN classification model through keras, and cataloging them to the Verta platform.

<a href="https://colab.research.google.com/github/VertaAI/examples/blob/registry_examples/registry/tensorflow/keras-imdb-model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 0. Imports

In [None]:
!pip install verta
!pip install tensorflow

In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

## 1. Register Model

### 1.1 (Optional) Model Training 

A model has to exist before we can register, so we will also train one here in our notebook.

If you already have a keras model saved in a file, you can skip this step and directly register it on the catalog

#### 1.1.1 Load Training Data

The IMDB large movie review dataset is a *binary classification* dataset—all the reviews have either a *positive* or *negative* sentiment.

Download the dataset using [TFDS](https://www.tensorflow.org/datasets). The dataset comes with an inbuilt subword tokenizer.

In [None]:
# loading the dataset

dataset, info = tfds.load('imdb_reviews/subwords8k', with_info=True,
                          as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']

#### 1.1.2 Train/Test code

**Model Info**

Build a `tf.keras.Sequential` model and start with an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.

A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input—and then to the next.

The `tf.keras.layers.Bidirectional` wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the output. This helps the RNN to learn long range dependencies.

In [None]:
# Define hyperparameters

BUFFER_SIZE = 10000
BATCH_SIZE = 64

train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.padded_batch(BATCH_SIZE, train_dataset.output_shapes)

test_dataset = test_dataset.padded_batch(BATCH_SIZE, test_dataset.output_shapes)

tokenizer = info.features['text'].encoder

hyperparams = {
    'num_epochs': 5,
    'optimizer': 'adam',
    'loss': 'binary_crossentropy',
    'vocab_size': tokenizer.vocab_size,
    'metrics': 'accuracy'
}

In [None]:
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(tokenizer.vocab_size, 64),
    tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

model.compile(loss=hyperparams['loss'],
              optimizer=hyperparams['optimizer'],
              metrics=[hyperparams['metrics']])

history = model.fit(train_dataset,
                    epochs=hyperparams['num_epochs'],
                    validation_data=test_dataset)

As a sanity check, we can validate that our model can produce predictions.

In [None]:
test_loss, test_acc = model.evaluate(test_dataset, verbose=0)

print('Test Loss: {}'.format(test_loss))
print('Test Accuracy: {}'.format(test_acc))

### 1.2. Register Model to Verta Model Catalog

Now that the model is in a good shape, we can register it into the Verta platform.

We'll connect to Verta through the [Verta Python Client](https://verta.readthedocs.io/en/main/_autogen/verta.Client.html)
create a [registered model](https://verta.readthedocs.io/en/master/_autogen/verta.registry.entities.RegisteredModel.html) for our Binary Text Classification Model  
and a [version](https://verta.readthedocs.io/en/master/_autogen/verta.registry.entities.RegisteredModelVersion.html) to associate this particular model with.

All of these can be viewed in the Verta web app once they are created.

Note: If your model uses CustomObject you have to register the model into the catalog using a serialized version of the model by extending the [VertaModelBase](https://verta.readthedocs.io/en/master/_autogen/verta.registry.VertaModelBase.html?highlight=VertaModelBase#verta.registry.VertaModelBase) class

In [None]:
# Paste your credentials in this cell or anywhere above this along with the code snippet to connect to Verta Platform

from verta import Client

client = Client(
        #   host="app.verta.ai",
        #   email="user@verta.ai",
        #   dev_key="a765b2de-786d-466c-b2d8-thiye06f80d5",
        )

In [None]:
# Create/Get a Verta registered model

registered_model = client.get_or_create_registered_model(
    "IMDB-Review-clf",
    desc="Binary text classifier to classify movie reviews as positive or negative",
    labels=["NLP", "Classification", "Neural Net"],
)

In [None]:
from verta.environment import Python

# uncommment the below if you want to load it from a saved model (h5 or tf saved)
# import tf
# model = tf.keras.models.load_model("<file_to_model.h5> or <path_to_tf_saved_model>")

model_version_v1 = registered_model.create_standard_model_from_keras(
    model, # The loaded model object, can be the one trained in the same file or loaded from keras load_model function
    environment=Python(requirements=[ # Add the required libraries for the model to run
    "tensorflow"
    ]), 
    name="v1-rnn", # Name to identify the version in the model versions tab
)

And That's it. You should now be able to see your Model, on your Catalog.