# Final Project: Disaster Tweet Classification

In [10]:
__author__ = "Kevin Guo, Pranav Sriram, Raymond Yao"
__version__ = "CS224u, Stanford, Spring 2021"

## Data Pre-Processing

In [18]:
import numpy as np
import pandas as pd
import re
from transformers import BertModel, BertTokenizer
from transformers import AutoTokenizer
import utils
import torch
import torch.nn as nn
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
from datasets import Dataset
from datasets import load_dataset
#from datasets import train_test_split
from transformers import TrainingArguments
from transformers import Trainer
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

In [19]:
# Set random seeds
utils.fix_random_seeds()

In [20]:
weights_name = 'bert-base-cased'

In [21]:
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)

In [22]:
bert_model = BertModel.from_pretrained(weights_name)

In [23]:
def bert_phi(text):
    input_ids = bert_tokenizer.encode(text, add_special_tokens=True)
    X = torch.tensor([input_ids])
    with torch.no_grad():
        reps = bert_model(X)
        return reps.last_hidden_state.squeeze(0).numpy()

In [24]:
def bert_classifier_phi(text):
    reps = bert_phi(text)
    #return reps.mean(axis=0)  # Another good, easy option.
    return reps[0]

In [26]:
all_data = pd.read_csv('tweets_mod_copy.csv')
all_data['text'] = all_data['text'].apply(lambda x: re.sub(r'https?\S+', '', x))
train, dev, test = np.split(all_data.sample(frac=1, random_state=42), [int(.8*len(all_data)), int(.9*len(all_data))])

In [27]:
print(train)

                                                    text  labels
3495   How many illegal buildings should be demolishe...       0
5461                     Who’s fatality is this tho ????       0
9794   #OnThisDay 2018 Chinese state media confirmed ...       1
11105  With any luck you will miss the windstorm on e...       0
1803   Inferno on Black Friday 1939: 71 deaths, 3,700...       1
...                                                  ...     ...
2196   go ahead and make a playlist with your name. g...       0
8561   Ruckelshaus, Sweeney and DDT – rescued from th...       0
11236  😂We learned a long time ago why all major bank...       0
4285   5,000 feral camels culled in drought-hit Austr...       1
8569   Another rescued mumma koala with her little ne...       1

[9096 rows x 2 columns]


In [29]:
X_str_train = train.text.values
print(len(X_str_train))
y_train = train.labels.values

X_str_dev = dev.text.values
print(len(X_str_dev))
y_dev = dev.labels.values

9096
1137


In [30]:
%time X_train = [bert_classifier_phi(text) for text in X_str_train]

In [31]:
%time X_dev = [bert_classifier_phi(text) for text in X_str_dev]

[array([ 2.92384684e-01,  8.42855051e-02, -3.28263521e-01, -1.88189179e-01,
        -2.57606447e-01, -1.11518152e-01,  2.74954259e-01, -1.31438494e-01,
        -9.31590348e-02, -1.23238420e+00, -1.14544034e-01,  3.60367239e-01,
        -2.00517222e-01, -9.22048986e-02, -3.27670008e-01,  1.07467011e-01,
         2.33953536e-01, -7.33315572e-03, -5.22653833e-02, -1.46168500e-01,
        -4.71140370e-02, -5.92244864e-02,  6.87659562e-01, -1.76097810e-01,
         2.68324733e-01,  2.44254053e-01,  2.30396837e-01,  3.92883062e-01,
        -1.04674958e-01,  8.26826096e-02,  1.21654958e-01,  2.91069653e-02,
        -1.80184364e-01,  6.77715838e-02, -1.36508659e-01, -4.65901159e-02,
        -9.28417146e-02, -4.06948477e-01, -9.13233384e-02, -3.34055156e-01,
        -4.42021817e-01,  1.23129174e-01,  5.51900983e-01,  2.16495246e-05,
         2.91171312e-01, -3.88045281e-01, -1.25025526e-01, -9.48438793e-02,
        -5.87927252e-02,  2.04913065e-01,  8.65895599e-02,  2.45752543e-01,
        -9.3

In [33]:
model = TorchShallowNeuralClassifier(
    early_stopping=True,
    hidden_dim=300)

In [35]:
%time _ = model.fit(X_train, y_train)

Stopping after epoch 20. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 0.34260573983192444

CPU times: user 165 ms, sys: 23.8 ms, total: 189 ms
Wall time: 165 ms


In [36]:
preds = model.predict(X_dev)

In [39]:
from sklearn.metrics import classification_report
print(classification_report(y_dev, preds, digits=3))

              precision    recall  f1-score   support

           0      0.846     0.985     0.910       201
           1      0.812     0.265     0.400        49

    accuracy                          0.844       250
   macro avg      0.829     0.625     0.655       250
weighted avg      0.840     0.844     0.810       250

