# Zero Shot Text Classification Example

In this tutorial we will walk through the text classification pipeline in forte, using zero shot text classification model from hugging face. We will also explore ways to improve our model accuracy with simple modifications.

## Introduction
Zero shot text classification is an extreme example of transfer learning, where the model tries to predict without any fine tuning. This kind of model predict classification result based on similarity between the input text and label name, so they don't require any labeled data. Well defined Label names are sufficient to achieve a reasonably performing model. We will be using [valhalla/distilbart-mnli-12-1](https://huggingface.co/valhalla/distilbart-mnli-12-1) as our text classification model, which was trained on [Multi-nli](https://huggingface.co/datasets/multi_nli) dataset to make predictions on a subset of amazon review sentiment [ARS](https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz) dataset and [Banking77](https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/test.csv) dataset. we will also see how the forte pipeline works seamlessly with different third party tools like [nltk](https://www.nltk.org/) and [huggingface](https://huggingface.co/) to make our life easier.

In [1]:
import os
from termcolor import colored
from forte.data.readers import ClassificationDatasetReader
from fortex.huggingface import ZeroShotClassifier
from forte.pipeline import Pipeline
from fortex.nltk import NLTKSentenceSegmenter
from ft.onto.base_ontology import Sentence
from ft.onto.base_ontology import Document
import pandas as pd
from sklearn.metrics import accuracy_score
from collections import Counter

  from .autonotebook import tqdm as notebook_tqdm


## Pipeline Initialization
The code below initializes hyperparameters and the pipeline.
- Csv_path : It is the path of the input csv file. In this example it'll be a subset of ARS test data.
- Class_names : This is the most important hyperparameter in zero shot text classification. we need to choose label names which are meaningful.
- Index2class : This is the numerical representation of class names.
- [ClassificationDatasetReader_config](https://github.com/asyml/forte/blob/63577ce88fd5a42c3c5930f27181167c92ab2057/forte/data/readers/classification_reader.py) : It takes three essential field inputs
    - forte_data_fields : It takes the column name of the csv file as input in sequential order.
    - index2class : It the index2class dictionary as input.
    - text_fields : It lets the pipeline know which column text in the csv file is used for making the prediction.
- GPU : Assign the gpu number on which the model should run. Set it to -1 if you want to use CPU only.

In [2]:
# enter the path of the input csv file
csv_path = "../../data_samples/amazon_review_polarity_csv/amazon_sample_4.csv"

# label names
Label_1='Negative'
Label_2='Positive'
class_names = [Label_1, Label_2]

#numerical representation of class names
index2class={0:Label_1,1: Label_2}

#config for ClassificationDatasetReader()
ClassificationDatasetReader_config = {
    "forte_data_fields": [
        "label",
        "ft.onto.base_ontology.Title",
        "ft.onto.base_ontology.Body",
    ],
    "index2class": index2class,
    "text_fields": [
        "ft.onto.base_ontology.Body"
    ],
}

#cuda_device (-1 for cpu usage)
GPU=0 

#pipeline components
pl = Pipeline()
pl.set_reader(ClassificationDatasetReader(), config=ClassificationDatasetReader_config)
pl.add(NLTKSentenceSegmenter()) # to segment each sentence from text.
pl.add(ZeroShotClassifier(), config={"candidate_labels": class_names,"cuda_device":GPU})
pl.initialize();

[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Data Visualization
We will use pandas library to read the csv file and visualize the input data.

In [3]:
df=pd.read_csv(csv_path)
df.head()

Unnamed: 0,label,Title,description
0,1,THIS IS A USED BATTERY,States this in a new battery but IT IS USED! C...
1,2,Wow - What Service!,I received this product less than 24 hours aft...
2,1,Very boring-gave up on it,Cook reintroduced some old characters using th...
3,1,not VISTA compatible,the cord works but the driver does NOT work fo...


In [4]:
# first row description text
df['description'][0]

'States this in a new battery but IT IS USED! Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready to buy?$6.96+ $2.98shippingNew'

## Prediction
We will predict a few sentences from the description column using the above pipeline. we can see the NLTKSentenceSegmenter( ) has split the text in each row of the input into individual sentences. Then ZeroShotClassifier( ) is making predictions for each sentence. The number besides the Label name is the confidence score for the prediction it made.

In [5]:
for pack in pl.process_dataset(csv_path):
    for sent in pack.get(Sentence):
        sent_text = sent.text
        print(colored("Sentence:", "red"), sent_text)
        print(colored("Prediction:", "blue"), sent.classification, "\n")
    print('-----------------------------------------------------------')

[31mSentence:[0m States this in a new battery but IT IS USED!
[34mPrediction:[0m {'Positive': 0.0789, 'Negative': 0.0529} 

[31mSentence:[0m Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready to buy?$6.96+ $2.98shippingNew
[34mPrediction:[0m {'Negative': 0.084, 'Positive': 0.013} 

-----------------------------------------------------------
[31mSentence:[0m I received this product less than 24 hours after I ordered it - in Excellent condition.
[34mPrediction:[0m {'Positive': 0.9996, 'Negative': 0.0008} 

[31mSentence:[0m Would Definitely use this vendor again!!
[34mPrediction:[0m {'Positive': 0.978, 'Negative': 0.0003} 

[31mSentence:[0m Thanks.
[34mPrediction:[0m {'Positive': 0.2642, 'Negative': 0.0941} 

-----------------------------------------------------------
[31mSentence:[0m Cook reintroduced some old characters using the same story line descriptions.
[34mPrediction:[0m {'Positive': 0.0302, 'Negative': 0.0

In [6]:
def get_class(prediction):
    '''this function take a dictionary containing class names and probability score as key value pair,
    and returns the class name with highest probability score.'''
    return pd.DataFrame.from_dict(prediction,orient ='index')[0].idxmax()

def aggregate_class_score(pack):
    '''this function take a list of class names and returns the class name with highest frequency,
    and prediction score base on its frequency'''
    prediction_list=[]    
    for sent in pack.get(Sentence):
        predicted_class=get_class(sent.classification)
        prediction_list.append(predicted_class)
        
    cnt=Counter(prediction_list)
    
    predicted_class=cnt.most_common()[0][0]
        
    predicted_score=cnt[predicted_class]/len(prediction_list)

    return (predicted_class,predicted_score)

## Aggregate Sentences
In the above block we wrote two functions get_class and aggregate_class_score to aggregate the predictions of individual sentences in the text of one row. This helps us find a prediction for the whole text in a row.

In [7]:
y_true=[]
y_pred=[]
for pack in pl.process_dataset(csv_path):
    yt=next(pack.get(Document)).document_class[0]
    yp,ys=aggregate_class_score(pack)    
    y_true.append(yt)
    y_pred.append(yp)
    print(pack.text,'\n')
    print(colored('ground_truth : ','green'),
          yt,colored('     predicted_class : ','blue'),
          yp,colored(' confidence score : ','red'),ys)
    print('-----------------------------------------------------------------------------------','\n')

States this in a new battery but IT IS USED! Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready to buy?$6.96+ $2.98shippingNew 

[32mground_truth : [0m Negative [34m     predicted_class : [0m Positive [31m confidence score : [0m 0.5
----------------------------------------------------------------------------------- 

I received this product less than 24 hours after I ordered it - in Excellent condition. Would Definitely use this vendor again!! Thanks. 

[32mground_truth : [0m Positive [34m     predicted_class : [0m Positive [31m confidence score : [0m 1.0
----------------------------------------------------------------------------------- 

Cook reintroduced some old characters using the same story line descriptions. I felt I was rereading a previously read novel. Editors missed many writing errors which made for awkward sentences. I finished only half the book and gave up 

[32mground_truth : [0m Negative [34m     predict

## Accuracy
We'll use the sklearn library to evaluate accuracy. In the previous block we have stored the ground truth and prediction in two lists. accuracy_score from sklearn will take these as input and return the accuracy of the model.

In [8]:
accuracy_score(y_true, y_pred)

0.5

## Experiment
Here we will explore a way to improve accuracy of our model on a given dataset. As we know the zero_shot model makes predictions based on similarity between the input sentence and the class names. So we can experiment with different class names similar to the original class names, and find out which class names work best for our dataset.

In [9]:
def accuracy(csv_path,Label_1='negative',Label_2='positive'):
    '''This function unifies the initialization, prediction and accuracy evaluation.
    It takes csv_path and class names as input and gives us the accuracy as output'''
    class_names = [Label_1, Label_2]
    index2class={0:Label_1,1: Label_2}
    ClassificationDatasetReader_config = {
        "forte_data_fields": [
            "label",
            "ft.onto.base_ontology.Title",
            "ft.onto.base_ontology.Body",
        ],
        "index2class": index2class,
        "text_fields": [
            "ft.onto.base_ontology.Body"
        ],
        
    }
    GPU=0
    pl = Pipeline()
    pl.set_reader(ClassificationDatasetReader(), config=ClassificationDatasetReader_config)
    pl.add(NLTKSentenceSegmenter())
    pl.add(ZeroShotClassifier(), config={"candidate_labels": class_names,"cuda_device":GPU})
    pl.initialize()

    y_true=[]
    y_pred=[]
    for pack in pl.process_dataset(csv_path):
        yt=next(pack.get(Document)).document_class[0]
        yp,ys=aggregate_class_score(pack)    
        y_true.append(yt)
        y_pred.append(yp)
    return accuracy_score(y_true, y_pred)

In [10]:
''' Let's define a list of similar class names grouped into tuples.
    Keep the order of the class names to be similar too. i.e. first negative word then positive word.
    As we want to check the accuracy on different class names, we should increase the number of 
    sample data to get stable accuracy results. '''

class_name_list=[('negative','positive'),
                ('bad','good'),
                ('unsatisfied','satisfied'),
                ('unhappy','happy')]

#if using cpu we ll use the small sample data used previously, it will keep the runtime low.
if GPU != -1:
    csv_path = "../../data_samples/amazon_review_polarity_csv/amazon_sample_10k.csv"


In [11]:
for i in class_name_list:
    print(i,' accuracy : ',accuracy(csv_path,Label_1=i[0],Label_2=i[1]),'\n')

[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


('negative', 'positive')  accuracy :  0.8603 



[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


('bad', 'good')  accuracy :  0.765 



[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


('unsatisfied', 'satisfied')  accuracy :  0.8651 



[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


('unhappy', 'happy')  accuracy :  0.8022 



Accuracy of 86.03% on unseen data is exceptional performance from hugging face model. Zero shot is a powerful tool for low volume of label data problems. With intelligent selection of class names we can improve further. We can see that class names ('unsatisfied', 'satisfied') improved our accuracy a little bit compared to class names ('negative', 'positive'). 

# MultiClass classification
In this section we will see how to do multiclass classification, using the same model and banking77 dataset.

In [12]:
csv_path = "../../data_samples/banking77/sample.csv"

url='https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/test.csv'
df=pd.read_csv(url)
df.sample(5)

Unnamed: 0,text,category
1764,There's a refund missing from my statement,Refund_not_showing_up
1082,I saw a payment i did not do,card_payment_not_recognised
3007,Is there a specific type you need for identity...,verify_my_identity
1441,Which ATMs accept this card?,atm_support
1440,Do all ATMs take this card?,atm_support


Extracting class names from the banking77 test dataset.

In [13]:
class_names=df['category'].unique()

In [14]:
index2class = dict(enumerate(class_names))

Here the config has been provided two additional parameters
- digit_label : set to False, as we have text as category
- one_based_index_label : If the labels start from one

In [15]:
this_reader_config = {
    "forte_data_fields": [
        "ft.onto.base_ontology.Body",
        "label",
    ],
    "index2class": index2class,
    "text_fields": [
        "ft.onto.base_ontology.Body"
    ],
    "digit_label": False,
    "one_based_index_label": False,
}
GPU=1

In [16]:
pl = Pipeline()
pl.set_reader(ClassificationDatasetReader(), config=this_reader_config)
pl.add(NLTKSentenceSegmenter())
pl.add(ZeroShotClassifier(), config={"candidate_labels": class_names,"cuda_device":GPU})
pl.initialize();

[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


The below code predicts the class of each sentence from 77 class names that we provided to the model.

In [17]:
for pack in pl.process_dataset(csv_path):
    for sent in pack.get(Sentence):
        sent_text = sent.text
        print(colored("Sentence:", "red"), sent_text, "\n")
        print(colored("Prediction:", "blue"), sent.classification,'\n')

[31mSentence:[0m How do I locate my card? 

[34mPrediction:[0m {'lost_or_stolen_card': 0.4794, 'compromised_card': 0.3402, 'get_physical_card': 0.3204, 'card_linking': 0.2052, 'card_acceptance': 0.1594, 'passcode_forgotten': 0.1419, 'getting_spare_card': 0.139, 'order_physical_card': 0.1389, 'reverted_card_payment?': 0.1365, 'getting_virtual_card': 0.1057, 'virtual_card_not_working': 0.0658, 'why_verify_identity': 0.0624, 'card_not_working': 0.0616, 'get_disposable_virtual_card': 0.0608, 'card_swallowed': 0.0604, 'supported_cards_and_currencies': 0.0602, 'contactless_not_working': 0.0591, 'activate_my_card': 0.0576, 'verify_my_identity': 0.0558, 'card_payment_not_recognised': 0.0552, 'card_about_to_expire': 0.0424, 'visa_or_mastercard': 0.0393, 'pending_card_payment': 0.0383, 'card_arrival': 0.0375, 'cash_withdrawal_not_recognised': 0.0372, 'unable_to_verify_identity': 0.0322, 'declined_card_payment': 0.0293, 'card_delivery_estimate': 0.0261, 'declined_cash_withdrawal': 0.025, 'dir

[31mSentence:[0m When will I get my card? 

[34mPrediction:[0m {'card_delivery_estimate': 0.8556, 'pending_card_payment': 0.7886, 'card_arrival': 0.5471, 'lost_or_stolen_card': 0.4822, 'card_acceptance': 0.3835, 'reverted_card_payment?': 0.3699, 'pending_transfer': 0.3254, 'get_physical_card': 0.2969, 'order_physical_card': 0.2072, 'compromised_card': 0.1655, 'pin_blocked': 0.1645, 'pending_cash_withdrawal': 0.1485, 'transfer_timing': 0.1464, 'getting_spare_card': 0.1148, 'getting_virtual_card': 0.0957, 'card_linking': 0.0879, 'virtual_card_not_working': 0.0861, 'declined_transfer': 0.0786, 'card_not_working': 0.0716, 'pending_top_up': 0.0708, 'visa_or_mastercard': 0.0702, 'passcode_forgotten': 0.0682, 'activate_my_card': 0.0627, 'contactless_not_working': 0.062, 'verify_my_identity': 0.0609, 'card_swallowed': 0.0601, 'declined_cash_withdrawal': 0.0576, 'get_disposable_virtual_card': 0.0565, 'failed_transfer': 0.054, 'supported_cards_and_currencies': 0.0504, 'declined_card_payment'

[31mSentence:[0m I still don't have my card after 2 weeks. 

[34mPrediction:[0m {'card_not_working': 0.9266, 'failed_transfer': 0.8903, 'lost_or_stolen_card': 0.7495, 'card_delivery_estimate': 0.5822, 'passcode_forgotten': 0.5741, 'pending_card_payment': 0.4939, 'declined_transfer': 0.4904, 'card_swallowed': 0.4864, 'pin_blocked': 0.4746, 'virtual_card_not_working': 0.4405, 'compromised_card': 0.434, 'contactless_not_working': 0.3996, 'top_up_failed': 0.3858, 'get_physical_card': 0.3381, 'Refund_not_showing_up': 0.3354, 'getting_spare_card': 0.3091, 'card_linking': 0.2732, 'disposable_card_limits': 0.2366, 'pending_transfer': 0.2267, 'reverted_card_payment?': 0.2226, 'card_acceptance': 0.2209, 'transfer_not_received_by_recipient': 0.2012, 'getting_virtual_card': 0.1718, 'declined_card_payment': 0.1544, 'cancel_transfer': 0.1506, 'unable_to_verify_identity': 0.1461, 'order_physical_card': 0.1403, 'card_payment_not_recognised': 0.14, 'balance_not_updated_after_cheque_or_cash_deposit'

[31mSentence:[0m Can the card be mailed and used in Europe? 

[34mPrediction:[0m {'card_acceptance': 0.4682, 'compromised_card': 0.1534, 'card_delivery_estimate': 0.1319, 'order_physical_card': 0.1169, 'passcode_forgotten': 0.0983, 'card_linking': 0.0975, 'country_support': 0.0774, 'reverted_card_payment?': 0.0706, 'card_swallowed': 0.062, 'getting_spare_card': 0.0609, 'get_disposable_virtual_card': 0.0533, 'visa_or_mastercard': 0.0524, 'topping_up_by_card': 0.0476, 'get_physical_card': 0.0468, 'pin_blocked': 0.0433, 'getting_virtual_card': 0.0412, 'declined_transfer': 0.0351, 'fiat_currency_support': 0.0344, 'disposable_card_limits': 0.0334, 'card_not_working': 0.033, 'atm_support': 0.0306, 'card_about_to_expire': 0.0302, 'supported_cards_and_currencies': 0.0294, 'why_verify_identity': 0.0291, 'age_limit': 0.0231, 'edit_personal_details': 0.023, 'lost_or_stolen_card': 0.0224, 'activate_my_card': 0.0212, 'verify_my_identity': 0.021, 'top_up_limits': 0.0196, 'exchange_charge': 0.019