<a href="https://colab.research.google.com/github/Diksha227/AIF360/blob/master/Zero_Shot_Text_Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Zero-Shot Text Classification

In this notebook, we will implement zero-shot classification from [transformers](https://github.com/huggingface/transformers) library and evaluate by [datasets](https://github.com/huggingface/datasets) library.

In [None]:
!pip install transformers datasets

In [None]:
from transformers import pipeline
from tqdm.notebook import tqdm

In [None]:
classifier = pipeline("zero-shot-classification", device=0)#GPU

Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Examples

In [None]:
candidate_labels = ["positive", "negative"]

In [None]:
text1 = "This movie is terrible but it has some good effects."
classifier(text1, candidate_labels)
#negative

{'labels': ['negative', 'positive'],
 'scores': [0.993061900138855, 0.0069380709901452065],
 'sequence': 'This movie is terrible but it has some good effects.'}

In [None]:
text2 = "This movie will always be a Broadway and Movie classic, as long as there are still people who sing, dance, and act."
classifier(text2, candidate_labels)
#positive

{'labels': ['positive', 'negative'],
 'scores': [0.9604027271270752, 0.03959730640053749],
 'sequence': 'This movie will always be a Broadway and Movie classic, as long as there are still people who sing, dance, and act.'}

In [None]:
candidate_labels = ["world", "sports", "business", "sci/tech"]

text1 = "Veteran inventor in market float Trevor Baylis, the veteran inventor famous for creating the Freeplay clockwork radio, is planning to float his company on the stock market."
classifier(text1, candidate_labels)
#business

{'labels': ['business', 'science and tech', 'world', 'sports'],
 'scores': [0.6816723942756653,
  0.18734510242938995,
  0.12677989900112152,
  0.004202586133033037],
 'sequence': 'Veteran inventor in market float Trevor Baylis, the veteran inventor famous for creating the Freeplay clockwork radio, is planning to float his company on the stock market.'}

In [None]:
text2 = "This Date in Baseball - Aug. 17 (AP) AP - 1904  #151; Jesse Tannehill of the Boston Red Sox pitched a no-hitter, beating the Chicago White Sox 6-0."
classifier(text2, candidate_labels)
#sports

{'labels': ['sports', 'world', 'business', 'science and tech'],
 'scores': [0.9873999953269958,
  0.010431364178657532,
  0.0015704811085015535,
  0.0005981583963148296],
 'sequence': 'This Date in Baseball - Aug. 17 (AP) AP - 1904  #151; Jesse Tannehill of the Boston Red Sox pitched a no-hitter, beating the Chicago White Sox 6-0.'}

In [None]:
candidate_labels = ["anger", "fear", "joy", "love", "sadness", "surprise"]
text = "i didnt feel humiliated"
classifier(text, candidate_labels)
#sadness

{'labels': ['surprise', 'joy', 'love', 'sadness', 'fear', 'anger'],
 'scores': [0.66361004114151,
  0.1976112276315689,
  0.04634414240717888,
  0.03801531344652176,
  0.03516925126314163,
  0.01925000175833702],
 'sequence': 'i didnt feel humiliated'}

In [None]:
candidate_labels = ["anger", "fear", "joy", "love", "sadness", "surprise"]
text = "i am feeling grouchy"
classifier(text, candidate_labels)
#anger

{'labels': ['anger', 'surprise', 'sadness', 'fear', 'joy', 'love'],
 'scores': [0.7041721940040588,
  0.11884286999702454,
  0.11590170115232468,
  0.04772496595978737,
  0.009454275481402874,
  0.003903944045305252],
 'sequence': 'i am feeling grouchy'}

It's amazing how zero-shot text classification performs on these examples from IMDB, AG_News, and emotion datasets. However, it is a mystery that how it works in the real life. To evaluate this, we will use datasets library and evaluate on test part of these datasets.

## Evaluation

In [None]:
from datasets import load_dataset, load_metric
import numpy as np
import math
from sklearn.metrics import classification_report

##### **IMDB: Sentiment Analysis**

In [None]:
dataset = load_dataset('imdb')

In [None]:
print(dataset["train"][0])

{'label': 1, 'text': 'Bromwell High is a cartoon comedy. It ran at the same time as some other programs about school life, such as "Teachers". My 35 years in the teaching profession lead me to believe that Bromwell High\'s satire is much closer to reality than is "Teachers". The scramble to survive financially, the insightful students who can see right through their pathetic teachers\' pomp, the pettiness of the whole situation, all remind me of the schools I knew and their students. When I saw the episode in which a student repeatedly tried to burn down the school, I immediately recalled ......... at .......... High. A classic line: INSPECTOR: I\'m here to sack one of your teachers. STUDENT: Welcome to Bromwell High. I expect that many adults of my age think that Bromwell High is far fetched. What a pity that it isn\'t!'}


In [None]:
candidate_labels = ["positive", "negative"]
predictions = []
for offset in tqdm(range(math.ceil(len(dataset["test"])/16))):
    preds = classifier([dataset["test"][16*offset+i]["text"] for i in range(16) if 16*offset+i<len(dataset["test"])], candidate_labels)
    pred_labels = [pred["labels"][np.argmax(pred["scores"])] for pred in preds]
    predictions.extend([0 if pred_label=="negative" else 1 for pred_label in pred_labels])

HBox(children=(FloatProgress(value=0.0, max=1563.0), HTML(value='')))




In [None]:
print(classification_report([x["label"] for x in dataset["test"]], predictions))

              precision    recall  f1-score   support

           0       0.85      0.93      0.89     12500
           1       0.92      0.84      0.88     12500

    accuracy                           0.88     25000
   macro avg       0.89      0.88      0.88     25000
weighted avg       0.89      0.88      0.88     25000



##### **AG_News: News Categorization**

In [None]:
dataset = load_dataset('ag_news')

In [None]:
candidate_labels = ["world", "sports", "business", "sci/tech"]
predictions = []
for offset in tqdm(range(math.ceil(len(dataset["test"])/16))):
    preds = classifier([dataset["test"][16*offset+i]["text"] for i in range(16) if 16*offset+i<len(dataset["test"])], candidate_labels)
    pred_labels = [pred["labels"][np.argmax(pred["scores"])] for pred in preds]
    predictions.extend([candidate_labels.index(pred_label) for pred_label in pred_labels])

HBox(children=(FloatProgress(value=0.0, max=475.0), HTML(value='')))




In [None]:
print(classification_report([x["label"] for x in dataset["test"]], predictions))

              precision    recall  f1-score   support

           0       0.54      0.82      0.65      1900
           1       0.94      0.86      0.90      1900
           2       0.53      0.69      0.60      1900
           3       0.76      0.20      0.32      1900

    accuracy                           0.64      7600
   macro avg       0.69      0.64      0.62      7600
weighted avg       0.69      0.64      0.62      7600



##### **Emotion: Emotion Classification**

In [None]:
dataset = load_dataset('emotion')

In [None]:
candidate_labels = ["anger", "fear", "joy", "love", "sadness", "surprise"]
predictions = []
for offset in tqdm(range(math.ceil(len(dataset["test"])/16))):
    preds = classifier([dataset["test"][16*offset+i]["text"] for i in range(16) if 16*offset+i<len(dataset["test"])], candidate_labels)
    pred_labels = [pred["labels"][np.argmax(pred["scores"])] for pred in preds]
    predictions.extend(pred_labels)

HBox(children=(FloatProgress(value=0.0, max=125.0), HTML(value='')))




In [None]:
print(classification_report([x["label"] for x in dataset["test"]], predictions))

              precision    recall  f1-score   support

       anger       0.71      0.35      0.47       275
        fear       0.55      0.36      0.43       224
         joy       0.81      0.35      0.49       695
        love       0.37      0.21      0.27       159
     sadness       0.73      0.41      0.53       581
    surprise       0.06      0.86      0.11        66

    accuracy                           0.38      2000
   macro avg       0.54      0.43      0.38      2000
weighted avg       0.69      0.38      0.46      2000



We can see that it really works!

It performs 0.88 F1-score for sentiment analysis.

For news categorization, the F1-score is 0.62. The worse category is sci/tech and it may be improved with detailed explanation of labels.

However, the performance in emotion dataset is rather poor. It might be due to similarity between classes. It is a very hard task to make distinction between joy, love, and surprise without any prior data.