In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID";
os.environ["CUDA_VISIBLE_DEVICES"]="1" 

# Zero Shot Learning Using Natural Language Inference

In this notebook, we will demonstrate **zero-shot** topic classification.  **Zero-Shot Learning (ZSL)** is being able to solve a task despite not having received any training examples of that task.  The `ZeroShotClassifier` class in *ktrain* can be used to perform topic classification with no training examples.  The technique is based on **Natural Language Inference (or NLI)** as described in [this interesting blog post](https://joeddav.github.io/blog/2020/05/29/ZSL.html) by Joe Davison.

## STEP 1: Setup the Zero Shot Classifier and Describe Topics

We first instantiate the zero-shot-classifier and then describe the topic labels for our classifier with strings.

In [2]:
from ktrain import text 

In [3]:
zsl = text.ZeroShotClassifier()
topic_strings=['politics', 'elections', 'sports', 'films', 'television']

## STEP 2: Predict

There is no training involved here, as we are using **zero-shot-learning**.  We will simply supply the document that is being classified and the `topic_strings` defined earlier. The `predict` method uses Natural Language Inference (NLI) to infer the topic probabilities.

In [4]:
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.9791899),
 ('elections', 0.98745817),
 ('sports', 0.0005765463),
 ('films', 0.0022924456),
 ('television', 0.0010546101)]

As you can see, our model correctly assigned the highest probabilities to `politics` and `elections`, as the text supplied pertains to both these topics.

Let's try some other examples.
#### document about `television`

In [5]:
doc = 'What is your favorite sitcom of all time?'
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.00015667638),
 ('elections', 0.00032881147),
 ('sports', 0.00013884966),
 ('films', 0.075576425),
 ('television', 0.9813269)]

#### document about both `politics` and `television`

In [6]:
doc = """
President Donald Trump's senior adviser and son-in-law, Jared Kushner, praised 
the administration's response to the coronavirus pandemic as a \"great success story\" on Wednesday -- 
less than a day after the number of confirmed coronavirus cases in the United States topped 1 million. 
Kushner painted a rosy picture for \"Fox and Friends\" Wednesday morning, 
saying that \"the federal government rose to the challenge and 
this is a great success story and I think that that's really what needs to be told.\"
"""
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.8049428),
 ('elections', 0.01889327),
 ('sports', 0.0055048335),
 ('films', 0.05876928),
 ('television', 0.8776824)]

#### document about `sports`, `television`, and `film`

In [7]:
doc = "The Last Dance is a 2020 American basketball documentary miniseries co-produced by ESPN Films and Netflix."
zsl.predict(doc, topic_strings=topic_strings, include_labels=True)

[('politics', 0.0005349868),
 ('elections', 0.0007852868),
 ('sports', 0.98488265),
 ('films', 0.9576993),
 ('television', 0.94114333)]

## Prediction Time and Batch Size

The `predict` method of `ZeroShotClassifier` generates a separate NLI prediction for each topic included in `topic_strings`.  As `len(topic_strings)` increases, the prediction time will also increase.  **You can speed up predictions by increasing the `batch_size`.**  The default `batch_size` is currently set conservatively at 8:

#### Predicting 800 topics takes ~8 seconds on a TITAN V GPU using `batch_size=4`

In [8]:
%%time
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
predictions = zsl.predict(doc, topic_strings=topic_strings*160, include_labels=True, batch_size=4)

CPU times: user 14.9 s, sys: 20.7 ms, total: 15 s
Wall time: 7.5 s


#### Predicting 800 topics takes less than 2 seconds on a TITAN V GPU using `batch_size=64`

In [20]:
%%time
doc = 'I am extremely dissatisfied with the President and will definitely vote in 2020.'
predictions = zsl.predict(doc, topic_strings=topic_strings*160, include_labels=True, batch_size=64)

CPU times: user 1.87 s, sys: 385 ms, total: 2.26 s
Wall time: 1.68 s
