# Zero Shot Text Classification Example

In this tutorial we will walk through the text classification pipeline in forte, using zero shot text classification model from hugging face. We will also explore ways to improve our model accuracy with simple modifications.

## Introduction
Zero shot text classification is an extreme example of transfer learning, where the model tries to predict without any fine tuning. This kind of model predict classification result based on similarity between the input text and label name, so they don't require any labeled data. Well defined Label names are sufficient to achieve a reasonably performing model. We will be using [valhalla/distilbart-mnli-12-1](https://huggingface.co/valhalla/distilbart-mnli-12-1) as our text classification model, which was trained on [Multi-nli](https://huggingface.co/datasets/multi_nli) dataset to make predictions on a subset of amazon review sentiment [ARS](https://s3.amazonaws.com/fast-ai-nlp/amazon_review_polarity_csv.tgz) dataset and [Banking77](https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/test.csv) dataset. we will also see how the forte pipeline works seamlessly with different third party tools like [nltk](https://www.nltk.org/) and [huggingface](https://huggingface.co/) to make our life easier.

In [1]:
import os
from termcolor import colored
from forte.data.readers import ClassificationDatasetReader
from fortex.huggingface import ZeroShotClassifier
from forte.pipeline import Pipeline
from fortex.nltk import NLTKSentenceSegmenter
from ft.onto.base_ontology import Sentence
from ft.onto.base_ontology import Document
import pandas as pd
from sklearn.metrics import accuracy_score
from collections import Counter
from pprint import pprint
import pandas as pd
pd.set_option('display.max_colwidth', 150)

## Pipeline Initialization
The code below initializes hyperparameters and the pipeline.
- csv_path : It is the path of the input csv file. In this example it'll be a subset of ARS test data.
- class_names : This is the most important hyperparameter in zero shot text classification. we need to choose label names which are meaningful.
- index2class : This is the numerical representation of class names.
- [ClassificationDatasetReader_config](https://github.com/asyml/forte/blob/63577ce88fd5a42c3c5930f27181167c92ab2057/forte/data/readers/classification_reader.py) : It takes three essential field inputs
    - forte_data_fields : It takes the column name of the csv file as input in sequential order.
    - index2class : It uses the index2class dictionary as input.
    - text_fields : It lets the pipeline know which column text in the csv file is used for making the prediction.
- GPU : Assign the gpu number on which the model should run. Set it to -1 if you want to use CPU only.

In [2]:
# enter the path of the input csv file
csv_path = "../../data_samples/amazon_review_polarity_csv/amazon_sample_4.csv"

# label names
Label_1='Negative'
Label_2='Positive'
class_names = [Label_1, Label_2]

#numerical representation of class names
index2class={0:Label_1,1: Label_2}

#config for ClassificationDatasetReader()
ClassificationDatasetReader_config = {
    "forte_data_fields": [
        "label",
        "ft.onto.base_ontology.Title",
        "ft.onto.base_ontology.Body",
    ],
    "index2class": index2class,
    "text_fields": [
        "ft.onto.base_ontology.Body"
    ],
}

#cuda_device (-1 for cpu usage)
GPU=0 

#pipeline components
pl = Pipeline()
pl.set_reader(ClassificationDatasetReader(), config=ClassificationDatasetReader_config)
pl.add(NLTKSentenceSegmenter()) # to segment each sentence from text.
pl.add(ZeroShotClassifier(), config={"candidate_labels": class_names,"cuda_device":GPU})
pl.initialize();

[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Data Visualization
We will use pandas library to read the csv file and visualize the input data.

In [3]:
df=pd.read_csv(csv_path)
df.head()

Unnamed: 0,label,Title,description
0,1,THIS IS A USED BATTERY,"States this in a new battery but IT IS USED! Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready ..."
1,2,Wow - What Service!,I received this product less than 24 hours after I ordered it - in Excellent condition. Would Definitely use this vendor again!! Thanks.
2,1,Very boring-gave up on it,Cook reintroduced some old characters using the same story line descriptions. I felt I was rereading a previously read novel. Editors missed many ...
3,1,not VISTA compatible,"the cord works but the driver does NOT work for windows Vista. So, this was a waste of money for me. I only bought it for the computer cord. Can't..."


In [4]:
# first row description text
df['description'][0]

'States this in a new battery but IT IS USED! Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready to buy?$6.96+ $2.98shippingNew'

## Prediction
We will predict a few sentences from the description column using the above pipeline. we can see the NLTKSentenceSegmenter( ) has split the text in each row of the input into individual sentences. Then ZeroShotClassifier( ) is making predictions for each sentence. The number besides the Label name is the confidence score for the prediction it made.

In [5]:
for pack in pl.process_dataset(csv_path):
    for sent in pack.get(Sentence):
        sent_text = sent.text
        print(colored("Sentence:", "red"), sent_text)
        print(colored("Prediction:", "blue"), sent.classification, "\n")
    print('-----------------------------------------------------------')

[31mSentence:[0m States this in a new battery but IT IS USED!
[34mPrediction:[0m {'Positive': 0.0789, 'Negative': 0.0529} 

[31mSentence:[0m Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready to buy?$6.96+ $2.98shippingNew
[34mPrediction:[0m {'Negative': 0.084, 'Positive': 0.013} 

-----------------------------------------------------------
[31mSentence:[0m I received this product less than 24 hours after I ordered it - in Excellent condition.
[34mPrediction:[0m {'Positive': 0.9996, 'Negative': 0.0008} 

[31mSentence:[0m Would Definitely use this vendor again!!
[34mPrediction:[0m {'Positive': 0.978, 'Negative': 0.0003} 

[31mSentence:[0m Thanks.
[34mPrediction:[0m {'Positive': 0.2642, 'Negative': 0.0941} 

-----------------------------------------------------------
[31mSentence:[0m Cook reintroduced some old characters using the same story line descriptions.
[34mPrediction:[0m {'Positive': 0.0302, 'Negative': 0.0

In [6]:
def get_class(prediction):
    ''' finds the class with max score
    input: classification output dict containing class probablities for input sentence
    output: returns the class with highest probablity
    '''
    return pd.DataFrame.from_dict(prediction,orient ='index')[0].idxmax()

def aggregate_class_score(pack):
    ''' finds the highest frequncy class name
    input: pack object
    output: returns aggregate class for each sentence in the pack.text
    '''
    prediction_list=[]    
    for sent in pack.get(Sentence):
        predicted_class=get_class(sent.classification)
        prediction_list.append(predicted_class)
        
    cnt=Counter(prediction_list)
    predicted_class=cnt.most_common()[0][0]    
    predicted_score=cnt[predicted_class]/len(prediction_list)

    return (predicted_class,predicted_score)

## Aggregate Sentences
In the above block we wrote two functions get_class and aggregate_class_score to aggregate the predictions of individual sentences in the text of one row. This helps us find a prediction for the whole text in a row.

In [7]:
y_true=[]
y_pred=[]
for pack in pl.process_dataset(csv_path):
    yt=next(pack.get(Document)).document_class[0]
    yp,ys=aggregate_class_score(pack)    
    y_true.append(yt)
    y_pred.append(yp)
    print(pack.text,'\n')
    print(colored('ground_truth : ','green'),
          yt,colored('     predicted_class : ','blue'),
          yp,colored(' confidence score : ','red'),ys)
    print('-----------------------------------------------------------------------------------','\n')

States this in a new battery but IT IS USED! Contacts are worn, dints in the plastic and scratches.AccessoryOneCondition Seller Information Ready to buy?$6.96+ $2.98shippingNew 

[32mground_truth : [0m Negative [34m     predicted_class : [0m Positive [31m confidence score : [0m 0.5
----------------------------------------------------------------------------------- 

I received this product less than 24 hours after I ordered it - in Excellent condition. Would Definitely use this vendor again!! Thanks. 

[32mground_truth : [0m Positive [34m     predicted_class : [0m Positive [31m confidence score : [0m 1.0
----------------------------------------------------------------------------------- 

Cook reintroduced some old characters using the same story line descriptions. I felt I was rereading a previously read novel. Editors missed many writing errors which made for awkward sentences. I finished only half the book and gave up 

[32mground_truth : [0m Negative [34m     predict

## Accuracy
We'll use the sklearn library to evaluate accuracy. In the previous block we have stored the ground truth and prediction in two lists. accuracy_score from sklearn will take these as input and return the accuracy of the model.

In [8]:
accuracy_score(y_true, y_pred)

0.5

## Experiment
Here we will explore a way to improve accuracy of our model on a given dataset. As we know the zero_shot model makes predictions based on similarity between the input sentence and the class names. So we can experiment with different class names similar to the original class names, and find out which class names work best for our dataset.

In [9]:
def accuracy(csv_path,Label_1='negative',Label_2='positive'):
    '''This function unifies the initialization, prediction and accuracy evaluation.
    input : csv_path and class names
    output : accuracy'''
    class_names = [Label_1, Label_2]
    index2class={0:Label_1,1: Label_2}
    ClassificationDatasetReader_config = {
        "forte_data_fields": [
            "label",
            "ft.onto.base_ontology.Title",
            "ft.onto.base_ontology.Body",
        ],
        "index2class": index2class,
        "text_fields": [
            "ft.onto.base_ontology.Body"
        ],
        
    }
    GPU=0
    pl = Pipeline()
    pl.set_reader(ClassificationDatasetReader(), config=ClassificationDatasetReader_config)
    pl.add(NLTKSentenceSegmenter())
    pl.add(ZeroShotClassifier(), config={"candidate_labels": class_names,"cuda_device":GPU})
    pl.initialize()

    y_true=[]
    y_pred=[]
    for pack in pl.process_dataset(csv_path):
        yt=next(pack.get(Document)).document_class[0]
        yp,ys=aggregate_class_score(pack)    
        y_true.append(yt)
        y_pred.append(yp)
    return accuracy_score(y_true, y_pred)

Let's define a list of similar class names grouped into tuples. Keep the order of the class names to be similar too. i.e. first negative word then positive word. As we want to check the accuracy on different class names, we should increase the number of sample data to get stable accuracy results.

In [10]:
class_name_list=[('negative','positive'),
                ('bad','good'),
                ('unsatisfied','satisfied'),
                ('unhappy','happy')]

#if using cpu we ll use the small sample data used previously, it will keep the runtime low.
if GPU != -1:
    csv_path = "../../data_samples/amazon_review_polarity_csv/amazon_sample_10k.csv"

In [11]:
class_name=[]
accuracy_list=[]
for i in class_name_list:
    class_name.append(i[0]+'_'+i[1])
    accuracy_list.append(accuracy(csv_path,Label_1=i[0],Label_2=i[1]))
    #print(i,' accuracy : ',accuracy(csv_path,Label_1=i[0],Label_2=i[1]),'\n')
da=pd.DataFrame()
da['class_name']=class_name
da['accuracy']=accuracy_list
da

[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Unnamed: 0,class_name,accuracy
0,negative_positive,0.8603
1,bad_good,0.765
2,unsatisfied_satisfied,0.8651
3,unhappy_happy,0.8022


Accuracy of 86.03% on unseen data is exceptional performance from hugging face model. Zero shot is a powerful tool for low volume of label data problems. With intelligent selection of class names we can improve further. We can see that class names ('unsatisfied', 'satisfied') improved our accuracy a little bit compared to class names ('negative', 'positive'). 

# MultiClass classification
In this section we will see how to do multi class classification, using the same model and banking77 dataset.

In [12]:
csv_path = "../../data_samples/banking77/sample.csv"
url='https://raw.githubusercontent.com/PolyAI-LDN/task-specific-datasets/master/banking_data/test.csv'
df=pd.read_csv(url)
df.sample(5)

Unnamed: 0,text,category
2798,Can I exchange money from abroad without additional costs?,exchange_charge
2613,how secure is a disposable virtual card,get_disposable_virtual_card
360,I've tried my card a bunch of times and it never worked.,card_not_working
1316,Is Visa or Mastercard available?,visa_or_mastercard
2393,The top-up verification code is missing,verify_top_up


Extracting class names from the banking77 test dataset.

In [13]:
class_names=df['category'].unique()

In [14]:
index2class = dict(enumerate(class_names))

Here the config has been provided two additional parameters
- digit_label : set to False, as we have text as category
- one_based_index_label : If the labels start from one

In [15]:
this_reader_config = {
    "forte_data_fields": [
        "ft.onto.base_ontology.Body",
        "label",
    ],
    "index2class": index2class,
    "text_fields": [
        "ft.onto.base_ontology.Body"
    ],
    "digit_label": False,
    "one_based_index_label": False,
}
GPU=1

In [16]:
pl = Pipeline()
pl.set_reader(ClassificationDatasetReader(), config=this_reader_config)
pl.add(NLTKSentenceSegmenter())
pl.add(ZeroShotClassifier(), config={"candidate_labels": class_names,"cuda_device":GPU})
pl.initialize();

[nltk_data] Downloading package punkt to
[nltk_data]     /home/bhaskar.rao/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


The below code predicts the class of each sentence from 77 class names that we provided to the model.

In [17]:
for pack in pl.process_dataset(csv_path):
    for sent in pack.get(Sentence):
        sent_text = sent.text
        print(colored("Sentence:", "red"), sent_text, "\n")
        print(colored("Prediction:", "blue"),'\n')
        pprint(sent.classification)

[31mSentence:[0m How do I locate my card? 

[34mPrediction:[0m 

{'Refund_not_showing_up': 0.0081,
 'activate_my_card': 0.0576,
 'age_limit': 0.0048,
 'apple_pay_or_google_pay': 0.0019,
 'atm_support': 0.0182,
 'automatic_top_up': 0.003,
 'balance_not_updated_after_bank_transfer': 0.0136,
 'balance_not_updated_after_cheque_or_cash_deposit': 0.0112,
 'beneficiary_not_allowed': 0.0008,
 'cancel_transfer': 0.0035,
 'card_about_to_expire': 0.0424,
 'card_acceptance': 0.1594,
 'card_arrival': 0.0375,
 'card_delivery_estimate': 0.0261,
 'card_linking': 0.2052,
 'card_not_working': 0.0616,
 'card_payment_fee_charged': 0.016,
 'card_payment_not_recognised': 0.0552,
 'card_payment_wrong_exchange_rate': 0.0074,
 'card_swallowed': 0.0604,
 'cash_withdrawal_charge': 0.0062,
 'cash_withdrawal_not_recognised': 0.0372,
 'change_pin': 0.0082,
 'compromised_card': 0.3402,
 'contactless_not_working': 0.0591,
 'country_support': 0.0072,
 'declined_card_payment': 0.0293,
 'declined_cash_withdrawal': 0

[31mSentence:[0m My card has not arrived yet. 

[34mPrediction:[0m 

{'Refund_not_showing_up': 0.6876,
 'activate_my_card': 0.1168,
 'age_limit': 0.054,
 'apple_pay_or_google_pay': 0.0092,
 'atm_support': 0.0245,
 'automatic_top_up': 0.0121,
 'balance_not_updated_after_bank_transfer': 0.4729,
 'balance_not_updated_after_cheque_or_cash_deposit': 0.3386,
 'beneficiary_not_allowed': 0.0127,
 'cancel_transfer': 0.0779,
 'card_about_to_expire': 0.0313,
 'card_acceptance': 0.3379,
 'card_arrival': 0.2362,
 'card_delivery_estimate': 0.925,
 'card_linking': 0.1376,
 'card_not_working': 0.5276,
 'card_payment_fee_charged': 0.0591,
 'card_payment_not_recognised': 0.6136,
 'card_payment_wrong_exchange_rate': 0.0501,
 'card_swallowed': 0.1379,
 'cash_withdrawal_charge': 0.0338,
 'cash_withdrawal_not_recognised': 0.2648,
 'change_pin': 0.0375,
 'compromised_card': 0.067,
 'contactless_not_working': 0.4622,
 'country_support': 0.0132,
 'declined_card_payment': 0.5725,
 'declined_cash_withdrawal'

[31mSentence:[0m Is it normal to have to wait over a week for my new card? 

[34mPrediction:[0m 

{'Refund_not_showing_up': 0.0033,
 'activate_my_card': 0.012,
 'age_limit': 0.0103,
 'apple_pay_or_google_pay': 0.0008,
 'atm_support': 0.0063,
 'automatic_top_up': 0.0007,
 'balance_not_updated_after_bank_transfer': 0.0061,
 'balance_not_updated_after_cheque_or_cash_deposit': 0.0069,
 'beneficiary_not_allowed': 0.0007,
 'cancel_transfer': 0.0052,
 'card_about_to_expire': 0.0116,
 'card_acceptance': 0.0915,
 'card_arrival': 0.0157,
 'card_delivery_estimate': 0.111,
 'card_linking': 0.0193,
 'card_not_working': 0.0807,
 'card_payment_fee_charged': 0.0073,
 'card_payment_not_recognised': 0.0132,
 'card_payment_wrong_exchange_rate': 0.0046,
 'card_swallowed': 0.038,
 'cash_withdrawal_charge': 0.0021,
 'cash_withdrawal_not_recognised': 0.0075,
 'change_pin': 0.0119,
 'compromised_card': 0.0461,
 'contactless_not_working': 0.0274,
 'country_support': 0.0075,
 'declined_card_payment': 0.0196

[31mSentence:[0m still waiting on my new card 

[34mPrediction:[0m 

{'Refund_not_showing_up': 0.0056,
 'activate_my_card': 0.0353,
 'age_limit': 0.0252,
 'apple_pay_or_google_pay': 0.0034,
 'atm_support': 0.0093,
 'automatic_top_up': 0.0035,
 'balance_not_updated_after_bank_transfer': 0.0512,
 'balance_not_updated_after_cheque_or_cash_deposit': 0.0361,
 'beneficiary_not_allowed': 0.0012,
 'cancel_transfer': 0.0077,
 'card_about_to_expire': 0.0103,
 'card_acceptance': 0.187,
 'card_arrival': 0.0461,
 'card_delivery_estimate': 0.2011,
 'card_linking': 0.025,
 'card_not_working': 0.0328,
 'card_payment_fee_charged': 0.0254,
 'card_payment_not_recognised': 0.0548,
 'card_payment_wrong_exchange_rate': 0.0112,
 'card_swallowed': 0.0425,
 'cash_withdrawal_charge': 0.0069,
 'cash_withdrawal_not_recognised': 0.0236,
 'change_pin': 0.0133,
 'compromised_card': 0.0425,
 'contactless_not_working': 0.03,
 'country_support': 0.01,
 'declined_card_payment': 0.052,
 'declined_cash_withdrawal': 0.

 'beneficiary_not_allowed': 0.0014,
 'cancel_transfer': 0.0041,
 'card_about_to_expire': 0.0241,
 'card_acceptance': 0.1494,
 'card_arrival': 0.0362,
 'card_delivery_estimate': 0.2114,
 'card_linking': 0.0592,
 'card_not_working': 0.3066,
 'card_payment_fee_charged': 0.0123,
 'card_payment_not_recognised': 0.0724,
 'card_payment_wrong_exchange_rate': 0.0084,
 'card_swallowed': 0.109,
 'cash_withdrawal_charge': 0.0031,
 'cash_withdrawal_not_recognised': 0.0246,
 'change_pin': 0.0096,
 'compromised_card': 0.0908,
 'contactless_not_working': 0.0378,
 'country_support': 0.0084,
 'declined_card_payment': 0.0291,
 'declined_cash_withdrawal': 0.0087,
 'declined_transfer': 0.033,
 'direct_debit_payment_not_recognised': 0.0388,
 'disposable_card_limits': 0.017,
 'edit_personal_details': 0.0174,
 'exchange_charge': 0.0147,
 'exchange_rate': 0.0023,
 'exchange_via_app': 0.0129,
 'extra_charge_on_statement': 0.0041,
 'failed_transfer': 0.0961,
 'fiat_currency_support': 0.0153,
 'get_disposable_vir