# Zero Shot Text Classification Model
**Model Description** - Bart with a classification head trained on MNLI.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. business -> This text is about business.

In [1]:
import eland as ed
from eland.conftest import *
import pandas as pd
import seaborn as sns
import preprocessor as prep
import matplotlib.pyplot as plt
import string
from tqdm.auto import tqdm
from transformers import BartForSequenceClassification, BartTokenizer

pd.set_option('display.max_colwidth', -1)

PyTorch version 1.6.0 available.
TensorFlow version 2.3.0 available.


## Importing the Data from Elasticsearch

In [2]:
ed_df = ed.DataFrame('localhost', 'twitter', columns=['full_text_processed'])

# defining the full-text query we need: Retrieving records for full_text_processed with the condition is_retweet=False and is_quote_status=False
query_unique = {
    "bool": {
        "must": {
            "term":{"is_retweet":"false"},
        },
        "filter": {
            "term":{"is_quote_status":"false"}
        },
    }
}
# using full-text search capabilities with Eland:
df_ed = ed_df.es_query(query_unique)
df_tweets = df_ed.to_pandas()

In [3]:
df_tweets.shape

(108364, 1)

In [4]:
df_tweets.head()

Unnamed: 0,full_text_processed
1264253979002843136,effect amphan south 24 parganas
1264253959918632960,field experience dealing cyclone ampan
1264253893632016384,dukkhor bishoye onek khoti hoyeche hoping best possible recovery asap shokol ke bolchi ektu patience rakhun shob thik hoye jabe government ha taken control seriously ei ‡¶¨‡¶æ‡¶Ç‡¶≤‡¶æ amar hanshe abar amphan bengalfightsamphan bengaltweet kolkata
1264253882580045824,economic prblm covid19 arising increasing nd kolkata amphan ha devastated wht doüëèüëè
1264253658763612160,drsjaishankar meaindia many day indian citizen resident west bengal suffer amp remain get stuck uk sirji flight 25 may london indiaplease help start london kolkata member huge crisis due cyclone amphan


## Basic Tweet Preprocessing
- Remove URLs and reserved words (RTs)
- Remove # and @ symbols
- Remove tweets less than 4 tokens in length


In [5]:
## Set options for the tweet-preprocessor
prep.set_options(prep.OPT.URL, prep.OPT.RESERVED, prep.OPT.EMOJI, prep.OPT.SMILEY)

## Clean text and remove #,@ symbols
def clean_tweet(text):
    text = prep.clean(text)
    table = str.maketrans('','','#@')
    return text.translate(table)

In [6]:
#df_tweets['full_text'] = df_tweets['full_text'].apply(lambda x: clean_tweet(x))

In [7]:
df_tweets['length'] = df_tweets['full_text_processed'].apply(lambda x: len([w for w in x.split()]))
df_tweets = df_tweets[df_tweets['length']>4]

In [8]:
df_tweets.head()

Unnamed: 0,full_text_processed,length
1264253979002843136,effect amphan south 24 parganas,5
1264253959918632960,field experience dealing cyclone ampan,5
1264253893632016384,dukkhor bishoye onek khoti hoyeche hoping best possible recovery asap shokol ke bolchi ektu patience rakhun shob thik hoye jabe government ha taken control seriously ei ‡¶¨‡¶æ‡¶Ç‡¶≤‡¶æ amar hanshe abar amphan bengalfightsamphan bengaltweet kolkata,34
1264253882580045824,economic prblm covid19 arising increasing nd kolkata amphan ha devastated wht doüëèüëè,12
1264253658763612160,drsjaishankar meaindia many day indian citizen resident west bengal suffer amp remain get stuck uk sirji flight 25 may london indiaplease help start london kolkata member huge crisis due cyclone amphan,31


## GPU approach using Transformers Pipeline

In [9]:
from transformers import pipeline

In [10]:
classifier = pipeline('zero-shot-classification', device=0)

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json from cache at /home/ubuntu/.cache/torch/transformers/a35b79dc26c2f371a0e19eae44d91c0a0281a5db09044517d2675703791ee3c5.746d7ef19ade685cd3ee03f131a96fab513947c26179546289ddf02a6ac683ce
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": false,
  "id2label": {
    "0": "contradiction",
    "1": "neutral",
    "2": "entailment"
  },
  "init_std":

In [11]:
TERMS = ['sympathy', 'complaint', 'hope', 'job', 'relief measures', 'compensation',
        'evacuation', 'income', 'ecosystem', 'government', 'corruption', 'news updates', 
        'volunteers', 'donation', 'mobile network', 'housing', 'farm', 'utilities', 
        'water supply', 'power supply', 'food supply', 'medical assistance', 'coronavirus', 
        'petition', 'poverty']

In [12]:
%timeit classifier(df_tweets['full_text_processed'][0], TERMS, multi_class=True)

44.2 ms ¬± 1.43 ms per loop (mean ¬± std. dev. of 7 runs, 1 loop each)


In [13]:
'''Method to get the labels for a tweet based on threshold specified'''
def get_all_labels(x, terms=TERMS):
    
    # Run model
    result = classifier(x, terms, multi_class=True)
    
    topics = []
    for label, score in zip(result['labels'], result['scores']):

        topics.append((label, np.round(score,2)))
            
    return topics

In [14]:
tqdm.pandas()
df_tweets['full_text_processed'].progress_apply(lambda x: get_all_labels(x, TERMS)).to_json('../models/zstc_labels.json', orient='index')

HBox(children=(FloatProgress(value=0.0, max=102314.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed


