# Text Classification fine tuning using TensorFlow and the Intel® Transfer Learning Tool API

This notebook uses the `tlt` library to fine tune a TF Hub pretrained model for text classification.

## 1. Import dependencies and setup parameters

This notebook assumes that you have already followed the instructions to setup a TensorFlow environment with all the dependencies required to run the notebook.

In [None]:
import os
import pandas as pd
import tensorflow as tf

# tlt imports
from tlt.models import model_factory
from tlt.datasets import dataset_factory

# Specify a directory for the dataset to be downloaded
dataset_dir = os.environ["DATASET_DIR"] if "DATASET_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "dataset")
     
# Specify a directory for output
output_dir = os.environ["OUTPUT_DIR"] if "OUTPUT_DIR" in os.environ else \
    os.path.join(os.environ["HOME"], "output")

print("Dataset directory:", dataset_dir)
print("Output directory:", output_dir)

## 2. Get the model

In this step, we call the TLT model factory to list supported TensorFlow image classification models. This is a list of pretrained models from TFHub that we tested with our API. Optionally, the `verbose=True` argument can be added to the `print_supported_models` function call to get more information about each model (such as the links to TFHub, the original dataset, etc).

In [None]:
# See a list of available text classification models
model_factory.print_supported_models(use_case='text_classification', framework='tensorflow')

Use the TLT model factory to get one of the models listed in the previous cell. The `get_model` function returns a TLT model object that will later be used for training.

In [None]:
model_name = "small_bert/bert_en_uncased_L-2_H-128_A-2"
framework = "tensorflow"

model = model_factory.get_model(model_name, framework)

print("Model name:", model.model_name)
print("Framework:", model.framework)
print("Use case:", model.use_case)

## 3. Get the dataset

In this step we get a dataset from the [TensorFlow datasets catalog](https://www.tensorflow.org/datasets/catalog/overview) to use for fine tuning. The dataset factory currently supports the following text classification datasets: [imdb_reviews](https://www.tensorflow.org/datasets/catalog/imdb_reviews), [glue/sst2](https://www.tensorflow.org/datasets/catalog/imdb_reviews), and [glue/cola](https://www.tensorflow.org/datasets/catalog/glue#gluecola_default_config).

In [None]:
# Supported datasets: imdb_reviews, glue/sst2, glue/cola
dataset_name = "imdb_reviews"
dataset = dataset_factory.get_dataset(dataset_dir, model.use_case, model.framework, dataset_name,
                                      dataset_catalog="tf_datasets", shuffle_files=True)

print(dataset.info)

print("\nClass names:", str(dataset.class_names))

In [None]:
# Batch the dataset and create splits for training and validation
dataset.preprocess(batch_size=32)
dataset.shuffle_split(train_pct=0.75, val_pct=0.25)

## 4. Fine tuning

The TLT model's train function is called with the dataset that was just prepared, along with an output directory for checkpoints, and the number of training epochs.

Mixed precision uses both 16-bit and 32-bit floating point types to make training run faster and use less memory. It is recommended to enable auto mixed precision training when running on platforms that support bfloat16 (Intel third or fourth generation Xeon processors). If it is enabled on a platform that does not support bfloat16, it can be detrimental to the training performance.

In [None]:
# If enable_auto_mixed_precision is set to None, auto mixed precision will be automatically enabled when running 
# with Intel fourth generation Xeon processors, and disabled for other platforms.
enable_auto_mixed_precision = None

history = model.train(dataset, output_dir, epochs=1, enable_auto_mixed_precision=enable_auto_mixed_precision)

## 5. Evaluate

Next we evaluate the fine tuned model. The model's evaluate function returns a list of metrics calculated from the dataset's validation subset.

In [None]:
metrics = model.evaluate(dataset)

# Print evaluation metrics
for metric_name, metric_value in zip(model._model.metrics_names, metrics):
    print("{}: {}".format(metric_name, metric_value))

## 6. Predict

The model's predict function can be called with a batch of data from the dataset.

In [None]:
# Get a single batch from the dataset object
data_batch, labels = dataset.get_batch()

# Call predict using the batch
batch_predictions = model.predict(data_batch)

# Maximum number of rows to show in the data frame
max_items = 10

# Collect the sentence text, score, and actual label for the batch
prediction_list = []
for i, (text, actual_label) in enumerate(zip(data_batch, labels)):
    sentence = text.numpy().decode('utf-8')
    score = batch_predictions[i]
    prediction_list.append([sentence,
                            tf.get_static_value(score)[0],
                            dataset.get_str_label(int(actual_label.numpy()))])
    if i + 1 >= max_items:
        break

# Display the results using a data frame
result_df = pd.DataFrame(prediction_list, columns=["Input Text", "Score", "Actual Label"])
result_df.style.hide(axis="index")

Raw text can also be passed to the predict function.

In [None]:
result = model.predict("Awesome movie")

print("Predicted score:", float(result))
print("Predicted label:", dataset.get_str_label(float(result)))

## 7. Export the saved model

Lastly, we can call the TLT model export function to generate a `saved_model.pb`. The model is saved in a format that is ready to use with [TensorFlow Serving](https://github.com/tensorflow/serving). Each time the model is exported, a new numbered directory is created, which allows serving to pick up the latest model.

In [None]:
saved_model_dir = model.export(output_dir)

## Citations

```
@InProceedings{maas-EtAl:2011:ACL-HLT2011,
  author    = {Maas, Andrew L.  and  Daly, Raymond E.  and  Pham, Peter T.  and  Huang, Dan  and  Ng, Andrew Y.  and  Potts, Christopher},
  title     = {Learning Word Vectors for Sentiment Analysis},
  booktitle = {Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics: Human Language Technologies},
  month     = {June},
  year      = {2011},
  address   = {Portland, Oregon, USA},
  publisher = {Association for Computational Linguistics},
  pages     = {142--150},
  url       = {http://www.aclweb.org/anthology/P11-1015}
}

@inproceedings{wang2019glue,
  title={{GLUE}: A Multi-Task Benchmark and Analysis Platform for Natural Language Understanding},
  author={Wang, Alex and Singh, Amanpreet and Michael, Julian and Hill, Felix and Levy, Omer and Bowman, Samuel R.},
  note={In the Proceedings of ICLR.},
  year={2019}
}
```