# Introduction to JumpStart - Text Classification

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

---

---
Welcome to Amazon [SageMaker JumpStart](https://docs.aws.amazon.com/sagemaker/latest/dg/studio-jumpstart.html)! You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through [SageMaker JumpStart API](https://sagemaker.readthedocs.io/en/stable/overview.html#use-prebuilt-models-with-sagemaker-jumpstart). 

In this demo notebook, we demonstrate how to use the JumpStart API for Text Classification. Text Classification refers to classifying an input sentence to one of the class labels of the training dataset.  We demonstrate following use cases of Text Classification models:

* How to fine-tune a pre-trained Transformer model to a custom dataset, and then run inference on the fine-tuned model.

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

---

1. [Set Up](#1.-Set-Up)
2. [Select a pre-trained model](#2.-Select-a-pre-trained-model)
3. [Finetune the pre-trained model on a custom dataset](#4.-Finetune-the-pre-trained-model-on-a-custom-dataset)
    * [Retrieve JumpStart Training artifacts](#4.1.-Retrieve-JumpStart-Training-artifacts)
    * [Set Training parameters](#4.2.-Set-Training-parameters)
    * [Train with Automatic Model Tuning (HPO)](#AMT)
    * [Start Training](#4.4.-Start-Training)
    * [Deploy & run Inference on the fine-tuned model](#4.5.-Deploy-&-run-Inference-on-the-fine-tuned-model)

## 1. Set Up
***
Before executing the notebook, there are some initial steps required for setup. This notebook requires latest version of sagemaker and ipywidgets.
***

In [None]:
!pip install sagemaker ipywidgets --upgrade --quiet

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

## 2. Select a pre-trained model
***
You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of JumpStart models can also be accessed at [JumpStart Models](https://sagemaker.readthedocs.io/en/stable/doc_utils/jumpstart.html#).
***

In [None]:
model_id = "tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2"

In [None]:
import IPython
from ipywidgets import Dropdown

# download JumpStart model_manifest file.
boto3.client("s3").download_file(
    f"jumpstart-cache-prod-{aws_region}", "models_manifest.json", "models_manifest.json"
)
with open("models_manifest.json", "rb") as json_file:
    model_list = json.load(json_file)

# filter-out all the Text Classification models from the manifest list.
tc_models_all_versions, tc_models = [
    model["model_id"] for model in model_list if "-tc-" in model["model_id"]
], []
[tc_models.append(model) for model in tc_models_all_versions if model not in tc_models]

# display the model-ids in a dropdown, for user to select a model.
dropdown = Dropdown(
    value=model_id,
    options=tc_models,
    description="JumpStart Text Classification Models:",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)
display(IPython.display.Markdown("## Select a JumpStart pre-trained model from the dropdown below"))
display(dropdown)

In [None]:
model_id = dropdown.value

## 3. Finetune the pre-trained model on a custom dataset
***
We discuss how a model can be finetuned to a custom dataset with any number of classes. 

The Text Embedding model can be fine-tuned on any text classification dataset in the same way the 
model available for inference has been fine-tuned on the SST2 movie review dataset.

The model available for fine-tuning attaches a classification layer to the Text Embedding model 
and initializes the layer parameters to random values. 
The output dimension of the classification layer is determined based on the number of classes 
detected in the input data. The fine-tuning step fine-tunes all the model 
parameters to minimize prediction error on the input data and returns the fine-tuned model. 
The model returned by fine-tuning can be further deployed for inference. 
Below are the instructions for how the training data should be formatted for input to the model.


- **Input:** A directory containing a 'data.csv' file. 
    - Each row of the first column of 'data.csv' should have integer class labels between 0 to the number of classes.
    - Each row of the second column should have the corresponding text. 
- **Output:** A trained model that can be deployed for inference. 
 
Below is an example of 'data.csv' file showing values in its first two columns. Note that the file should not have any header.

|   |   |
|---|---|
|0	|hide new secretions from the parental units| 
|0	|contains no wit , only labored gags| 
|1	|that loves its characters and communicates something rather beautiful about human nature| 
|...|...|
 
source: [TensorFlow Hub](model_url). License:[Apache 2.0 License](https://jumpstart-cache-alpha-us-west-2.s3-us-west-2.amazonaws.com/licenses/Apache-License/LICENSE-2.0.txt).
 
SST2 dataset is downloaded from [TensorFlow](https://www.tensorflow.org/datasets/catalog/glue#gluesst2).
 [Apache 2.0 License](https://jumpstart-cache-prod-us-west-2.s3-us-west-2.amazonaws.com/licenses/Apache-License/LICENSE-2.0.txt).
  [Dataset Homepage](https://nlp.stanford.edu/sentiment/index.html). 
***

### Set Training parameters
Now that we are done with all the setup that is needed, we are ready to fine-tune our model. To begin, let us create a [``sageMaker.estimator.Estimator``](https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html) object. This estimator will launch the training job. 

There are two kinds of parameters that need to be set for training. 

The first one are the parameters for the training job. These include: Training data path. This is S3 folder in which the input data is stored 
***
The second set of parameters are algorithm specific training hyper-parameters.
***

In [None]:
import json
from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket


# Sample training data is available in this bucket
training_data_bucket = get_jumpstart_content_bucket()
training_data_prefix = "training-datasets/SST/"

training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}"

### Start Training
***
We start by creating the estimator object with all the required assets and then launch the training job.  Since default hyperparameter values are model-specific, inspect estimator.hyperparameters() to view default values for your selected model.
***

In [None]:
estimator = JumpStartEstimator(
    model_id=model_id,
    hyperparameters={"epochs": "1", "batch_size": "64"},
)

In [None]:
# You can now fit the estimator by providing training data to the train channel

estimator.fit({"training": training_dataset_s3_path}, logs=True)

## Deploy & run Inference on the fine-tuned model
***
A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the class label of an input sentence.
***

In [None]:
# You can deploy the fine-tuned model to an endpoint directly from the estimator.
predictor = estimator.deploy()

---
Next, we input example sentences for running inference.
These examples are taken from SST2 dataset downloaded from [TensorFlow](https://www.tensorflow.org/datasets/catalog/glue#gluesst2). [Apache 2.0 License](https://www.apache.org/licenses/LICENSE-2.0). [Dataset Homepage](https://nlp.stanford.edu/sentiment/index.html). 

---

In [None]:
text1 = "astonishing ... ( frames ) profound ethical and philosophical questions in the form of dazzling pop entertainment"
text2 = "simply stupid , irrelevant and deeply , truly , bottomlessly cynical "

In [None]:
for text in [text1, text2]:
    query_response = predictor.predict(text)
    print(query_response)

---
Next, we clean up the deployed endpoint.

---

In [None]:
# Delete the SageMaker endpoint and the attached resources
predictor.delete_predictor()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart_text_classification|Amazon_JumpStart_Text_Classification.ipynb)
