# BERT for disaster tweets classification

This notebook is meant to be run in Google collab on the gpu because of the computational cost of fine-tuning the model.

# Setup

## Installations

In [1]:
!pip install -q -U "tensorflow-text==2.9.*"

[K     |████████████████████████████████| 4.6 MB 15.2 MB/s 
[?25h

In [2]:
!pip install -q tf-models-official

[K     |████████████████████████████████| 2.4 MB 13.6 MB/s 
[K     |████████████████████████████████| 118 kB 75.6 MB/s 
[K     |████████████████████████████████| 238 kB 75.5 MB/s 
[K     |████████████████████████████████| 43 kB 2.1 MB/s 
[K     |████████████████████████████████| 1.1 MB 59.5 MB/s 
[K     |████████████████████████████████| 662 kB 74.3 MB/s 
[K     |████████████████████████████████| 352 kB 80.2 MB/s 
[K     |████████████████████████████████| 2.3 MB 52.2 MB/s 
[K     |████████████████████████████████| 588.3 MB 19 kB/s 
[K     |████████████████████████████████| 38.2 MB 1.1 MB/s 
[K     |████████████████████████████████| 5.8 MB 56.1 MB/s 
[K     |████████████████████████████████| 1.3 MB 63.1 MB/s 
[K     |████████████████████████████████| 6.0 MB 67.5 MB/s 
[K     |████████████████████████████████| 439 kB 74.8 MB/s 
[K     |████████████████████████████████| 1.7 MB 61.2 MB/s 
[?25h  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [3]:
!pip install textacy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting textacy
  Downloading textacy-0.12.0-py3-none-any.whl (208 kB)
[K     |████████████████████████████████| 208 kB 13.8 MB/s 
Collecting pyphen>=0.10.0
  Downloading pyphen-0.13.2-py3-none-any.whl (2.0 MB)
[K     |████████████████████████████████| 2.0 MB 68.9 MB/s 
Collecting jellyfish>=0.8.0
  Downloading jellyfish-0.9.0.tar.gz (132 kB)
[K     |████████████████████████████████| 132 kB 73.9 MB/s 
[?25hCollecting cytoolz>=0.10.1
  Downloading cytoolz-0.12.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 68.2 MB/s 
Building wheels for collected packages: jellyfish
  Building wheel for jellyfish (setup.py) ... [?25l[?25hdone
  Created wheel for jellyfish: filename=jellyfish-0.9.0-cp38-cp38-linux_x86_64.whl size=70639 sha256=de62fa79b9f0cbc4c377134e9c2c9a495a29e3e79e22e281216349acfc3afc18
  Stored in directory:

## Imports

In [4]:
import os

import numpy as np
import matplotlib.pyplot as plt

import pandas as pd
import html 
import re
from textacy import preprocessing
from functools import partial

import tensorflow as tf
import tensorflow_models as tfm
import tensorflow_hub as hub
import tensorflow_datasets as tfds
tfds.disable_progress_bar()

In [5]:
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12"
tf.io.gfile.listdir(gs_folder_bert)

['bert_config.json',
 'bert_model.ckpt.data-00000-of-00001',
 'bert_model.ckpt.index',
 'vocab.txt']

# Load data

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
train = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/disaster_tweets_data/train.csv")
test = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/disaster_tweets_data/test.csv")

# Pre-process train data

In [8]:
# Replace missing values with an empty string.

train["location"] = train["location"].fillna("")
train["keyword"] = train["keyword"].fillna("")

In [9]:
# define noise removal function

# define clean function
# add / remove any line if necessary
def clean(text):
    # convert html escapes like &amp; by their plain-text representation
    text = html.unescape(text) 
    
    # subsitute tags like <tab> by spaces in the specified text or remove them
    text = re.sub(r'<[^<>]*>', ' ', text)
    
    # subsitute markdown URLs like [Some text](https://....)
    text = re.sub(r'\[([^\[\]]*)\]\([^\(\)]*\)', r'\1', text)
    
    # subsitute text or code in brackets like [0]
    text = re.sub(r'\[[^\[\]]*\]', ' ', text)
    
    # subsitute standalone sequences of specials, matches &# but NOT #hashtag
    text = re.sub(r'(?:^|\s)[&#<>{}\[\]+|\\:-]{1,}(?:\s|$)', ' ', text)
    
    # subsitute standalone sequences of hyphens like --- or ==
    text = re.sub(r'(?:^|\s)[\-=\+]{2,}(?:\s|$)', ' ', text)
    
    # sequences of white spaces
    text = re.sub(r'\s+', ' ', text)
    
    # remove stock market tickers like $GE
    text = re.sub(r'\$\w*', '', text)  
    
    # remove old style retweet text "RT"
    text = re.sub(r'RT[\s]+', '', text)        
    text = re.sub(r'DT[\s]+', '', text)   
    
    # remove hashtags
    text = re.sub(r'#', '', text)
    
    return text.strip()

In [10]:
# Create cleaning pipeline.
preproc = preprocessing.make_pipeline(
    
    # join words split by a hyphen or line break
    preprocessing.normalize.hyphenated_words,
    
    # subsitute fancy quatation marks with an ASCII equivalent
    preprocessing.normalize.quotation_marks,
    
    # normalize unicode characters in text into canonical forms
    preprocessing.normalize.unicode,
    
    # remove any accents character in text by replacing them with ASCII equivalents or removing them entirely
    preprocessing.remove.accents,
    
    
    # remove all email addresses in text 
    partial(preprocessing.replace.emails, repl= ""), # or _EMAIL_
    
    # remove all phone numbers in text 
    partial(preprocessing.replace.phone_numbers, repl=""), # or _PhoneNumber_
    
    # remove all URLs in text 
    partial(preprocessing.replace.urls, repl= ""), # or _URL_
    
    # remove all (Twitter-style) user handles in text 
    partial(preprocessing.replace.user_handles, repl=""), # or _HANDLE_
    
    # Replace all hashtags in text with repl.
    #partial(preprocessing.replace.hashtags, repl="_HASTAG_"),
    
    ### TEST ### Enable it only before generating tokens for word clouds
    partial(preprocessing.replace.numbers, repl=""),
    
    # remove HTML tags from text
    preprocessing.remove.html_tags,
    
    # remove text within curly {}, square [], and/or round () brackets
    preprocessing.remove.brackets,

    # replace specific set of punctuation marks with whitespace
    partial(preprocessing.remove.punctuation, only=[ ",", ":", ";", "/", " ","(","@"]),
    
    # Replace all currency symbols in text with repl
    preprocessing.replace.currency_symbols,
    
    # replace all emoji and pictographs in text with repl.
    preprocessing.replace.emojis,
    
 )

In [11]:
train['text_c'] = train['text'].apply(clean)
train["clean_text"] = train["text_c"].apply(preproc)

train['keyword_c'] = train['keyword'].apply(clean)
train["clean_keyword"] = train["keyword_c"].apply(preproc)

train['location_c'] = train['location'].apply(clean)
train["clean_location"] = train["location_c"].apply(preproc)

train.sample(5, random_state=42)

Unnamed: 0,id,keyword,location,text,target,text_c,clean_text,keyword_c,clean_keyword,location_c,clean_location
2644,3796,destruction,,So you have a new weapon that can cause un-ima...,1,So you have a new weapon that can cause un-ima...,So you have a new weapon that can cause un-ima...,destruction,destruction,,
2227,3185,deluge,,The f$&amp;@ing things I do for #GISHWHES Just...,0,The f&@ing things I do for GISHWHES Just got s...,The f& things I do for GISHWHES Just got soake...,deluge,deluge,,
5448,7769,police,UK,DT @georgegalloway: RT @Galloway4Mayor: ÛÏThe...,1,@georgegalloway: @Galloway4Mayor: ÛÏThe CoL p...,UIThe CoL police can catch a pickpocket in L...,police,police,UK,UK
132,191,aftershock,,Aftershock back to school kick off was great. ...,0,Aftershock back to school kick off was great. ...,Aftershock back to school kick off was great. ...,aftershock,aftershock,,
6845,9810,trauma,"Montgomery County, MD",in response to trauma Children of Addicts deve...,0,in response to trauma Children of Addicts deve...,in response to trauma Children of Addicts deve...,trauma,trauma,"Montgomery County, MD",Montgomery County MD


In [12]:
train['clean_joined_features'] = train["clean_text"] + train["clean_keyword"] + train["clean_location"]

## Formatting the data for BERT

In [13]:
# BERT tokenizer.

tokenizer = tfm.nlp.layers.FastWordpieceBertTokenizer(
    vocab_file=os.path.join(gs_folder_bert, "vocab.txt"),
    lower_case=True)

In [14]:
# BERT packer for formatting the inputs.

max_seq_length = 128

packer = tfm.nlp.layers.BertPackInputs(
    seq_length=max_seq_length,
    special_tokens_dict = tokenizer.get_special_tokens_dict())

In [15]:
# Class for pre-processing the text data

class BertInputProcessor(tf.keras.layers.Layer):
  def __init__(self, tokenizer, packer):
    super().__init__()
    self.tokenizer = tokenizer
    self.packer = packer

  def call(self, text):
    text = self.tokenizer(text)

    packed = self.packer([text])
    return packed

In [16]:
bert_input_processor = BertInputProcessor(tokenizer, packer)

In [17]:
# Example inputs for example prediction.
example_inputs = bert_input_processor(list(train['clean_joined_features'])[:10])

In [18]:
inputs = bert_input_processor(list(train['clean_joined_features']))

In [19]:
train_dataset = tf.data.Dataset.from_tensor_slices((dict(inputs),
                                                    list(train['target'])))

# Fine-tuning BERT

In [20]:
import json

bert_config_file = os.path.join(gs_folder_bert, "bert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())
config_dict

{'attention_probs_dropout_prob': 0.1,
 'hidden_act': 'gelu',
 'hidden_dropout_prob': 0.1,
 'hidden_size': 768,
 'initializer_range': 0.02,
 'intermediate_size': 3072,
 'max_position_embeddings': 512,
 'num_attention_heads': 12,
 'num_hidden_layers': 12,
 'type_vocab_size': 2,
 'vocab_size': 30522}

In [21]:
encoder_config = tfm.nlp.encoders.EncoderConfig({
    'type':'bert',
    'bert': config_dict
})

In [22]:
bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)
bert_encoder

<official.nlp.modeling.networks.bert_encoder.BertEncoder at 0x7f4d9c43f1f0>

In [23]:
bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)

In [24]:
# Example logit predictions.
bert_classifier(
    example_inputs, training=True).numpy()[:10]

array([[ 0.23425016, -0.3779763 ],
       [-0.4125249 , -0.70295995],
       [-0.38736963, -0.5819932 ],
       [ 0.06214715, -0.91500294],
       [-1.0778369 , -1.131032  ],
       [-0.73366123, -1.0632837 ],
       [-0.7215802 , -0.8319583 ],
       [-0.6494328 , -0.9795061 ],
       [-0.7649241 , -0.3588312 ],
       [-0.17873794, -1.1142362 ]], dtype=float32)

## BERT setup

In [25]:
checkpoint = tf.train.Checkpoint(encoder=bert_encoder)
checkpoint.read(
    os.path.join(gs_folder_bert, 'bert_model.ckpt')).assert_consumed()

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7f4d9c164d30>

In [26]:
# Set up epochs and steps
epochs = 2
batch_size = 32
eval_batch_size = 32

train_data_size = len(train['clean_joined_features'])
steps_per_epoch = int(train_data_size / batch_size)
num_train_steps = steps_per_epoch * epochs
warmup_steps = int(0.1 * num_train_steps)
initial_learning_rate=2e-5

In [27]:
linear_decay = tf.keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=initial_learning_rate,
    end_learning_rate=0,
    decay_steps=num_train_steps)

In [28]:
warmup_schedule = tfm.optimization.lr_schedule.LinearWarmup(
    warmup_learning_rate = 0,
    after_warmup_lr_sched = linear_decay,
    warmup_steps = warmup_steps
)

In [29]:
optimizer = tf.keras.optimizers.experimental.Adam(
    learning_rate = warmup_schedule)

In [30]:
metrics = [tf.keras.metrics.SparseCategoricalAccuracy('accuracy', dtype=tf.float32)]
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bert_classifier.compile(
    optimizer=optimizer,
    loss=loss,
    metrics=metrics)

In [31]:
bert_classifier.fit(
      train_dataset.shuffle(len(train_dataset)).batch(batch_size),
      batch_size=32,
      epochs=epochs)

Epoch 1/2
Epoch 2/2


<keras.callbacks.History at 0x7f4d9c0db280>

# Get predictions

## Pre-processing the test data

In [32]:
test["location"] = test["location"].fillna("")
test["keyword"] = test["keyword"].fillna("")

In [33]:
test['text_c'] = test['text'].apply(clean)
test["clean_text"] = test["text_c"].apply(preproc)

test['keyword_c'] = test['keyword'].apply(clean)
test["clean_keyword"] = test["keyword_c"].apply(preproc)

test['location_c'] = test['location'].apply(clean)
test["clean_location"] = test["location_c"].apply(preproc)

train.sample(5, random_state=42)

Unnamed: 0,id,keyword,location,text,target,text_c,clean_text,keyword_c,clean_keyword,location_c,clean_location,clean_joined_features
2644,3796,destruction,,So you have a new weapon that can cause un-ima...,1,So you have a new weapon that can cause un-ima...,So you have a new weapon that can cause un-ima...,destruction,destruction,,,So you have a new weapon that can cause un-ima...
2227,3185,deluge,,The f$&amp;@ing things I do for #GISHWHES Just...,0,The f&@ing things I do for GISHWHES Just got s...,The f& things I do for GISHWHES Just got soake...,deluge,deluge,,,The f& things I do for GISHWHES Just got soake...
5448,7769,police,UK,DT @georgegalloway: RT @Galloway4Mayor: ÛÏThe...,1,@georgegalloway: @Galloway4Mayor: ÛÏThe CoL p...,UIThe CoL police can catch a pickpocket in L...,police,police,UK,UK,UIThe CoL police can catch a pickpocket in L...
132,191,aftershock,,Aftershock back to school kick off was great. ...,0,Aftershock back to school kick off was great. ...,Aftershock back to school kick off was great. ...,aftershock,aftershock,,,Aftershock back to school kick off was great. ...
6845,9810,trauma,"Montgomery County, MD",in response to trauma Children of Addicts deve...,0,in response to trauma Children of Addicts deve...,in response to trauma Children of Addicts deve...,trauma,trauma,"Montgomery County, MD",Montgomery County MD,in response to trauma Children of Addicts deve...


In [34]:
test['clean_joined_features'] = test["clean_text"] + test["clean_keyword"] + test["clean_location"]

In [35]:
test_inputs = bert_input_processor(list(test['clean_joined_features']))

In [36]:
test_dataset = tf.data.Dataset.from_tensor_slices(dict(test_inputs))

## Get predictions

In [37]:
predictions = bert_classifier.predict(test_dataset.batch(batch_size))



In [38]:
predictions

array([[-1.3821484 ,  0.9218454 ],
       [-2.1851223 ,  1.8749597 ],
       [-2.0271132 ,  1.7456344 ],
       ...,
       [-2.7556977 ,  2.787077  ],
       [-1.1667815 ,  0.65156657],
       [-0.81265354,  0.9463611 ]], dtype=float32)

In [39]:
def get_prediction_labels(predictions):
  labels = []
  for logit in predictions:
    if logit[0] > logit[1]:
      labels.append(0)
    else:
      labels.append(1)
  return labels

In [40]:
labels = get_prediction_labels(predictions)

In [41]:
preds_df_bert = pd.DataFrame()
preds_df_bert['id'] = test['id']
preds_df_bert['target'] = labels

In [42]:
preds_df_bert.to_csv('submission.csv', index=False)

# References
* https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/classify_text_with_bert.ipynb
* https://www.tensorflow.org/tfmodels/nlp/fine_tune_bert#train_the_model
* https://www.kaggle.com/code/romannowak/nlp-with-disaster-tweets-cleaning-tf-idf-and-bert/edit