# Introduction to JumpStart - Zero Shot Text classification


---

This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. 

![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

---


---

Welcome to Amazon SageMaker JumpStart! You can use JumpStart to solve many Machine Learning tasks through one-click in SageMaker Studio, or through SageMaker JumpStart API. In this demo notebook, we demonstrate how to use the JumpStart API to do zero shot text classification. 

In supervised classification, natural language processing (NLP) models can only classify text that belong to classes in the training data. Zero-shot classification is a paradigm where a model can classify new, unseen examples that belong to classes that were not present in the training data. For example, a text classification model that is trained to classify new year resolutions tweets on 2 classes ‘career’ and ‘health’, can be used to classify resolutions to a category ‘finance’ that the model has not been trained on .

In this notebook, you will learn how to deploy the pre-trained model for Zero Shot Text classification, run inference and clean up resources. Furthermore, we benchmark the a zero shot text classification model, BART LARGE MNLI model on [New Year's Resolutions dataset](https://data.world/crowdflower/2015-new-years-resolutions).

---




1. [Set Up](#1.-Set-Up)
2. [Select a model](#2.-Select-a-model)
3. [Retrieve JumpStart Artifacts & Deploy an Endpoint](#3.-Retrieve-JumpStart-Artifacts-&-Deploy-an-Endpoint)
4. [Query endpoint and parse response](#4.-Query-endpoint-and-parse-response)
5. [Benchmarking](#5.-Benchmarking)
6. [Clean up the endpoint](#6.-Clean-up-the-endpoint)

Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.

Note: After you’re done running the notebook, make sure to delete all resources so that all the resources that you created in the process are deleted and your billing is stopped. Code in [Clean up the endpoint](#5.-Clean-up-the-endpoint) deletes model and endpoints that are created.

### 1. Set Up

---
Before executing the notebook, there are some initial steps required for set up.

---

In [None]:
!pip install ipywidgets==7.0.0 --quiet
!pip install --upgrade sagemaker

#### Permissions and environment variables

---
To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access. 

---

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

### 2. Select a model

***
You can continue with the default model, or can choose a different model from the dropdown generated upon running the next cell. A complete list of SageMaker pre-trained models can also be accessed at [Sagemaker pre-trained Models](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html#). [Default model](https://huggingface.co/facebook/bart-large-mnli) is the base BART Large model further trained on the [MNLI](https://huggingface.co/datasets/multi_nli) dataset. 
***

In [None]:
(
    model_id,
    model_version,
) = (
    "huggingface-zstc-facebook-bart-large-mnli",
    "*",
)

***
[Optional] Here, we filter-out all the zero shot text classification models and select a model for inference.
***

In [None]:
from ipywidgets import Dropdown
from sagemaker.jumpstart.notebook_utils import list_jumpstart_models

filter_value = "task == zstc"
zstc_models = list_jumpstart_models(filter=filter_value)

# display the model-ids in a dropdown to select a model for inference.
model_dropdown = Dropdown(
    options=zstc_models,
    value=model_id,
    description="Select a model",
    style={"description_width": "initial"},
    layout={"width": "max-content"},
)

#### Chose a model for Inference

In [None]:
display(model_dropdown)

In [None]:
# model_version="*" fetches the latest version of the model
model_id, model_version = model_dropdown.value, "*"

### 3. Retrieve Artifacts & Deploy an Endpoint

***

Using SageMaker, we can perform inference on the pre-trained model, even without fine-tuning it first on a new dataset. We start by retrieving the `deploy_image_uri`, `deploy_source_uri`, and `model_uri` for the pre-trained model. To host the pre-trained model, we create an instance of [`sagemaker.model.Model`](https://sagemaker.readthedocs.io/en/stable/api/inference/model.html) and deploy it. This may take a few minutes.

***

In [None]:
from sagemaker import image_uris, model_uris, script_uris, hyperparameters, instance_types
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base


endpoint_name = name_from_base(f"jumpstart-example-infer-{model_id}")

# Retrieve the default inference instance type. You can replace it with other instance types compatible with the model.
inference_instance_type = instance_types.retrieve_default(
    region=None, model_id=model_id, model_version=model_version, scope="inference"
)


# Retrieve the inference docker container uri. This is the base HuggingFace container image for the default model above.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

# Retrieve the inference script uri. This includes all dependencies and scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=model_id, model_version=model_version, script_scope="inference"
)

# Retrieve the model uri. This includes the pre-trained model and parameters.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

# Create the SageMaker model instance
model = Model(
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    model_data=model_uri,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    predictor_cls=Predictor,
    endpoint_name=endpoint_name,
)

### 4. Query endpoint and parse response

---
Input to the endpoint is a sequence and a set of candidate labels to chose from. It is in a json format and encoded in `utf-8` format. Output of the endpoint is a `json` with predicted labels and the scores. 

---

Next we write some helper function for querying the endpoint and parsing the endpoint response.

In [None]:
import matplotlib.pyplot as plt
import numpy as np

sequence, candidate_labels = "one day I will see the world", ["travel", "cooking", "dancing"]

newline = "\n"
bold = "\033[1m"
unbold = "\033[0m"


def query_endpoint(input_json):
    """Query the model predictor."""
    query_response = model_predictor.predict(
        json.dumps(input_json).encode("utf-8"),
        {
            "ContentType": "application/json",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return the labels, scores and the predicted label (highest score)."""
    model_predictions = json.loads(query_response)
    scores = model_predictions["scores"]
    labels = model_predictions["labels"]
    predicted_label_index = np.argmax(scores)
    predicted_label = labels[predicted_label_index]
    return labels, scores, predicted_label


input_query = {"sequences": sequence, "candidate_labels": candidate_labels}
query_response = query_endpoint(input_query)
labels, scores, predicted_label = parse_response(query_response)

print(
    f"Inference:{newline}"
    f"Sequence: {bold}{sequence}{unbold}{newline}"
    f"Labels: {bold}{labels}{unbold}{newline}"
    f"Scores: {bold}{scores}{unbold}{newline}"
    f"Predicted Label: {bold}{predicted_label}{unbold}{newline}"
)

---
Zero Shot text classification models also support multi-class classification i.e. predict multiple labels for a single input.  To predict more than one label, set `multi_class` to `True`. By default, this parameter is `False`.  

---

In [None]:
ml_sequence, ml_candidate_labels = "one day I will see the world", [
    "travel",
    "cooking",
    "dancing",
    "exploration",
]

ml_input_query = {
    "sequences": ml_sequence,
    "candidate_labels": ml_candidate_labels,
    "multi_class": True,
}
ml_query_response = query_endpoint(ml_input_query)
ml_labels, ml_scores, _ = parse_response(ml_query_response)

print(
    f"Inference:{newline}"
    f"Sequence: {bold}{ml_sequence}{unbold}{newline}"
    f"Labels: {bold}{ml_candidate_labels}{unbold}{newline}"
    f"Multi-label scores: {bold}{ml_scores}{unbold}{newline}"
)

### 5. Benchmarking
---
In this section, we will benchmark the BART LARGE MNLI model on the [New Year's Resolutions dataset](https://data.world/crowdflower/2015-new-years-resolutions). We classify each resolution as one of the following categories:

- Health
- Humor
- Personal Growth
- Philanthropy
- Leisure
- Career
- Finance
- Education
- Time Management
---

#### 5.1. Data download and inspection

In [None]:
region = boto3.Session().region_name
s3_bucket = f"sagemaker-solutions-prod-{region}"
key_prefix = "0.2.0/Zero-shot-text-classification/1.0.0/artifacts/data"
sample_tweets_file_name = "jumpstart-soln-zero-shot-text-clf-data.csv"
s3 = boto3.client("s3")

s3.download_file(s3_bucket, f"{key_prefix}/{sample_tweets_file_name}", sample_tweets_file_name)

# Get on overview of the dataset.
# Resolution category: is the actual label of the text
# Text: is the actual tweet

import pandas as pd

sample_tweets = pd.read_csv(sample_tweets_file_name)
sample_tweets

In [None]:
# Unique categories
list(sample_tweets["Resolution_Category"].unique())

---
Remap provided categories to the newly defined categories.

---

In [None]:
category_remap = {
    "Health & Fitness": "Health",
    "Recreation & Leisure": "Leisure",
    "Philanthropic": "Philanthropy",
    "Time Management/Organization": "Time Management",
    "Education/Training": "Education",
}

sample_tweets["Resolution_Category"] = sample_tweets["Resolution_Category"].replace(category_remap)

---
Drop the family, friends, and relationship categories because the original Family/Friends/Relationship cannot be mapped to a single category

---

In [None]:
sample_tweets = sample_tweets.loc[
    sample_tweets["Resolution_Category"] != "Family/Friends/Relationships"
]
sample_tweets = sample_tweets.reset_index(drop=True)
sample_tweets

In [None]:
# Unique categories
classification_categories = list(sample_tweets["Resolution_Category"].unique())
classification_categories

#### 5.2. Dataset Clean up
---
Before running inference the text is cleaned by removing links, emojis, and media.

---

In [None]:
import re


class TweetPreprocessor:
    @staticmethod
    def remove_links(tweet):
        """Takes a string and removes web links from it"""
        tweet = re.sub(r"http\S+", "", tweet)  # remove http links
        tweet = re.sub(r"bit.ly/\S+", "", tweet)  # remove bitly links
        tweet = re.sub(r"pic.twitter\S+", "", tweet)
        return tweet

    @staticmethod
    def remove_users(tweet):
        """Takes a string and removes retweet and @user information"""
        tweet = re.sub("(RT\s@[A-Za-z]+[A-Za-z0-9-_]+):*", "", tweet)  # remove retweet
        tweet = re.sub("(@[A-Za-z]+[A-Za-z0-9-_]+):*", "", tweet)  # remove tweeted at
        return tweet

    @staticmethod
    def remove_hashtags(tweet):
        """Takes a string and removes any hashtags"""
        tweet = re.sub("(#[A-Za-z]+[A-Za-z0-9-_]+)", "", tweet)  # remove hashtags
        return tweet

    @staticmethod
    def remove_av(tweet):
        """Takes a string and removes AUDIO/VIDEO tags or labels"""
        tweet = re.sub("VIDEO:", "", tweet)  # remove 'VIDEO:' from start of tweet
        tweet = re.sub("AUDIO:", "", tweet)  # remove 'AUDIO:' from start of tweet
        return tweet

    @staticmethod
    def preprocess(tweet):
        # tweet = tweet.encode('latin1', 'ignore').decode('latin1')
        tweet = tweet.lower()
        # tweet = TweetPreprocessor.remove_users(tweet)
        tweet = TweetPreprocessor.remove_links(tweet)
        # tweet = TweetPreprocessor.remove_hashtags(tweet)
        tweet = TweetPreprocessor.remove_av(tweet)
        tweet = " ".join(tweet.split())  # Remove extra spaces
        return tweet.strip()

    @staticmethod
    def get_hash_tags(tweet):
        return re.findall(r"#(\w+)", tweet)

In [None]:
!pip install demoji

In [None]:
import demoji
import boto3
import os
import json

sample_tweets["text_clean"] = sample_tweets["text"].map(
    TweetPreprocessor.preprocess
)  # Preprocess text.
sample_tweets["text_clean"] = sample_tweets["text_clean"].map(demoji.replace)  # Remove emojis.
sample_tweets

#### 5.3. Run inference

---
Select the number of samples you want to use for inference. The estimated time depends on the instance type you choose. On `ml.p3.2xlarge`, running inference on 1000 samples take roughly 5 minutes.

---

In [None]:
n_samples = 1000
sample_tweets = sample_tweets.iloc[:n_samples]

In [None]:
sequences = sample_tweets["text_clean"].tolist()

predicted_labels = []
for tweet in sequences:
    endpoint_response = query_endpoint(
        {
            "sequences": tweet,
            "candidate_labels": classification_categories,
        }
    )
    _, _, predicted_label = parse_response(endpoint_response)
    predicted_labels.append(predicted_label)

In [None]:
sample_tweets["zero-shot-class"] = predicted_labels
sample_tweets

#### 5.4. Compute metrics

---

Next, we compare the predicted label with the ground truth labels. 

---

In [None]:
from sklearn.metrics import classification_report

print(classification_report(sample_tweets["Resolution_Category"], sample_tweets["zero-shot-class"]))

### 6. Clean up the endpoint

In [None]:
# Delete the SageMaker model and endpoint
model_predictor.delete_model()
model_predictor.delete_endpoint()

## Notebook CI Test Results

This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.

![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)

![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/introduction_to_amazon_algorithms|jumpstart_zstc|Amazon_JumpStart_Zero_Shot_Text_Classification.ipynb)
