![JohnSnowLabs](https://nlp.johnsnowlabs.com/assets/images/logo.png)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/JohnSnowLabs/spark-nlp-workshop/blob/master/tutorials/Certification_Trainings/Public/5.4_ZeroShot_Text_Classification.ipynb)

# Zero-Shot Text Classification
**State-of-the-art NLP models for text classification without annotated data**

Natural language processing is a very exciting field right now. In recent years, the community has begun to figure out some pretty effective methods of learning from the enormous amounts of unlabeled data available on the internet. The success of transfer learning from unsupervised models has allowed us to surpass virtually all existing benchmarks on downstream supervised learning tasks. As we continue to develop new model architectures and unsupervised learning objectives, "state of the art" continues to be a rapidly moving target for many tasks where large amounts of labeled data are available.

## Zero-Shot Learning (ZSL)
Traditionally, zero-shot learning (ZSL) most often referred to a fairly specific type of task: learn a classifier on one set of labels and then evaluate on a different set of labels that the classifier has never seen before. Recently, especially in NLP, it's been used much more broadly to mean get a model to do something that it wasn't explicitly trained to do. A well-known example of this is in the [GPT-2 paper](https://pdfs.semanticscholar.org/9405/cc0d6169988371b2755e573cc28650d14dfe.pdf) where the authors evaluate a language model on downstream tasks like machine translation without fine-tuning on these tasks directly.


## Colab Setup

In [None]:
! pip install -q pyspark==3.4.1 spark-nlp==5.1.2

In [2]:
import sparknlp

from sparknlp.base import *
from sparknlp.annotator import *

from pyspark.ml import Pipeline,PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import pandas as pd

spark = sparknlp.start()

print("Spark NLP version: ", sparknlp.version())
print("Apache Spark version: ", spark.version)

spark

Spark NLP version:  5.1.2
Apache Spark version:  3.4.1


## Bert Zero-Shot Classification

This model is intended to be used for zero-shot text classification, especially in English. It is fine-tuned on XNLI by using BERT Base Case model.

BertForZeroShotClassification using a ModelForSequenceClassification trained on NLI (natural language inference) tasks. Equivalent of BertForSequenceClassification models, but these models don't require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more flexible.

We used `TFBertForSequenceClassification` to train this model and used `BertForZeroShotClassification` annotator in Spark NLP for prediction at scale!

### Zero-Shot Pipeline


Let's see how easy it is to just use any set of lables our trained model has never seen via `setCandidateLabels()` param:

In [3]:
document_assembler = DocumentAssembler() \
    .setInputCol("text") \
    .setOutputCol("document")

tokenizer = Tokenizer()\
    .setInputCols("document")\
    .setOutputCol("token")

zero_shot_classifier = BertForZeroShotClassification.pretrained("bert_base_cased_zero_shot_classifier_xnli", "en")\
    .setInputCols(["document", "token"]) \
    .setOutputCol("class") \
    .setCandidateLabels(["urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology"])

pipeline = Pipeline(stages=[
    document_assembler,
    tokenizer,
    zero_shot_classifier
])

zero_shot_bert = pipeline.fit(spark.createDataFrame([[""]]).toDF("text"))

bert_base_cased_zero_shot_classifier_xnli download started this may take some time.
Approximate size to download 387.7 MB
[OK!]


In [4]:
zero_shot_classifier.extractParamMap()

{Param(parent='BERT_FOR_ZERO_SHOT_CLASSIFICATION_e4205e7cf10f', name='activation', doc='Whether to calculate logits via Softmax or Sigmoid. Default is Softmax'): 'softmax',
 Param(parent='BERT_FOR_ZERO_SHOT_CLASSIFICATION_e4205e7cf10f', name='batchSize', doc='Size of every batch'): 8,
 Param(parent='BERT_FOR_ZERO_SHOT_CLASSIFICATION_e4205e7cf10f', name='coalesceSentences', doc="Instead of 1 class per sentence (if inputCols is '''sentence''') output 1 class per document by averaging probabilities in all sentences."): False,
 Param(parent='BERT_FOR_ZERO_SHOT_CLASSIFICATION_e4205e7cf10f', name='lazyAnnotator', doc='Whether this AnnotatorModel acts as lazy in RecursivePipelines'): False,
 Param(parent='BERT_FOR_ZERO_SHOT_CLASSIFICATION_e4205e7cf10f', name='multilabel', doc='Whether to calculate logits via Multiclass(softmax) or Multilabel(sigmoid). Default is False i.e. Multiclass'): False,
 Param(parent='BERT_FOR_ZERO_SHOT_CLASSIFICATION_e4205e7cf10f', name='threshold', doc='Choose the th

In [5]:
text = [["I have a problem with my iphone that needs to be resolved asap!!"],
        ["Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app."],
        ["I have a phone and I love it!"],
        ["I really want to visit Germany and I am planning to go there next year."],
        ["Let's watch some movies tonight! I am in the mood for a horror movie."],
        ["Have you watched the match yesterday? It was a great game!"],
        ["We need to harry up and get to the airport. We are going to miss our flight!"]]

# create a DataFrame in PySpark
inputDataset = spark.createDataFrame(text, ["text"])
predictionDF = zero_shot_bert.transform(inputDataset)

In [6]:
predictionDF.select("document.result", "class.result").show(10, False)

+----------------------------------------------------------------------------------------------------------------+--------+
|result                                                                                                          |result  |
+----------------------------------------------------------------------------------------------------------------+--------+
|[I have a problem with my iphone that needs to be resolved asap!!]                                              |[mobile]|
|[Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.]|[mobile]|
|[I have a phone and I love it!]                                                                                 |[mobile]|
|[I really want to visit Germany and I am planning to go there next year.]                                       |[travel]|
|[Let's watch some movies tonight! I am in the mood for a horror movie.]                                         |[movie] |
|[Have y

### Using Light Pipeline

In [7]:
sample_text = "Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app."

light_pipeline = LightPipeline(zero_shot_bert)

results = light_pipeline.annotate(sample_text)

results

{'document': ['Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app.'],
 'token': ['Last',
  'week',
  'I',
  'upgraded',
  'my',
  'iOS',
  'version',
  'and',
  'ever',
  'since',
  'then',
  'my',
  'phone',
  'has',
  'been',
  'overheating',
  'whenever',
  'I',
  'use',
  'your',
  'app',
  '.'],
 'class': ['mobile']}

In [8]:
results["class"]

['mobile']

### Multi Label vs. Multi Class

We can use `activation` parameter to set whether or not the result should be multi-class (the sum of all probabilities is `1.0`) or multi-label (each label has a probability between `0.0` to `1.0`)

- multi-class: `softmax` (default)
- multi-label: `sigmoid`

In [9]:
zero_shot_classifier\
    .setCandidateLabels(["space & cosmos", "scientific discovery", "microbiology", "robots", "archeology", "politics"])\
    .setActivation("sigmoid") # multi-label

pipeline = Pipeline(stages=[
    document_assembler,
    tokenizer,
    zero_shot_classifier
])

input_text3 = [
    ["""Learn about the presidential election process, including the Electoral College, caucuses and primaries, and the national conventions."""],
    ["""In a new book, Sean Carroll brings together physics and philosophy while advocating for "poetic naturalism." Ramin Skibba, Contributor. Space ..."""],
    ["""Who are you voting for in 2024?"""]]

# create a DataFrame in PySpark
inputDataset = spark.createDataFrame(input_text3, ["text"])
model = pipeline.fit(inputDataset)
predictionDF = model.transform(inputDataset)

predictionDF.select("document.result", "class.result").show(3, False)

+---------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------+
|result                                                                                                                                             |result                                |
+---------------------------------------------------------------------------------------------------------------------------------------------------+--------------------------------------+
|[Learn about the presidential election process, including the Electoral College, caucuses and primaries, and the national conventions.]            |[politics]                            |
|[In a new book, Sean Carroll brings together physics and philosophy while advocating for "poetic naturalism." Ramin Skibba, Contributor. Space ...]|[space & cosmos, scientific discovery]|
|[Who are you voting for in 2024?]                     

Let's see our other zero-shot classification models

## RoBerta Zero-Shot Classification

This model is intended to be used for zero-shot text classification, especially in English. It is fine-tuned on NLI by using Roberta Base model.

`RoBertaForZeroShotClassificationusing` a `ModelForSequenceClassification` trained on NLI (natural language inference) tasks. Equivalent of `RoBertaForZeroShotClassification` models, but these models don't require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more flexible.

We used `TFRobertaForSequenceClassification` to train this model and used `RoBertaForZeroShotClassification` annotator in Spark NLP for prediction at scale!

In [10]:
zero_shot_classifier = RoBertaForZeroShotClassification.pretrained("roberta_base_zero_shot_classifier_nli", "en")\
  .setInputCols(["document",'token'])\
  .setOutputCol("class")\
  .setCaseSensitive(True)\
  .setMaxSentenceLength(512)\
  .setCandidateLabels(["movie","mobile", "music", "travel", "sport", "computer"])

pipeline = Pipeline(stages=[
    document_assembler,
    tokenizer,
    zero_shot_classifier
])

zero_shot_roberta = pipeline.fit(spark.createDataFrame([[""]]).toDF("text"))

roberta_base_zero_shot_classifier_nli download started this may take some time.
Approximate size to download 444.8 MB
[OK!]


In [11]:
text = [["I have a problem with my iphone that needs to be resolved asap!!"],
        ["We need to harry up and get to the airport. We are going to miss our flight!"]]

# create a DataFrame in PySpark
inputDataset = spark.createDataFrame(text, ["text"])
predictionDF = zero_shot_roberta.transform(inputDataset)

In [12]:
predictionDF.select("document.result", "class.result").show(10, False)

+------------------------------------------------------------------------------+--------+
|result                                                                        |result  |
+------------------------------------------------------------------------------+--------+
|[I have a problem with my iphone that needs to be resolved asap!!]            |[mobile]|
|[We need to harry up and get to the airport. We are going to miss our flight!]|[travel]|
+------------------------------------------------------------------------------+--------+



## DistilBert Zero-Shot Classification

This model is intended to be used for zero-shot text classification, especially in English. It is fine-tuned on MNLI by using DistilBERT Base Uncased model.

`DistilBertForZeroShotClassification` using a `ModelForSequenceClassification` trained on NLI (natural language inference) tasks. Equivalent of `DistilBertForSequenceClassification` models, but these models don't require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is much more flexible.

We used `TFDistilBertForSequenceClassification` to train this model and used `DistilBertForZeroShotClassification` annotator in Spark NLP for prediction at scale!

In [13]:
zero_shot_classifier = DistilBertForZeroShotClassification.pretrained("distilbert_base_zero_shot_classifier_uncased_mnli", "en")\
    .setInputCols(["document", "token"]) \
    .setOutputCol("class") \
    .setCandidateLabels(["urgent", "mobile", "travel", "movie", "music", "sport", "technology"])

pipeline = Pipeline(stages=[
    document_assembler,
    tokenizer,
    zero_shot_classifier
])

zero_shot_distilbert = pipeline.fit(spark.createDataFrame([[""]]).toDF("text"))

distilbert_base_zero_shot_classifier_uncased_mnli download started this may take some time.
Approximate size to download 238.1 MB
[OK!]


In [14]:
text = [["I have a problem with my iphone that needs to be resolved asap!!"],
        ["Last week I upgraded my iOS version and ever since then my phone has been overheating whenever I use your app."],
        ["I have a phone and I love it!"],
        ["I really want to visit Germany and I am planning to go there next year."],
        ["Have you watched the match yesterday? It was a great game!"],
        ["We need to harry up and get to the airport. We are going to miss our flight!"]]

# create a DataFrame in PySpark
inputDataset = spark.createDataFrame(text, ["text"])

zero_shot_distilbert.transform(inputDataset).select("class.result").show()

+------------+
|      result|
+------------+
|    [mobile]|
|[technology]|
|    [mobile]|
|    [travel]|
|     [sport]|
|    [urgent]|
+------------+



##Bart Zero-Shot Classification

This model is intended to be used for zero-shot text classification, especially in English. It is fine-tuned on MNLI by using large BART model.

BartForZeroShotClassification using a ModelForSequenceClassification trained on MNLI tasks. Equivalent of BartForSequenceClassification models, but these models don’t require a hardcoded number of potential classes, they can be chosen at runtime. It usually means it’s slower but it is much more flexible.

We used TFBartForSequenceClassification to train this model and used BartForZeroShotClassification annotator in Spark NLP 🚀 for prediction at scale!

In [15]:
document_assembler = DocumentAssembler() \
.setInputCol('text') \
.setOutputCol('document')

tokenizer = Tokenizer() \
.setInputCols(['document']) \
.setOutputCol('token')

zeroShotClassifier = BartForZeroShotClassification \
.pretrained('bart_large_zero_shot_classifier_mnli', 'en') \
.setInputCols(['token', 'document']) \
.setOutputCol('class') \
.setCaseSensitive(True) \
.setMaxSentenceLength(512) \
.setCandidateLabels(["urgent", "mobile", "travel", "movie", "music", "sport", "weather", "technology"])

pipeline = Pipeline(stages=[
document_assembler,
tokenizer,
zeroShotClassifier
])

zero_shot_bart = pipeline.fit(spark.createDataFrame([[""]]).toDF("text"))

bart_large_zero_shot_classifier_mnli download started this may take some time.
Approximate size to download 445.4 MB
[OK!]


In [16]:
text = [["Last summer, I embarked on an unforgettable journey to explore the ancient ruins of Machu Picchu, surrounded by breathtaking landscapes and rich cultural history."]]

# create a DataFrame in PySpark
inputDataset = spark.createDataFrame(text, ["text"])
predictionDF = zero_shot_bart.transform(inputDataset)

In [17]:
predictionDF.select("document.result", "class.result").show(10, False)

+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+
|result                                                                                                                                                              |result  |
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+
|[Last summer, I embarked on an unforgettable journey to explore the ancient ruins of Machu Picchu, surrounded by breathtaking landscapes and rich cultural history.]|[travel]|
+--------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------+

