# Final Project: Disaster Tweet Classification

In [9]:
__author__ = "Kevin Guo, Pranav Sriram, Raymond Yao"
__version__ = "CS224u, Stanford, Spring 2021"

## Data Pre-Processing

In [10]:
import numpy as np
import pandas as pd
import re
from transformers import BertModel, BertTokenizer
from transformers import AutoTokenizer
import utils
import torch
import torch.nn as nn
from torch_shallow_neural_classifier import TorchShallowNeuralClassifier
from datasets import Dataset
from datasets import load_dataset
#from datasets import train_test_split
from transformers import TrainingArguments
from transformers import Trainer
from torch.utils.data import DataLoader
from transformers import AdamW
from transformers import get_scheduler
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split

In [11]:
# Set random seeds
utils.fix_random_seeds()

In [12]:
weights_name = 'bert-base-cased'

In [13]:
bert_tokenizer = BertTokenizer.from_pretrained(weights_name)

In [14]:
bert_model = BertModel.from_pretrained(weights_name)

In [15]:
def bert_phi(text):
    input_ids = bert_tokenizer.encode(text, add_special_tokens=True)
    X = torch.tensor([input_ids])
    with torch.no_grad():
        reps = bert_model(X)
        return reps.last_hidden_state.squeeze(0).numpy()

In [16]:
def bert_classifier_phi(text):
    reps = bert_phi(text)
    #return reps.mean(axis=0)  # Another good, easy option.
    return reps[0]

In [17]:
all_data = pd.read_csv('tweets_mod_copy.csv')
all_data['text'] = all_data['text'].apply(lambda x: re.sub(r'https?\S+', '', x))
train, dev, test = np.split(all_data.sample(frac=1, random_state=42), [int(.8*len(all_data)), int(.9*len(all_data))])

In [18]:
print(train)

                                                    text  labels
3495   How many illegal buildings should be demolishe...       0
5461                     Who’s fatality is this tho ????       0
9794   #OnThisDay 2018 Chinese state media confirmed ...       1
11105  With any luck you will miss the windstorm on e...       0
1803   Inferno on Black Friday 1939: 71 deaths, 3,700...       1
...                                                  ...     ...
2196   go ahead and make a playlist with your name. g...       0
8561   Ruckelshaus, Sweeney and DDT – rescued from th...       0
11236  😂We learned a long time ago why all major bank...       0
4285   5,000 feral camels culled in drought-hit Austr...       1
8569   Another rescued mumma koala with her little ne...       1

[9096 rows x 2 columns]


In [19]:
X_str_train = train.text.values
print(len(X_str_train))
y_train = train.labels.values

X_str_dev = dev.text.values
print(len(X_str_dev))
y_dev = dev.labels.values

9096
1137


In [35]:
%time X_train = [bert_classifier_phi(text) for text in X_str_train]

CPU times: user 13min 14s, sys: 16.3 s, total: 13min 31s
Wall time: 13min 24s


In [36]:
%time X_dev = [bert_classifier_phi(text) for text in X_str_dev]

CPU times: user 1min 41s, sys: 2.39 s, total: 1min 44s
Wall time: 1min 43s


In [37]:
class LogisticRegression(torch.nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
        
    def forward(self, x):
        y_pred = nn.Sigmoid(self.linear(x))
        return y_pred
    
model = LogisticRegression()

In [34]:
# %time _ = model.fit(X_train, y_train)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = torch.nn.BCELoss(size_average=True)
X_str_train = torch.tensor(X_train)

for epoch in range(20):
    model.train()
    optimizer.zero_grad()    # Forward pass
    y_pred = model(X_str_train)    # Compute Loss
    loss = criterion(y_pred, y_train)    # Backward pass
    loss.backward()
    optimizer.step()

NameError: name 'X_train' is not defined

In [94]:
preds = model.predict(X_dev)

In [95]:
from sklearn.metrics import classification_report
print(classification_report(y_dev, preds, digits=3))

              precision    recall  f1-score   support

           0      0.920     0.950     0.935       920
           1      0.754     0.650     0.698       217

    accuracy                          0.893      1137
   macro avg      0.837     0.800     0.816      1137
weighted avg      0.888     0.893     0.890      1137

