In [None]:
#!git clone https://github.com/huggingface/transformers.git
!pip install transformers
!pip install datasets

## Distilling Zero Shot Classification

This notebook demonstrates how to use a script that provides a way to improve the speed and memory performance of a zero-shot classifier by training a more efficient student model from the zero-shot teacher's predictions over an unlabeled dataset.

For a given sequence, the zero-shot classification pipeline requires each possible label to be fed through the large NLI model separately. This requirement slows results considerably, particularly for tasks with a large number of classes K.

We'll use the dbpedia14's entity classification dataset for this example.

In [1]:
from datasets import load_dataset
train, test = load_dataset('dbpedia_14', split=['train', 'test'])
train[0]

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2183.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1288.0, style=ProgressStyle(description…


Downloading and preparing dataset d_bpedia14/dbpedia_14 (download: 65.18 MiB, generated: 191.44 MiB, post-processed: Unknown size, total: 256.62 MiB) to /Users/arian/.cache/huggingface/datasets/d_bpedia14/dbpedia_14/2.0.0/a70413e39e7a716afd0e90c9e53cb053691f56f9ef5fe317bd07f2c368e8e897...


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset d_bpedia14 downloaded and prepared to /Users/arian/.cache/huggingface/datasets/d_bpedia14/dbpedia_14/2.0.0/a70413e39e7a716afd0e90c9e53cb053691f56f9ef5fe317bd07f2c368e8e897. Subsequent calls will reuse this data.


{'content': ' Abbott of Farnham E D Abbott Limited was a British coachbuilding business based in Farnham Surrey trading under that name from 1929. A major part of their output was under sub-contract to motor vehicle manufacturers. Their business closed in 1972.',
 'label': 0,
 'title': 'E. D. Abbott Ltd'}

In [2]:
import re

original_labels = train.info.features["label"].names
zeroshot_labels = [re.sub(r"(\w)([A-Z])", r"\1 \2", l) for l in original_labels]


print(original_labels)
print()
print(zeroshot_labels)


['Company', 'EducationalInstitution', 'Artist', 'Athlete', 'OfficeHolder', 'MeanOfTransportation', 'Building', 'NaturalPlace', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'WrittenWork']

['Company', 'Educational Institution', 'Artist', 'Athlete', 'Office Holder', 'Mean Of Transportation', 'Building', 'Natural Place', 'Village', 'Animal', 'Plant', 'Album', 'Film', 'Written Work']


In [3]:
train[1]

{'content': " Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss.",
 'label': 0,
 'title': 'Schwan-Stabilo'}

In [10]:
import pandas as pd

df_train = pd.DataFrame(train)

In [40]:
df_train.groupby(["label"]).count()

Unnamed: 0_level_0,content,title
label,Unnamed: 1_level_1,Unnamed: 2_level_1
0,40000,40000
1,40000,40000
2,40000,40000
3,40000,40000
4,40000,40000
5,40000,40000
6,40000,40000
7,40000,40000
8,40000,40000
9,40000,40000


### 🤗 Zero-shot classification pipeline

The [zero-shot classification pipeline](https://huggingface.co/transformers/main_classes/pipelines.html#transformers.ZeroShotClassificationPipeline) is a tool withing 🤗 Transformers that can be used to classify text sequences out of the box, provided only a list of possible class names:

In [4]:
from transformers import pipeline


zero_shot_classifier = pipeline('zero-shot-classification', model="roberta-large-mnli", device=0)

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
sequence = "Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss."
class_names = zeroshot_labels
zero_shot_classifier(sequence, class_names)

{'sequence': "Schwan-STABILO is a German maker of pens for writing colouring and cosmetics as well as markers and highlighters for office use. It is the world's largest manufacturer of highlighter pens Stabilo Boss.",
 'labels': ['Company',
  'Building',
  'Village',
  'Mean Of Transportation',
  'Album',
  'Artist',
  'Film',
  'Plant',
  'Written Work',
  'Office Holder',
  'Educational Institution',
  'Natural Place',
  'Animal',
  'Athlete'],
 'scores': [0.26675543189048767,
  0.16022440791130066,
  0.08964039385318756,
  0.06799466907978058,
  0.06777624785900116,
  0.05430350452661514,
  0.050630487501621246,
  0.044264644384384155,
  0.04133811220526695,
  0.040396228432655334,
  0.034614574164152145,
  0.0286851953715086,
  0.028418727219104767,
  0.02495739236474037]}

This method serves as a convenient out-of-the-box classifier. Unfortunately, the method is by necessity somewhat slow. This is partially due to the large underlying model being used, but more important is the fact that for this method to work, every possible sequence / class name pair must be fed through the model together. So in order to classify `N` sequences into `K` classes, the model has to be called `N*K` times (whereas a typical classifier would only be called `N` times). This makes the method comparatively slow, especially for settings with a large number of classes.

In [6]:
# classify 1600 examples with K=4 classes
%time
for _ in range(100):
    zero_shot_classifier([sequence] * 16, class_names)

CPU times: user 7 µs, sys: 4 µs, total: 11 µs
Wall time: 24.8 µs


In [6]:
# # classify 1600 examples with K=8 classes
# %time
# expanded_class_names = class_names + ["politics", "health", "food", "weather"]
# for _ in range(100):
#     zero_shot_classifier([sequence] * 16, expanded_class_names)

CPU times: user 7 µs, sys: 1 µs, total: 8 µs
Wall time: 16.2 µs


As we can see, increasing the number of classes from `K=4` to `K=8` approximately doubles the inference time. This classification method is extremely useful, but ideally we'd like to speed up inference.

### Distilling a more efficient student model

The best way to speed up inference is to **train a more efficient student model on the zero-shot classifier's predictions** over an unlabeled dataset. This can be done with the [`distill_classifier.py`](https://github.com/huggingface/transformers/blob/master/examples/research_projects/zero-shot-distillation/distill_classifier.py) script provided in the `transformers` repo.

Given (1) an unlabeled corpus and (2) a set of `K` class names, this script allows a user to train a standard classification head with `K` output dimensions. The script generates a softmax distribution for the provided data & class names, and a student classifier is then fine-tuned on these proxy labels. The resulting student model can be used for classifying novel text instances into the previously specified `K` classes with an order-of-magnitude boost in inference speed plus decreased memory usage.

Let's see how to do this with the [AG's News](https://huggingface.co/datasets/ag_news) topic classification dataset. The first thing we need is an unlabeled dataset (in reality AG's News is annotated of course, but we'll pretend and ignore the annotations for the sake of example). Let's put the sequences from the train set into a `txt` file:

In [7]:
#!pip install wandb
#!wandb init

[32m[1mLet's setup this directory for W&B![0m
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
Aborted!


In [None]:
# setup WandB and connect to team workspace
import wandb
wandb.init(project="weak-supervision-case-study", entity="")

In [48]:
!mkdir dbmedia14
with open("dbmedia14/train_unlabeled.txt", 'w') as f:
    for seq in train["content"]:
        f.write(seq + '\n')

In [3]:
!wc -l dbmedia14/train_unlabeled.txt

560000 dbmedia14/train_unlabeled.txt


In [4]:
!wc -l dbmedia14/class_names.txt

14 dbmedia14/class_names.txt


The other thing the script needs is the names of the classes. We'll put these into their own newline-delimitted `txt` as well:

In [49]:
with open("dbmedia14/class_names.txt", 'w') as f:
    for label in class_names:
        f.write(label + '\n')

In [9]:
!pwd

/home/dock/workspace/transformers/examples/research_projects/zero-shot-distillation


Now we can run the script. First the zero-shot model will loop through the data and generate (soft) proxy-labels, and then a student `DistilBert` model will be fine-tuned on these predictions. The student will then be saved in `./distilbert-base-uncased-agnews-student`. See the [script readme](https://github.com/huggingface/transformers/blob/master/examples/research_projects/zero-shot-distillation/README.md) for more information about the available script arguments.

On a single P100, this will take about ~2 hours with the full training set of 130K examples. On a V100 with mixed precision (just pass `--fp16`), it will take ~30 minutes.

In [2]:
import gc
import torch
gc.collect()

torch.cuda.empty_cache()

# TODO

1. It fails to use the entire dataset. Need to run with subsets 1k, 10k of the training set. I will split the annotation script and the training script in too steps so it won't happen.  

2. It takes about 8 hours to annotate the labels using zero shot distillation. The resulting training set with the labels fail to run the training phase because of memory I think. Need to debug that.

In [1]:
!python distill_classifier.py \
--data_file ./dbmedia14/train_unlabeled_10k.txt \
--class_names_file ./dbmedia14/class_names.txt \
--hypothesis_template "This text is about {}." \
--student_name_or_path distilbert-base-uncased \
--output_dir ./distilbert-base-uncased-dbmedia14-student \
--fp16 True

Traceback (most recent call last):
  File "distill_classifier.py", line 8, in <module>
    from datasets import Dataset
ModuleNotFoundError: No module named 'datasets'


### Using the student model

The resulting model can now be loaded and used like any other pre-trained model:

(you can also use `"joeddav/distilbert-base-uncased-agnews-student"` to download this model from the hub if you want to try it without running whole script above)

In [None]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("./distilbert-base-uncased-dbmedia14-student")
model = AutoModelForSequenceClassification.from_pretrained("./distilbert-base-uncased-dbmedia14-student")
model.config

and even used trivially with a `TextClassificationPipeline`:

In [None]:
from transformers import TextClassificationPipeline
distilled_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer, return_all_scores=True, device=0)
print(sequence)
distilled_classifier(sequence)

Let's compare the speed & accuracy of the two methods.

Original zero-shot model:

In [16]:
import numpy as np
from time import time
from tqdm.auto import tqdm

start = time()
batch_size = 32
hypothesis_template = "This text is about {}."
preds = []
for i in tqdm(range(0, len(test), batch_size)):
    examples = test[i:i+batch_size]['text']
    outputs = zero_shot_classifier(examples, class_names, hypothesis_template=hypothesis_template)
    preds += [class_names.index(o['labels'][0]) for o in outputs]
accuracy = np.mean(np.array(preds) == np.array(test['label']))
print(f"Teacher model accuracy: {accuracy*100:0.2f}%")
print(f"Runtime: {time() - start : 0.2f} seconds")

HBox(children=(FloatProgress(value=0.0, max=238.0), HTML(value='')))


Teacher model accuracy: 69.33%
Runtime:  221.37 seconds


Distilled student model:

In [17]:
start = time()
batch_size = 128 # larger batch size bc distilled model is more memory efficient
distilled_classifier.return_all_scores = False
preds = []
groundtruth = []
for i in tqdm(range(0, len(test), batch_size)):
    examples = test[i:i+batch_size]['text']
    outputs = distilled_classifier(examples)
    preds += [class_names.index(o['label']) for o in outputs]
    groundtruth += test[i:i+batch_size]['label']
accuracy = np.mean(np.array(preds) == np.array(test['label']))
print(f"Distilled model accuracy: {accuracy*100:0.2f}%")
print(f"Runtime: {time() - start : 0.2f} seconds")

HBox(children=(FloatProgress(value=0.0, max=60.0), HTML(value='')))


Distilled model accuracy: 70.79%
Runtime:  11.95 seconds


Error: IPyKernel not installed into interpreter Python 3.9.1 64-bit:/usr/local/bin/python3

In [None]:
from sklearn.metrics import precision_recall_fscore_support

scores = {
    "macro":{}, "micro":{}, "weighted":{}
}

for metric in scores:
    scores[metric]["precision"], scores[metric]["recall"], scores[metric]["f1"], _ = \
    precision_recall_fscore_support(groundtruth, preds, average=metric)

from pprint import pprint
pprint(scores)

In [None]:
from sklearn import metrics
confusion_matrix = metrics.confusion_matrix(groundtruth, preds)
confusion_matrix

In [None]:
from sklearn.metrics import classification_report

print(classification_report(groundtruth, preds, target_names=class_names))

As you can see, **the disitlled model gets similar accuracy on a held-out test set while running in 1/20th the time**. 

Lastly, you can share the distilled model with the community by [uploading it to the 🤗 Hub](https://huggingface.co/transformers/model_sharing.html). We've uploaded the distilled model from this notebook at [joeddav/distilbert-base-uncased-agnews-student](https://huggingface.co/joeddav/distilbert-base-uncased-agnews-student).