# Hugging Face SageMaker - Text-Classification
### Sentiment Analysis with `DistilBERT` and `imdb` dataset

---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

---

1. [Introduction](#Introduction)  
2. [Environment and Permissions](#Environment-and-Permissions)
3. [Preprocess - Tokenization of the dataset](#Preprocessing)   
4. [Creating an Estimator and start a training job](#Creating-an-Estimator-and-start-a-training-job)  
5. [Deploying the endpoint](#Deploying-the-endpoint)  

# Introduction

Welcome to our end-to-end binary Text-Classification example. This demo uses the Hugging Faces `transformers` and `datasets` library together with a custom Amazon SageMaker SDK extension to fine-tune a pre-trained transformer for binary text classification. The pre-trained model will be fine-tuned using the `imdb` dataset. The following is a diagram illustrating what we will do

![ Architecture](./files/architecture.png)

_**NOTE: You can run this demo in SageMaker Studio, your local machine or SageMaker Notebook Instances**_

# Environment and Permissions 

In [None]:
!pip install datasets

In [7]:
import sagemaker
import boto3

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="sagemaker_execution_role")["Role"]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::394934205700:role/service-role/AmazonSageMaker-ExecutionRole-20220201T182213
sagemaker bucket: sagemaker-us-east-1-394934205700
sagemaker session region: us-east-1


## Visualizing our data
We are using the `datasets` library to download the `imdb` dataset. The [imdb](http://ai.stanford.edu/~amaas/data/sentiment/) dataset consists of 25000 training and 25000 testing highly polar movie reviews.
Let see how our dataset looks like

In [None]:
from datasets import load_dataset

train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])

In [9]:
train_dataset, test_dataset

(Dataset({
     features: ['text', 'label'],
     num_rows: 25000
 }),
 Dataset({
     features: ['text', 'label'],
     num_rows: 25000
 }))

In [32]:
train_dataset[0]

{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

# Preprocessing

Before you can train a model on a dataset, it needs to be preprocessed into the expected model input format. Whether your data is text, images, or audio, they need to be converted and assembled into batches of tensors.
Text, use a [Tokenizer](https://huggingface.co/docs/transformers/main_classes/tokenizer) to convert text into a sequence of tokens, create a numerical representation of the tokens, and assemble them into tensors.

## Tokenization 

In [None]:
from sagemaker import get_execution_role
from sagemaker.pytorch.processing import PyTorchProcessor

pytorch_processor = PyTorchProcessor(
    role=role,
    instance_type="ml.m5.xlarge",
    instance_count=1,
    framework_version="1.13",
    py_version="py39",
)

In [None]:
from sagemaker.processing import ProcessingInput, ProcessingOutput
from time import gmtime, strftime

s3_prefix = "huggingface-text-classification"
processing_job_name = "{}-{}".format(s3_prefix, strftime("%d-%H-%M-%S", gmtime()))
output_destination = "s3://{}/{}".format(sess.default_bucket(), s3_prefix)

pytorch_processor.run(
    code="preprocessing.py",
    source_dir="scripts/preprocess",
    job_name=processing_job_name,
    outputs=[
        ProcessingOutput(
            output_name="train",
            destination="{}/train".format(output_destination),
            source="/opt/ml/processing/train",
        ),
        ProcessingOutput(
            output_name="test",
            destination="{}/test".format(output_destination),
            source="/opt/ml/processing/test",
        ),
    ],
)

In [None]:
preprocessing_job_description = pytorch_processor.jobs[-1].describe()
preprocessing_job_description

### Visualizing our processed dataset
Lets load our tokenized dataset and see how it looks

In [43]:
from datasets import load_from_disk

s3_prefix = "huggingface-text-classification"
processed_train_dataset = load_from_disk(
    "s3://{}/{}/train".format(sess.default_bucket(), s3_prefix)
)
processed_train_dataset

Dataset({
    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 25000
})

In [44]:
processed_train_dataset[0]

{'labels': tensor(0),
 'input_ids': tensor([  101,  1045, 12524,  1045,  2572,  8025,  1011,  3756,  2013,  2026,
          2678,  3573,  2138,  1997,  2035,  1996,  6704,  2008,  5129,  2009,
          2043,  2009,  2001,  2034,  2207,  1999,  3476,  1012,  1045,  2036,
          2657,  2008,  2012,  2034,  2009,  2001,  8243,  2011,  1057,  1012,
          1055,  1012,  8205,  2065,  2009,  2412,  2699,  2000,  4607,  2023,
          2406,  1010,  3568,  2108,  1037,  5470,  1997,  3152,  2641,  1000,
          6801,  1000,  1045,  2428,  2018,  2000,  2156,  2023,  2005,  2870,
          1012,  1026,  7987,  1013,  1028,  1026,  7987,  1013,  1028,  1996,
          5436,  2003,  8857,  2105,  1037,  2402,  4467,  3689,  3076,  2315,
         14229,  2040,  4122,  2000,  4553,  2673,  2016,  2064,  2055,  2166,
          1012,  1999,  3327,  2016,  4122,  2000,  3579,  2014,  3086,  2015,
          2000,  2437,  2070,  4066,  1997,  4516,  2006,  2054,  1996,  2779,
         25430, 1

## Creating an Estimator and start a training job

In [21]:
from sagemaker.huggingface import HuggingFace

training_input_path = "s3://{}/{}/train".format(sess.default_bucket(), s3_prefix)
test_input_path = "s3://{}/{}/test".format(sess.default_bucket(), s3_prefix)

# hyperparameters, which are passed into the training job
hyperparameters = {
    "epochs": 1,
    "train_batch_size": 32,
    "model_name": "distilbert-base-uncased",
    "learning_rate": 0.00003,
}

In [22]:
huggingface_estimator = HuggingFace(
    entry_point="train.py",
    source_dir="./scripts/train",
    instance_type="ml.p3.2xlarge",
    instance_count=1,
    role=role,
    transformers_version="4.26",
    pytorch_version="1.13",
    py_version="py39",
    hyperparameters=hyperparameters,
)

In [None]:
# starting the train job with our uploaded datasets as input
huggingface_estimator.fit({"train": training_input_path, "test": test_input_path})

## Deploying the endpoint

To deploy our endpoint, we call `deploy()` on our HuggingFace estimator object, passing in our desired number of instances and instance type.

In [None]:
predictor = huggingface_estimator.deploy(initial_instance_count=1, instance_type="ml.g4dn.xlarge")

Then, we use the returned predictor object to call the endpoint.

In [27]:
sentiment_input = {
    "inputs": "This is the best movie ever made in history, an absolute sculpted work of art that depicts every emotion of human existence, from suffering, to courage to love, in front of the background of political astuteness and socio-hierarchal analysis."
}

predictor.predict(sentiment_input)

[{'label': 'LABEL_1', 'score': 0.984623908996582}]

In [28]:
sentiment_input = {
    "inputs": "Another bloated film that gets all the history wrong, turns all of the characters into stick figures and makes piles of money for the star."
}

predictor.predict(sentiment_input)

[{'label': 'LABEL_0', 'score': 0.9781230092048645}]

Finally, we delete the endpoint again.

In [None]:
predictor.delete_model()
predictor.delete_endpoint()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.


![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)

![ This badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/sagemaker-script-mode|pytorch-sagemaker-huggingface|huggingface_text_classification.ipynb)
