# Zero Shot Text Classification Model
**Model Description** - Bart with a classification head trained on MNLI.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. business -> This text is about business.

In [1]:
import eland as ed
import pandas as pd
import numpy as np
import preprocessor as prep

from sklearn.ensemble import RandomForestClassifier
from src.models import ALZeroShotWrapper

#pd.options.display.max_columns = None
#pd.options.display.max_rows = None

ed_df = ed.read_es('localhost', 'twitter')

## Importing Data

In [2]:
df = ed_df[
    ['tweet_id', 'lang', 'full_text']
].to_pandas().fillna(np.nan)

In [3]:
tweet_vectors = pd.read_csv('../data/results/iwmi_tweet2vec.csv')

tweet_vectors['tweet_id'] = tweet_vectors['tweet_id'].astype(float).astype(int).astype(str)
df['tweet_id'] = df['tweet_id'].astype(float).astype(int).astype(str)

df = df.set_index('tweet_id').join(tweet_vectors.set_index('tweet_id'))

## Basic Tweet Preprocessing
- Remove URLs and reserved words (RTs)
- Remove # and @ symbols
- Remove tweets less than 4 tokens in length


In [4]:
## Set options for the tweet-preprocessor
prep.set_options(prep.OPT.URL, prep.OPT.RESERVED, prep.OPT.EMOJI, prep.OPT.SMILEY)

## Clean text and remove #,@ symbols
def clean_tweet(text):
    text = prep.clean(text)
    table = str.maketrans('','','#@')
    return text.translate(table)

In [5]:
df['full_text'] = df['full_text'].apply(lambda x: clean_tweet(x))

In [6]:
df['length'] = df['full_text'].apply(lambda x: len(x.split()))
df = df[df['length']>4]

In [7]:
df_train = df[df.lang == 'en']
X = df_train[[f'vec_{i}' for i in range(10)]].values
sequences = df_train.full_text.values
candidate_labels = [
    'resource availability', 
    'volunteers', 
    'power supply', 
    'relief measures', 
    'food supply', 
    'infrastructure', 
    'medical assistance', 
    'rescue', 
    'shelter', 
    'utilities', 
    'water supply', 
    'evacuation', 
    'government', 
    'crime violence', 
    'mobile network', 
    'sympathy', 
    'news updates', 
    'internet', 
    'grievance'
]


## Initialising the Model

In [8]:
rfc = RandomForestClassifier(random_state=0)
al_zeroshot = ALZeroShotWrapper(rfc, max_iter=100, n_initial=10, increment=10, random_state=0)

al_zeroshot.fit(X, sequences, candidate_labels)

Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


iter 0


ValueError: could not convert string to float: 'grievance'

## Example on how to use the Model
The sequence (text) is posed as a NLI Premise and the label as a hypothesis

In [10]:
# pose sequence as a NLI premise and label (politics) as a hypothesis
premise = 'Who are you voting for in 2020?'
hypothesis = 'This text is about politics.'

# run through model pre-trained on MNLI
input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
logits = model(input_ids)[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
true_prob = probs[:,1].item() * 100
print(f'Probability that the label is true: {true_prob:0.2f}%')

Probability that the label is true: 98.08%


## Defining the Labels and Threshold for classification

In [19]:
TERMS = ['resource availability', 'volunteers', 'power supply', 'relief measures', 'food supply', 'infrastructure', 'medical assistance', 'rescue', 'shelter', 'utilities', 'water supply', 'evacuation', 'government', 'crime violence', 'mobile network', 'sympathy', 'news updates', 'internet', 'grievance']

HYPOTHESES = ['This text is about '+x for x in TERMS]
THRESHOLD = 50

In [12]:
'''Method to get the labels for a tweet based on threshold specified'''
def get_labels(premise, threshold=THRESHOLD):
    topics = []
    for idx, hypothesis in enumerate(HYPOTHESES):
        # run through model pre-trained on MNLI
        input_ids = tokenizer.encode(premise, hypothesis, return_tensors='pt')
        logits = model(input_ids)[0]

        # we throw away "neutral" (dim 1) and take the probability of
        # "entailment" (2) as the probability of the label being true 
        entail_contradiction_logits = logits[:,[0,2]]
        probs = entail_contradiction_logits.softmax(dim=1)
        true_prob = probs[:,1].item() * 100

        if true_prob>=threshold:
            topics.append((TERMS[idx], np.round(true_prob,2)))
    
    return topics

In [13]:
tqdm.pandas()
df_tweets['labels'] = df_tweets['full_text'].progress_apply(lambda x: get_labels(x, THRESHOLD))