# 20 newsgroup text classification with BERT finetuning

In this notebook, we'll use a pre-trained [BERT](https://arxiv.org/abs/1810.04805) model for text classification using TensorFlow 2 / Keras and HuggingFace's [Transformers](https://github.com/huggingface/transformers). This notebook is based on ["Predicting Movie Review Sentiment with BERT on TF Hub"](https://github.com/google-research/bert/blob/master/predicting_movie_reviews_with_bert_on_tf_hub.ipynb) by Google and ["BERT Fine-Tuning Tutorial with PyTorch"](https://mccormickml.com/2019/07/22/BERT-fine-tuning/) by Chris McCormick.

**Note that using a GPU with this notebook is highly recommended.**

First, the needed imports.

In [None]:
%matplotlib inline

import tensorflow as tf
from tensorflow.keras.utils import plot_model
from tensorflow.keras.callbacks import TensorBoard

from transformers import BertTokenizer, TFBertForSequenceClassification
from transformers import __version__ as transformers_version

from distutils.version import LooseVersion as LV

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

import io, sys, os, datetime

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

print('Using TensorFlow version:', tf.__version__,
      'Keras version:', tf.keras.__version__,
      'Transformers version:', transformers_version)
assert(LV(tf.__version__) >= LV("2.3.0"))

if len(tf.config.list_physical_devices('GPU')):
    from tensorflow.python.client import device_lib
    for d in device_lib.list_local_devices():
        if d.device_type == 'GPU':
            print('GPU:', d.physical_device_desc)
else:
    print('No GPU, using CPU instead.')

## 20 Newsgroups data set

Next we'll load the [20 Newsgroups](http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/news20.html) data set. 

The dataset contains 20000 messages collected from 20 different Usenet newsgroups (1000 messages from each group):

|[]()|[]()|[]()|[]()|
| --- | --- |--- | --- |
| alt.atheism           | soc.religion.christian   | comp.windows.x     | sci.crypt |               
| talk.politics.guns    | comp.sys.ibm.pc.hardware | rec.autos          | sci.electronics |              
| talk.politics.mideast | comp.graphics            | rec.motorcycles    | sci.space |                   
| talk.politics.misc    | comp.os.ms-windows.misc  | rec.sport.baseball | sci.med |                     
| talk.religion.misc    | comp.sys.mac.hardware    | rec.sport.hockey   | misc.forsale |

In [None]:
TEXT_DATA_DIR = "/media/data/20_newsgroup"

print('Processing text dataset')

texts = []  # list of text samples
labels_index = {}  # dictionary mapping label name to numeric id
labels = []  # list of label ids
for name in sorted(os.listdir(TEXT_DATA_DIR)):
    path = os.path.join(TEXT_DATA_DIR, name)
    if os.path.isdir(path):
        label_id = len(labels_index)
        labels_index[name] = label_id
        for fname in sorted(os.listdir(path)):
            if fname.isdigit():
                fpath = os.path.join(path, fname)
                args = {} if sys.version_info < (3,) else {'encoding': 'latin-1'}
                with open(fpath, **args) as f:
                    t = f.read()
                    i = t.find('\n\n')  # skip header
                    if 0 < i:
                        t = t[i:]
                    texts.append(t)
                labels.append(label_id)

labels = np.array(labels)
print('Found %s texts.' % len(texts))

We split the data into training, validation, and test sets using scikit-learn's `train_test_split()`.

In [None]:
TEST_SET = 4000

(texts_train, texts_test,
 labels_train, labels_test) = train_test_split(texts, labels, 
                                               test_size=TEST_SET,
                                               shuffle=True, random_state=42)

(texts_train, texts_valid,
 labels_train, labels_valid) = train_test_split(texts_train, labels_train, 
                                                shuffle=False,
                                                test_size=0.1)

print('Length of training texts:', len(texts_train), 'labels:', len(labels_train))
print('Length of validation texts:', len(texts_valid), 'labels:', len(labels_valid))
print('Length of test texts:', len(texts_test), 'labels:', len(labels_test))

## BERT

Next we specify the pre-trained BERT model we are going to use. The model `"bert-base-uncased"` is the lowercased "base" model (12-layer, 768-hidden, 12-heads, 110M parameters).

### Tokenization

We load the used vocabulary from the BERT model, and use the BERT tokenizer to convert the messages into tokens that match the data the BERT model was trained on.

In [None]:
BERTMODEL='bert-base-uncased'
CACHE_DIR='/media/data/transformers-cache/'

tokenizer = BertTokenizer.from_pretrained(BERTMODEL,
                                          do_lower_case=True,
                                          cache_dir=CACHE_DIR)

Next we tokenize all datasets. We set the maximum sequence lengths for our training and test messages as MAX_LEN_TRAIN and MAX_LEN_TEST. The maximum length supported by the used BERT model is 512 tokens.

In [None]:
%%time

MAX_LEN_TRAIN, MAX_LEN_TEST = 128, 512

data_train = tokenizer(texts_train, padding=True, truncation=True,
                       return_tensors="tf", max_length=MAX_LEN_TRAIN)
data_valid = tokenizer(texts_valid, padding=True, truncation=True,
                       return_tensors="tf", max_length=MAX_LEN_TRAIN)
data_test = tokenizer(texts_test, padding=True, truncation=True,
                      return_tensors="tf", max_length=MAX_LEN_TEST)

Let us look at the truncated tokenized first training message.

In [None]:
data_train["input_ids"][0]

We can also convert the token ids back to tokens. `[CLS]` and `[SEP]` are special tokens required by BERT.

In [None]:
tokenizer.decode(data_train["input_ids"][0])

### TF Datasets

Let's now define our TF `Dataset`s for training, validation, and test data. A batch size of 16 or 32 is often recommended for fine-tuning BERT on a specific task.

In [None]:
BATCH_SIZE = 32

dataset_train = tf.data.Dataset.from_tensor_slices((data_train.data, labels_train))
dataset_train = dataset_train.shuffle(len(dataset_train)).batch(BATCH_SIZE)
dataset_valid = tf.data.Dataset.from_tensor_slices((data_valid.data, labels_valid))
dataset_valid = dataset_valid.batch(BATCH_SIZE)
dataset_test = tf.data.Dataset.from_tensor_slices((data_test.data, labels_test))
dataset_test = dataset_test.batch(BATCH_SIZE)

### Model initialization

We now load a pretrained BERT model with a single linear classification layer added on top. 

In [None]:
model = TFBertForSequenceClassification.from_pretrained(BERTMODEL,
                                                        cache_dir=CACHE_DIR,
                                                        num_labels=20)

We use Adam as the optimizer, categorical crossentropy as loss, and then compile the model.

`LR` is the learning rate for the Adam optimizer (2e-5 to 5e-5 recommended for BERT finetuning).

In [None]:
LR = 2e-5

optimizer = tf.keras.optimizers.Adam(learning_rate=LR, epsilon=1e-08, clipnorm=1.0)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

model.compile(optimizer=optimizer, loss=loss, metrics=[metric])

print(model.summary())

## Learning

In [None]:
logdir = os.path.join(os.getcwd(), "logs",
                      "20ng-bert-"+datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
print('TensorBoard log directory:', logdir)
os.makedirs(logdir)
callbacks = [TensorBoard(log_dir=logdir)]

For fine-tuning BERT on a specific task, 2-4 epochs is often recommended.

In [None]:
%%time

EPOCHS = 4

history = model.fit(dataset_train, validation_data=dataset_valid,
                    epochs=EPOCHS, verbose=2, callbacks=callbacks)

Let's take a look at loss and accuracy for train and validation sets:

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,3))

ax1.plot(history.epoch,history.history['loss'], label='training')
ax1.plot(history.epoch,history.history['val_loss'], label='validation')
ax1.set_title('loss')
ax1.set_xlabel('epoch')
ax1.legend(loc='best')

ax2.plot(history.epoch,history.history['accuracy'], label='training')
ax2.plot(history.epoch,history.history['val_accuracy'], label='validation')
ax2.set_title('accuracy')
ax2.set_xlabel('epoch')
ax2.legend(loc='best');

## Inference

For a better measure of the quality of the model, let's see the model accuracy for the test messages.

In [None]:
%%time

test_scores = model.evaluate(dataset_test, verbose=2)
print("Test set %s: %.2f%%" % (model.metrics_names[1], test_scores[1]*100))

We can also look at classification accuracies separately for each newsgroup, and compute a confusion matrix to see which newsgroups get mixed the most:

In [None]:
test_predictions = model.predict(dataset_test)

cm=confusion_matrix(labels_test,
                    np.argmax(test_predictions[0], axis=1),
                    labels=list(range(20)))

print('Classification accuracy for each newsgroup:'); print()
labels = [l[0] for l in sorted(labels_index.items(), key=lambda x: x[1])]
for i,j in enumerate(cm.diagonal()/cm.sum(axis=1)): print("%s: %.4f" % (labels[i].ljust(26), j))
print()

print('Confusion matrix (rows: true newsgroup; columns: predicted newsgroup):'); print()
np.set_printoptions(linewidth=9999)
print(cm); print()

plt.figure(figsize=(10,10))
plt.imshow(cm, cmap="gray", interpolation="none")
plt.title('Confusion matrix (rows: true newsgroup; columns: predicted newsgroup)')
plt.grid(None)
tick_marks = np.arange(len(labels))
plt.xticks(tick_marks, labels, rotation=90)
plt.yticks(tick_marks, labels);