# Zero Shot Text Classification Model
**Model Description** - Bart with a classification head trained on MNLI.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. business -> This text is about business.

In [5]:
import eland as ed
from eland.conftest import *
import pandas as pd
import seaborn as sns
import preprocessor as prep
import matplotlib.pyplot as plt
import string
from tqdm.auto import tqdm
from transformers import BartForSequenceClassification, BartTokenizer

pd.set_option('display.max_colwidth', -1)

## Importing the Data from Elasticsearch

In [2]:
ed_df = ed.DataFrame('localhost', 'twitter', columns=['full_text_processed'])

# defining the full-text query we need: Retrieving records for full_text_processed with the condition is_retweet=False and is_quote_status=False
query_unique = {
    "bool": {
        "must": {
            "term":{"is_retweet":"false"},
        },
        "filter": {
            "term":{"is_quote_status":"false"}
        },
    }
}
# using full-text search capabilities with Eland:
df_ed = ed_df.es_query(query_unique)
df_tweets = df_ed.to_pandas()

In [3]:
df_tweets.shape

(108364, 1)

In [6]:
df_tweets.head()

Unnamed: 0,full_text_processed
1262961673708675072,live cyclone amphan map tracking storm’s path
1262961660932894720,nyt live cyclone amphan map tracking storm’s path
1262961652359729152,live news update super cyclone amphan amphanupdate cycloneamphan amphancyclone cycloneamphanupdate 120 km nearly south paradip odisha 200 km southsouthwest digha west bengal 360 km southsouthwest khepupara bangladesh
1262960808742522880,cyclone ampan came closer live super cyclone amphan update pradip odisha 120 km orissa digha west bengal 200 km west bengal khepupara bangladesh 360 km bangladesh bangladesh bangla west bengal cyclone amphan
1262937945214005248,live news update super cyclone amphan amphanupdate cycloneamphan amphancyclone cycloneamphanupdate 125 km nearly south paradip odisha 225 km southsouthwest digha west bengal 380 km southsouthwest khepupara bangladesh


## Basic Tweet Preprocessing
- Remove URLs and reserved words (RTs)
- Remove # and @ symbols
- Remove tweets less than 4 tokens in length


In [7]:
## Set options for the tweet-preprocessor
prep.set_options(prep.OPT.URL, prep.OPT.RESERVED, prep.OPT.EMOJI, prep.OPT.SMILEY)

## Clean text and remove #,@ symbols
def clean_tweet(text):
    text = prep.clean(text)
    table = str.maketrans('','','#@')
    return text.translate(table)

In [8]:
#df_tweets['full_text'] = df_tweets['full_text'].apply(lambda x: clean_tweet(x))

In [9]:
df_tweets['length'] = df_tweets['full_text_processed'].apply(lambda x: len([w for w in x.split()]))
df_tweets = df_tweets[df_tweets['length']>4]

In [10]:
df_tweets.head()

Unnamed: 0,full_text_processed,length
1262961673708675072,live cyclone amphan map tracking storm’s path,7
1262961660932894720,nyt live cyclone amphan map tracking storm’s path,8
1262961652359729152,live news update super cyclone amphan amphanupdate cycloneamphan amphancyclone cycloneamphanupdate 120 km nearly south paradip odisha 200 km southsouthwest digha west bengal 360 km southsouthwest khepupara bangladesh,27
1262960808742522880,cyclone ampan came closer live super cyclone amphan update pradip odisha 120 km orissa digha west bengal 200 km west bengal khepupara bangladesh 360 km bangladesh bangladesh bangla west bengal cyclone amphan,32
1262937945214005248,live news update super cyclone amphan amphanupdate cycloneamphan amphancyclone cycloneamphanupdate 125 km nearly south paradip odisha 225 km southsouthwest digha west bengal 380 km southsouthwest khepupara bangladesh,27


## GPU approach using Transformers Pipeline

In [15]:
from transformers import pipeline

In [16]:
classifier = pipeline('zero-shot-classification', device=0)

loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json from cache at /home/ubuntu/.cache/torch/transformers/a35b79dc26c2f371a0e19eae44d91c0a0281a5db09044517d2675703791ee3c5.746d7ef19ade685cd3ee03f131a96fab513947c26179546289ddf02a6ac683ce
Model config BartConfig {
  "_num_labels": 3,
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "dropout": 0.1,
  "encoder_attention_heads": 16,
  "encoder_ffn_dim": 4096,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 12,
  "eos_token_id": 2,
  "extra_pos_embeddings": 2,
  "force_bos_token_to_be_generated": false,
  "id2label": {
    "0": "contradiction",
    "1": "neutral",
    "2": "entailment"
  },
  "init_std":

In [11]:
TERMS = ['sympathy', 'complaint', 'hope', 'job', 'relief measures', 'compensation',
        'evacuation', 'income', 'ecosystem', 'government', 'corruption', 'news updates', 
        'volunteers', 'donation', 'mobile network', 'housing', 'farm', 'utilities', 
        'water supply', 'power supply', 'food supply', 'medical assistance', 'coronavirus', 
        'petition', 'poverty']

In [18]:
%timeit classifier(df_tweets['full_text_processed'][0], TERMS, multi_class=True)

46.4 ms ± 946 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [19]:
'''Method to get the labels for a tweet based on threshold specified'''
def get_all_labels(x, terms=TERMS):
    
    # Run model
    result = classifier(x, terms, multi_class=True)
    
    topics = []
    for label, score in zip(result['labels'], result['scores']):

        topics.append((label, np.round(score,2)))
            
    return topics

In [22]:
tqdm.pandas()
df_tweets['full_text_processed'].progress_apply(lambda x: get_all_labels(x, TERMS)).to_json('../models/zstc_labels.json', orient='index')

HBox(children=(FloatProgress(value=0.0, max=74289.0), HTML(value='')))


