# Read this!

This notebook is mostly a copy of `07_ZeroShotTopicClassificationCORD19.ipynb`.
The main difference is that this notebook implements the Zero-Shot NLI method with a smaller pre-trained BERT model, as opposed to the original notebook, that implements the method with a large pre-trained BART model.

# Zero Shot Topic Classification on CORD-19

## Introduction

In this notebook we'll build a Zero Shot Topic Classifier on the COVID-19 Open Research Dataset (CORD-19, Wang et al., 2020).
Essentially, we aim to build a web application capable of receiving natural language questions, such as "what do we know about vaccines and therapeutics?", and then displaying the most relevant research literature regarding the specific question.
This dataset has received wide attention in the data mining and natural language processing community in order to develop tools to aid health workers stay up-to-date with the latest and most relevant research about the current pandemic.

Recent advances in NLP, such as OpenAI's GPT-3 (Brown et al., 2020), have shown that large language models can achieve competitive performance on downstream tasks with less task-specific data than it'd be required by smaller models.
However, GPT-3 is currently difficult to use on real world applications due to its size of ~175 billions of parameters.

Recent experiments made at HuggingFace (Davison, 2020) explored the potential of using Sentence-BERT (Reimers and Gurevych, 2020) to separately embed sentences and never-seen-before topic labels.
Then, they'd rank the sentence's topics by measuring the cosine distance between both vectors (Veeranna, 2016), obtaining promising results.

In another experiment, they use a pre-trained natural languange inference (NLI) sequence-pair classifier as an out of-the-box zero shot text classifier, as proposed by Yin et al. (2020).
By using a pre-trained BART model (Lewis et al., 2019) fine-tuned on the Multigenre NLI corpus, they were able to score an F1 score of 53.7 on the Yahoo News dataset.
The dataset has 10 classes and the current supervised models state of the art is an accuracy of 77.62.

In [None]:
%env CUDA_VISIBLE_DEVICES=6

env: CUDA_VISIBLE_DEVICES=6


In [None]:
# default_exp zero_shot_nli_small

In [None]:
# all_flag

## Natural Language Inference (NLI) method

In this approach we use a BART classifier (Lewis et al., 2019) pre-trained on the Multi-Genre NLI (MultiNLI, Williams et al., 2018) corpus as the base model.

Given research interests expressed in natural language, we pose the problem of recovering relevant research from the CORD-19 dataset (Wang et al., 2020) as a Zero Shot Topic Classification task (Yin et al., 2019).
Leveraging the Natural Language Inference task framework, we assess each paper relevance by feeding the model with the paper's title and abstract as premise and a research interest as hypothesis.

Finally, we use the model's entailment inference values as proxy relevance scores for each paper.

In [None]:
# export

import math
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def get_nli_model(name="facebook/bart-large-mnli"):
    model = AutoModelForSequenceClassification.from_pretrained(name)
    tokenizer = AutoTokenizer.from_pretrained(name)
    return model, tokenizer

In [None]:
from risotto.artifacts import load_papers_artifact

try:
    papers = load_papers_artifact()
    model, tokenizer = get_nli_model(name="huggingface/prunebert-base-uncased-6-finepruned-w-distil-mnli")
except FileNotFoundError:
    print('Data artifacts not ready.')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1311.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=437973677.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=39.0, style=ProgressStyle(description_w…






In [None]:
# export

import numpy as np
from fastprogress.fastprogress import progress_bar


def build_tokenized_papers_artifact(papers,
                                    tokenizer,
                                    should_dump=True,
                                    dump_path=None,
                                    batch_size=128):
    num_batches = math.ceil(len(papers) / batch_size)
    tokenized_series = pd.Series([], dtype=object, name="tokenized_papers")

    for batch_idx in progress_bar(range(num_batches)):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        papers_batch = papers.iloc[start_idx:end_idx]

        title_abstract = (papers_batch.title + ". " +
                          papers_batch.abstract).fillna(" ").values.tolist()
        tokenized_batch = tokenizer.batch_encode_plus(
            title_abstract, max_length=tokenizer.model_max_length)

        for i, (paper_idx, _) in enumerate(papers_batch.iterrows()):
            tokenized_series.at[paper_idx] = tokenized_batch["input_ids"][i]

    if should_dump:
        tokenized_series.to_hdf(dump_path, key="tokenized_papers")

    return tokenized_series


def load_tokenized_papers_artifact(artifacts_path):
    return pd.read_hdf(artifacts_path, key="tokenized_papers")

In [None]:
# Build tokenized papers
tokenized_papers = build_tokenized_papers_artifact(
    papers=papers,
    tokenizer=tokenizer,
    dump_path="artifacts/nli_bert_artifacts.hdf",
)
tokenized_papers.head()

your performance may suffer as PyTables will pickle object types that it cannot
map directly to c-types [inferred_type->mixed,key->values] [items->None]

  encoding=encoding,


ug7v899j    [101, 6612, 2838, 1997, 3226, 1011, 10003, 202...
02tnwd4m    [101, 9152, 12412, 15772, 1024, 1037, 4013, 10...
ejv2xln0    [101, 14175, 18908, 4630, 5250, 1011, 1040, 19...
2b73a28n    [101, 2535, 1997, 2203, 14573, 18809, 1011, 10...
9785vg6d    [101, 4962, 3670, 1999, 4958, 8939, 24587, 444...
Name: tokenized_papers, dtype: object

In [None]:
# Load tokenized papers
tokenized_papers = load_tokenized_papers_artifact(
    "artifacts/nli_bert_artifacts.hdf")
tokenized_papers.head()

ug7v899j    [0, 20868, 1575, 9, 2040, 12, 32012, 1308, 438...
02tnwd4m    [0, 19272, 4063, 30629, 35, 10, 1759, 12, 3382...
ejv2xln0    [0, 6544, 24905, 927, 8276, 12, 495, 8, 34049,...
2b73a28n    [0, 21888, 9, 253, 15244, 2614, 12, 134, 11, 1...
9785vg6d    [0, 13120, 8151, 11, 22201, 44828, 4590, 11, 1...
Name: tokenized_papers, dtype: object

In [None]:
# export

import torch


def build_entailments_artifact(tokenized_papers,
                               query_tokenized,
                               batch_size=64,
                               device="cuda",
                               should_dump=True,
                               dump_path=None):
    query_encoded = [*query_tokenized[1:]]

    model.eval()
    model.to(device)

    num_batches = math.ceil(len(tokenized_papers) / batch_size)
    entail_probs = []

    for batch_idx in progress_bar(range(num_batches)):
        start_idx = batch_idx * batch_size
        end_idx = start_idx + batch_size
        tokenized_papers_batch = tokenized_papers.iloc[
            start_idx:end_idx].tolist()
        max_length = float("-inf")
        for i in range(len(tokenized_papers_batch)):
            tokenized_paper = tokenized_papers_batch[i]
            # Ugly hack...
            tokenized_papers_batch[i] = [
                *tokenized_paper[:-1][:(tokenizer.model_max_length -
                                        len(query_encoded))], *query_encoded
            ]
            tokenized_paper = tokenized_papers_batch[i]
            if len(tokenized_paper) > max_length:
                max_length = len(tokenized_paper)
        masks = []
        token_type_ids = []
        for tokenized_paper in tokenized_papers_batch:
            paper_length = len(tokenized_paper)
            delta = max_length - paper_length

            paper_type_ids = [
                0 for _ in range(paper_length - len(query_encoded))
            ]
            paper_type_ids += [1 for _ in range(len(query_encoded))]
            paper_type_ids += [0 for _ in range(delta)]
            token_type_ids.append(paper_type_ids)

            tokenized_paper += [0 for _ in range(delta)]

            mask = [1 for _ in range(paper_length)]
            mask += [0 for _ in range(delta)]
            masks.append(mask)

        input_ids = torch.tensor(tokenized_papers_batch).to(device)
        token_type_ids = torch.tensor(token_type_ids).to(device)
        attention_mask = torch.tensor(masks).to(device)

        with torch.no_grad():
            outputs = model(input_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask)[0]

        entail_contradiction_logits = outputs[:, [0, 1]]
        probs = entail_contradiction_logits.softmax(dim=1)
        entail_probs += (probs[:, 1] * 100).tolist()

    entail_series = pd.Series(entail_probs,
                              index=tokenized_papers.index,
                              name="entailments")

    if should_dump:
        entail_series.to_hdf(dump_path, key="entailments")

    return entail_series


def load_entailments_artifact(artifacts_path):
    return pd.read_hdf(artifacts_path, key="entailments")

In [None]:
query_tokenized = tokenizer.encode(
    "This paper is about vaccines and therapeutics.")
build_entailments_artifact(batch_size=256,
                           tokenized_papers=tokenized_papers,
                           query_tokenized=query_tokenized,
                           dump_path="artifacts/nli_bert_artifacts.hdf")

ug7v899j    34.904873
02tnwd4m     3.709159
ejv2xln0    83.782410
2b73a28n     0.407605
9785vg6d    83.543762
              ...    
2upc2spn    72.468544
48kealmj    21.194389
7goz1agp    90.725761
twp49jg3    78.526169
wtoj53xy    10.045611
Name: entailments, Length: 77304, dtype: float64

## References

Brown, T. B., Mann, B., Ryder, N., Subbiah, M., Kaplan, J., Dhariwal, P., Neelakantan, A., Shyam, P., Sastry, G., Askell, A., Agarwal, S., Herbert-Voss, A., Krueger, G., Henighan, T., Child, R., Ramesh, A., Ziegler, D. M., Wu, J., Winter, C., … Amodei, D. (2020). Language Models are Few-Shot Learners. https://arxiv.org/abs/2005.14165

Davison, J. (2020). Zero-Shot Learning in Modern NLP. https://joeddav.github.io/blog/2020/05/29/ZSL.html

Lewis, M., Liu, Y., Goyal, N., Ghazvininejad, M., Mohamed, A., Levy, O., Stoyanov, V., & Zettlemoyer, L. (2019). BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension. http://arxiv.org/abs/1910.13461

Reimers, N., & Gurevych, I. (2020). Sentence-BERT: Sentence embeddings using siamese BERT-networks. EMNLP-IJCNLP 2019 - 2019 Conference on Empirical Methods in Natural Language Processing and 9th International Joint Conference on Natural Language Processing, Proceedings of the Conference, 3982–3992. https://doi.org/10.18653/v1/d19-1410

Veeranna, S. P., Nam, J., Mencía, E. L., & Fürnkranz, J. (2016). Using semantic similarity for multi-label zero-shot classification of text documents. ESANN 2016 - 24th European Symposium on Artificial Neural Networks, April, 423–428.

Wang, L. L., Lo, K., Chandrasekhar, Y., Reas, R., Yang, J., Eide, D., Funk, K., Kinney, R., Liu, Z., Merrill, W., Mooney, P., Murdick, D., Rishi, D., Sheehan, J., Shen, Z., Stilson, B., Wade, A. D., Wang, K., Wilhelm, C., … Kohlmeier, S. (2020). CORD-19: The Covid-19 Open Research Dataset. https://arxiv.org/abs/2004.10706

Williams, A., Nangia, N., & Bowman, S. R. (2018). A Broad-Coverage Challenge Corpus for Sentence Understanding through Inference. Proceedings of the 2018 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long Papers), 1112--1122. http://aclweb.org/anthology/N18-1101

Yin, W., Hay, J., & Roth, D. (2019). Benchmarking zero-shot text classification: Datasets, evaluation and entailment approach. EMNLP-IJCNLP 2019 - 2019 Conference on Empirical Methods in Natural Language Processing and 9th International Joint Conference on Natural Language Processing, Proceedings of the Conference, 3914–3923. https://doi.org/10.18653/v1/d19-1404

---