# Zero Shot Text Classification Model
**Model Description** - Bart with a classification head trained on MNLI.

Sequences are posed as NLI premises and topic labels are turned into premises, i.e. business -> This text is about business.

In [1]:
import eland as ed
import pandas as pd
import numpy as np
import preprocessor as prep

from sklearn.ensemble import RandomForestClassifier
from src.models import ALZeroShotWrapper

#pd.options.display.max_columns = None
#pd.options.display.max_rows = None

ed_df = ed.read_es('localhost', 'twitter')

## Importing Data

In [2]:
df = ed_df[
    ['tweet_id', 'lang', 'full_text']
].to_pandas().fillna(np.nan)

In [3]:
tweet_vectors = pd.read_csv('../data/results/iwmi_tweet2vec.csv')

tweet_vectors['tweet_id'] = tweet_vectors['tweet_id'].astype(float).astype(int).astype(str)
df['tweet_id'] = df['tweet_id'].astype(float).astype(int).astype(str)

df = df.set_index('tweet_id').join(tweet_vectors.set_index('tweet_id'))

## Basic Tweet Preprocessing
- Remove URLs and reserved words (RTs)
- Remove # and @ symbols
- Remove tweets less than 4 tokens in length


In [4]:
## Set options for the tweet-preprocessor
prep.set_options(prep.OPT.URL, prep.OPT.RESERVED, prep.OPT.EMOJI, prep.OPT.SMILEY)

## Clean text and remove #,@ symbols
def clean_tweet(text):
    text = prep.clean(text)
    table = str.maketrans('','','#@')
    return text.translate(table)

In [5]:
df['full_text'] = df['full_text'].apply(lambda x: clean_tweet(x))

In [6]:
df['length'] = df['full_text'].apply(lambda x: len(x.split()))
df = df[df['length']>4]

In [7]:
df_train = df[df.lang == 'en']
X = df_train[[f'vec_{i}' for i in range(10)]].values
sequences = df_train.full_text.values
candidate_labels = [
    'resource availability', 
    'volunteers', 
    'power supply', 
    'relief measures', 
    'food supply', 
    'infrastructure', 
    'medical assistance', 
    'rescue', 
    'shelter', 
    'utilities', 
    'water supply', 
    'evacuation', 
    'government', 
    'crime violence', 
    'mobile network', 
    'sympathy', 
    'news updates', 
    'internet', 
    'grievance'
]


## Initialising the Model

In [8]:
rfc = RandomForestClassifier(random_state=0)
al_zeroshot = ALZeroShotWrapper(rfc, max_iter=100, n_initial=10, increment=10, random_state=0)

al_zeroshot.fit(X, sequences, candidate_labels)

Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


iter 0
iter 1
iter 2
iter 3
iter 4
iter 5
iter 6
iter 7
iter 8
iter 9
iter 10
iter 11
iter 12
iter 13
iter 14
iter 15
iter 16
iter 17
iter 18
iter 19
iter 20
iter 21
iter 22
iter 23
iter 24
iter 25
iter 26
iter 27
iter 28
iter 29
iter 30
iter 31
iter 32
iter 33
iter 34
iter 35
iter 36
iter 37
iter 38
iter 39
iter 40
iter 41
iter 42
iter 43
iter 44
iter 45
iter 46
iter 47
iter 48
iter 49
iter 50
iter 51
iter 52
iter 53
iter 54
iter 55
iter 56
iter 57
iter 58
iter 59
iter 60
iter 61
iter 62
iter 63
iter 64
iter 65
iter 66
iter 67
iter 68
iter 69
iter 70
iter 71
iter 72
iter 73
iter 74
iter 75
iter 76
iter 77
iter 78
iter 79
iter 80
iter 81
iter 82
iter 83
iter 84
iter 85
iter 86
iter 87
iter 88
iter 89
iter 90
iter 91
iter 92
iter 93
iter 94
iter 95
iter 96
iter 97
iter 98
iter 99


ALZeroShotWrapper(classifier=RandomForestClassifier(random_state=0),
                  increment=10, max_iter=100, n_initial=10, random_state=0)

In [26]:
zero_shot_labels = al_zeroshot.y

rfc_labels = al_zeroshot\
    .label_encoder\
    .inverse_transform(
        al_zeroshot\
            .classifier_\
            .predict(X)\
            .astype(int)
    )

In [31]:
zs_mask = ~np.isnan(zero_shot_labels)
ground_truth_labels = al_zeroshot\
    .label_encoder\
    .inverse_transform(
        zero_shot_labels[zs_mask]\
            .astype(int)
    )
rfc_zs_preds = rfc_labels[zs_mask]

In [34]:
# Checking if the random forest was well trained on pseudo-ground-truth data
(ground_truth_labels == rfc_zs_preds).all()

True

## Classify the entire dataset using the trained random forest

In [35]:
le = al_zeroshot.label_encoder
rfc = al_zeroshot.classifier_

In [43]:
X_all = df[[f'vec_{i}' for i in range(10)]].values
y_pred = le.inverse_transform(rfc.predict(X_all).astype(int))

In [44]:
df['label_prediction'] = y_pred
df[['full_text', 'label_prediction']]

Unnamed: 0_level_0,full_text,label_prediction
tweet_id,Unnamed: 1_level_1,Unnamed: 2_level_1
1263737818767130624,ExSecular DD ki gaan uda di Amphan ne. Ab bhik...,grievance
1263737812576202752,Speaking on the situation in the wake of Cyclo...,relief measures
1263737811460636672,Speaking on the situation in the wake of Cyclo...,relief measures
1263737810105839616,Speaking on the situation in the wake of Cyclo...,relief measures
1263737809757749248,The damage to the Sunderbans by AmphanCyclone ...,grievance
...,...,...
1262711874832994304,Cyclone Amphan has yet to make landfall but is...,grievance
1254077710600679424,UN: 'Our fear is tremendous loss of lifewomen ...,evacuation
1262672672946900992,MANITSaysNoExams PromoteFinalYearStudents MHRD...,relief measures
1262778121700495360,I still feel jittery everytime I think of thos...,grievance


In [46]:
df.to_csv('provisional_topic_modelling.csv')