Based on https://www.kaggle.com/michalashkenazi/covid19-text-classification-nlp-bert-tensorflow

In [3]:
pip install tensorflow-text

In [4]:
pip install -q tf-models-official

In [8]:
# -- Import Libraries -- 
import os
import numpy as np
import random

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import model_selection
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix

from official.nlp import optimization
from nltk.corpus import stopwords

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text
from tensorflow.keras import layers, losses, preprocessing

tf.get_logger().setLevel('ERROR')

In [6]:
SEED = 42

In [9]:
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)

In [10]:
!wget https://github.com/Pittawat2542/krathu-500/raw/main/labeled/comments-small-labeled.csv

In [11]:
# -- Global Variables -- 
BATCH_SIZE = 128
EPOCHS = 16
LEARNING_RATE = 1e-05 #small gradient steps to prevent forgetting in transfer learning.

tfhub_handle_encoder = 'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1'
tfhub_handle_preprocess = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3'

In [13]:
# -- Load Data -- 
train_data = pd.read_csv('comments-small-labeled.csv')

train_data.head()

In [15]:
# -- Split Data to train, validation and test -- 
train_X, test_X, train_y, test_y = model_selection.train_test_split(train_data.text.astype(str),
                                                                  train_data.class_label.astype(str), 
                                                                  test_size=0.15, 
                                                                  random_state = SEED)

In [16]:
# -- convert labels to one hot --
label_encoder = LabelEncoder()

vec = label_encoder.fit_transform(train_y)
train_y = tf.keras.utils.to_categorical(vec)

vec = label_encoder.fit_transform(test_y)
test_y = tf.keras.utils.to_categorical(vec)

In [20]:
# -- Creating the Model for Fine Tuning -- 
def bert_text_classification():

    # - text input -
    text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='text')
        
    # - preprocessing layer - 
    preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='preprocessing')
        
    # - encoding - 
    encoder_inputs = preprocessing_layer(text_input)
    encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_encoder')
        
    # - output -
    outputs = encoder(encoder_inputs)
        
    # - classifier layer -
    net = outputs['pooled_output']
    net = tf.keras.layers.Dropout(0.2)(net)
    net = tf.keras.layers.Dense(3, activation='softmax', name='classifier')(net)
    
    model = tf.keras.Model(text_input, net)
    return model
        
model = bert_text_classification()     

In [21]:
# -- Loss -- 
loss = tf.keras.losses.CategoricalCrossentropy()

# -- Optimizer -- 
# will use the same optimizer that BERT was originally trained with: the "Adaptive Moments" (Adam). 
train_data_size = len(train_X)
steps_per_epoch = int(train_data_size/BATCH_SIZE)
num_train_steps = steps_per_epoch * EPOCHS
num_warmup_steps = int(0.1*num_train_steps/BATCH_SIZE)

optimizer = optimization.create_optimizer(init_lr=LEARNING_RATE,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')

# -- compile the model --
model.compile(optimizer=optimizer,
              loss=loss,
              metrics=['accuracy'])

In [None]:
# -- Fine Tuning the Model --
history = model.fit(x=train_X,
                    y=train_y,
                    validation_split=0.2,
                    epochs=EPOCHS,
                    verbose=1,
                    batch_size=BATCH_SIZE)

In [None]:
# -- Testing --
loss, acc = model.evaluate(x=test_X,
                           y=test_y)
print("test loss: ", loss, ", test acc: ", 100*acc, "%")