# Zero Shot Text Classification Model
**Model Description** - Bart with a classification head trained on MNLI.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. business -> This text is about business.

In [1]:
import eland as ed
from eland.conftest import *
import pandas as pd
import seaborn as sns
import preprocessor as prep
import matplotlib.pyplot as plt
import string
from tqdm.auto import tqdm
from transformers import BartForSequenceClassification, BartTokenizer

pd.options.display.max_columns = None
pd.options.display.max_rows = None

PyTorch version 1.6.0 available.
TensorFlow version 2.3.0 available.


## Importing the Data from Elasticsearch

In [2]:
ed_df = ed.DataFrame('localhost', 'twitter', columns=['full_text', 'user_id', 'verified', 'name', 'location', 'entities.hashtags.text', 'entities.user_mentions.name'])

# defining the full-text query we need: Retrieving records for full_text_processed with the condition is_retweet=False and is_quote_status=False
query_unique = {
    "bool": {
        "must": {
            "term":{"is_retweet":"false"},
        },
        "filter": {
            "term":{"is_quote_status":"false"},
            "term":{"lang.keyword":"en"}
        },
    }
}
# using full-text search capabilities with Eland:
df_ed = ed_df.es_query(query_unique)
df_tweets = df_ed.to_pandas()

In [3]:
df_tweets.shape

(75523, 7)

In [4]:
df_tweets.head()

Unnamed: 0,full_text,user_id,verified,name,location,entities.hashtags.text,entities.user_mentions.name
1264253979002843136,The effect of Amphan in South 24 parganas http...,1256290967344168960,False,Avishek Pradhan,West Bengal,,
1264253893632016384,"Dukkhor bishoye, onek khoti hoyeche. Hoping fo...",1249801865518145536,False,Aryan M.,Kolkata | Delhi,,
1264253882580045824,"Economic is in prblm, covid19 is arising it's ...",868107356,False,tk.sinha,"Kolkata, India",,
1264253658763612160,@DrSJaishankar @MEAIndia How many days more In...,130462260,False,Mohammad Khalilullah,,,"[Dr. S. Jaishankar, Anurag Srivastava]"
1264253569525592064,@HCI_London @MEAIndia How many days more India...,130462260,False,Mohammad Khalilullah,,,"[India in the UK, Anurag Srivastava]"


## Basic Tweet Preprocessing
- Remove URLs and reserved words (RTs)
- Remove # and @ symbols
- Remove tweets less than 4 tokens in length


In [5]:
## Set options for the tweet-preprocessor
prep.set_options(prep.OPT.URL, prep.OPT.RESERVED, prep.OPT.EMOJI, prep.OPT.SMILEY)

## Clean text and remove #,@ symbols
def clean_tweet(text):
    text = prep.clean(text)
    table = str.maketrans('','','#@')
    return text.translate(table)

In [6]:
df_tweets['full_text'] = df_tweets['full_text'].apply(lambda x: clean_tweet(x))

In [7]:
df_tweets['length'] = df_tweets['full_text'].apply(lambda x: len([w for w in x.split()]))
df_tweets = df_tweets[df_tweets['length']>4]

In [8]:
df_tweets.head()

Unnamed: 0,full_text,user_id,verified,name,location,entities.hashtags.text,entities.user_mentions.name,length
1264253979002843136,The effect of Amphan in South 24 parganas,1256290967344168960,False,Avishek Pradhan,West Bengal,,,8
1264253893632016384,"Dukkhor bishoye, onek khoti hoyeche. Hoping fo...",1249801865518145536,False,Aryan M.,Kolkata | Delhi,,,37
1264253882580045824,"Economic is in prblm, covid19 is arising it's ...",868107356,False,tk.sinha,"Kolkata, India",,,22
1264253658763612160,DrSJaishankar MEAIndia How many days more Indi...,130462260,False,Mohammad Khalilullah,,,"[Dr. S. Jaishankar, Anurag Srivastava]",48
1264253569525592064,HCI_London MEAIndia How many days more Indian ...,130462260,False,Mohammad Khalilullah,,,"[India in the UK, Anurag Srivastava]",48


## Initialising the Model

In [9]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json from cache at /home/ubuntu/.cache/torch/transformers/1ae1f5b6e2b22b25ccc04c000bb79ca847aa226d0761536b011cf7e5868f0655.ef00af9e673c7160b4d41cfda1f48c5f4cba57d5142754525572a846a1ab1b9b
loading file https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt from cache at /home/ubuntu/.cache/torch/transformers/f8f83199a6270d582d6245dc100e99c4155de81c9745c6248077018fe01abcfb.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json from cache at /home/ubuntu/.cache/torch/transformers/a35b79dc26c2f371a0e19eae44d91c0a0281a5db09044517d2675703791ee3c5.746d7ef19ade685cd3ee03f131a96fab513947c26179546289ddf02a6ac683ce
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_

## Example on how to use the Model
The sequence (text) is posed as a NLI Premise and the label as a hypothesis

In [10]:
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'

# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')

Probability that the label is true: 98.08%


## Defining the Labels and Threshold for classification

In [11]:
TERMS = ['resource availability', 'volunteers', 'power supply', 'relief measures', 'food supply', 'infrastructure', 'medical assistance', 'rescue', 'shelter', 'utilities', 'water supply', 'evacuation', 'government', 'crime violence', 'mobile network', 'sympathy', 'news updates', 'internet', 'grievance']

HYPOTHESES = ['This text is about '+x for x in TERMS]
THRESHOLD = 0.5

In [12]:
'''Method to get the labels for a tweet based on threshold specified'''
def get_labels(premise, threshold=THRESHOLD):
    topics = []
    for idx, hypothesis in enumerate(HYPOTHESES):
        # run through model pre-trained on MNLI
        input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
        logits = model(input_ids)[0]

        # we throw away "neutral" (dim 1) and take the probability of
        # "entailment" (2) as the probability of the label being true 
        entail_contradiction_logits = logits[:,[0,2]]
        probs = entail_contradiction_logits.softmax(dim=1)
        true_prob = probs[:,1].item() * 100

        if true_prob>=threshold:
            topics.append((TERMS[idx], np.round(true_prob,2)))
    
    return topics

In [None]:
tqdm.pandas()
df_tweets['labels'] = df_tweets['full_text'].progress_apply(lambda x: get_labels(x, THRESHOLD))

## GPU approach using Transformers Pipeline

In [15]:
from transformers import pipeline

In [16]:
classifier = pipeline('zero-shot-classification', device=0)

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json from cache at /home/ubuntu/.cache/torch/transformers/a35b79dc26c2f371a0e19eae44d91c0a0281a5db09044517d2675703791ee3c5.746d7ef19ade685cd3ee03f131a96fab513947c26179546289ddf02a6ac683ce
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": false,
  "id2label": {
    "0": "contradiction",
    "1": "neutral",
    "2": "entailment"
  },
  "init_std":

In [17]:
TERMS = ['resource availability', 'volunteers', 'power supply', 'relief measures', 
         'food supply', 'infrastructure', 'medical assistance', 'rescue', 'shelter', 
         'utilities', 'water supply', 'evacuation', 'government', 'crime violence', 
         'mobile network', 'sympathy', 'news updates', 'internet', 'grievance', 
         'livelihood', 'income', 'ecosystem', 'biodiversity', 'agriculture']

In [18]:
%timeit classifier(df_tweets['full_text'][0], TERMS, multi_class=True)

46.4 ms ± 946 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
'''Method to get the labels for a tweet based on threshold specified'''
def get_all_labels(x, terms=TERMS):
    
    # Run model
    result = classifier(x, terms, multi_class=True)
    
    topics = []
    for label, score in zip(result['labels'], result['scores']):

        topics.append((label, np.round(score,2)))
            
    return topics

In [22]:
tqdm.pandas()
df_tweets['full_text'].progress_apply(lambda x: get_all_labels(x, TERMS)).to_json('../models/zstc_labels.json', orient='index')

HBox(children=(FloatProgress(value=0.0, max=74289.0), HTML(value='')))


