# Zero Shot Text Classification Model
**Model Description** - Bart with a classification head trained on MNLI.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. business -> This text is about business.

In [1]:
import eland as ed
from eland.conftest import *
import pandas as pd
import seaborn as sns
import preprocessor as prep
import matplotlib.pyplot as plt
import string
from tqdm.auto import tqdm
from transformers import BartForSequenceClassification, BartTokenizer

pd.options.display.max_columns = None
pd.options.display.max_rows = None

## Importing the Data from Elasticsearch

In [2]:
ed_df = ed.DataFrame('localhost', 'twitter', columns=['full_text', 'user_id', 'verified', 'name', 'location', 'entities.hashtags.text', 'entities.user_mentions.name'])

# defining the full-text query we need: Retrieving records for full_text_processed with the condition is_retweet=False and is_quote_status=False
query_unique = {
    "bool": {
        "must": {
            "term":{"is_retweet":"false"},
        },
        "filter": {
            "term":{"is_quote_status":"false"},
            "term":{"lang.keyword":"en"}
        },
    }
}
# using full-text search capabilities with Eland:
df_ed = ed_df.es_query(query_unique)
df_tweets = df_ed.to_pandas()

In [6]:
df_tweets.shape

(68701, 7)

In [4]:
df_tweets.head()

Unnamed: 0,full_text,user_id,verified,name,location,entities.hashtags.text,entities.user_mentions.name
1264160647002103808,Praying for everyone affected by #AmphanSuperC...,1256622599364214786,False,The Meraaki,"Ahmadabad City, India",AmphanSuperCyclone,
1264160569038442496,"Political differences was there, they still ex...",219183608,False,Sujatro Ghosh,"Berlin, Deutschland",Amphan,
1264132382296289280,Amid #coronavirus crisis cyclone Amphan pummel...,219183608,False,Sujatro Ghosh,"Berlin, Deutschland",coronavirus,
1264160529100288000,West Bengal calls for Indian Army's support to...,18071358,True,Zee News English,India,,
1264083313699901440,Don’t send Shramik Special trains till May 26 ...,18071358,True,Zee News English,India,,


## Basic Tweet Preprocessing
- Remove URLs and reserved words (RTs)
- Remove # and @ symbols
- Remove tweets less than 4 tokens in length


In [5]:
## Set options for the tweet-preprocessor
prep.set_options(prep.OPT.URL, prep.OPT.RESERVED, prep.OPT.EMOJI, prep.OPT.SMILEY)

## Clean text and remove #,@ symbols
def clean_tweet(text):
    text = prep.clean(text)
    table = str.maketrans('','','#@')
    return text.translate(table)

In [6]:
df_tweets['full_text'] = df_tweets['full_text'].apply(lambda x: clean_tweet(x))

In [7]:
df_tweets['length'] = df_tweets['full_text'].apply(lambda x: len([w for w in x.split()]))
df_tweets = df_tweets[df_tweets['length']>4]

In [8]:
df_tweets.head()

Unnamed: 0,full_text,user_id,verified,name,location,entities.hashtags.text,entities.user_mentions.name,length
1264160647002103808,Praying for everyone affected by AmphanSuperCy...,1256622599364214786,False,The Meraaki,"Ahmadabad City, India",AmphanSuperCyclone,,24
1264160569038442496,"Political differences was there, they still ex...",219183608,False,Sujatro Ghosh,"Berlin, Deutschland",Amphan,,11
1264132382296289280,Amid coronavirus crisis cyclone Amphan pummele...,219183608,False,Sujatro Ghosh,"Berlin, Deutschland",coronavirus,,36
1264160529100288000,West Bengal calls for Indian Army's support to...,18071358,True,Zee News English,India,,,19
1264083313699901440,Dont send Shramik Special trains till May 26 i...,18071358,True,Zee News English,India,,,18


## Initialising the Model

In [9]:
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')

Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Example on how to use the Model
The sequence (text) is posed as a NLI Premise and the label as a hypothesis

In [10]:
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'

# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')

Probability that the label is true: 98.08%


## Defining the Labels and Threshold for classification

In [19]:
TERMS = ['resource availability', 'volunteers', 'power supply', 'relief measures', 'food supply', 'infrastructure', 'medical assistance', 'rescue', 'shelter', 'utilities', 'water supply', 'evacuation', 'government', 'crime violence', 'mobile network', 'sympathy', 'news updates', 'internet', 'grievance']

HYPOTHESES = ['This text is about '+x for x in TERMS]
THRESHOLD = 50

In [12]:
'''Method to get the labels for a tweet based on threshold specified'''
def get_labels(premise, threshold=THRESHOLD):
    topics = []
    for idx, hypothesis in enumerate(HYPOTHESES):
        # run through model pre-trained on MNLI
        input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
        logits = model(input_ids)[0]

        # we throw away "neutral" (dim 1) and take the probability of
        # "entailment" (2) as the probability of the label being true 
        entail_contradiction_logits = logits[:,[0,2]]
        probs = entail_contradiction_logits.softmax(dim=1)
        true_prob = probs[:,1].item() * 100

        if true_prob>=threshold:
            topics.append((TERMS[idx], np.round(true_prob,2)))
    
    return topics

In [13]:
tqdm.pandas()
df_tweets['labels'] = df_tweets['full_text'].progress_apply(lambda x: get_labels(x, THRESHOLD))