# Text Classification using BERT
An example of using BERT for text classification task using Tensorflow by fine-tuning BERT model to classify movie reviews in the [IMDB dataset](http://ai.stanford.edu/~amaas/data/sentiment/).


# Dependencies 

- **TensorFlow Text for Natural Language Processing (NLP) tasks**
    ```bash
    !pip install "tensorflow-text==2.11.*"
    ```
    TensorFlow Text is a powerful library that provides tools for NLP tasks like text classification, sentiment analysis, named entity recognition, etc. It allows preprocessing text data, tokenization, and word embedding, facilitating the creation and training of advanced NLP models.

- **TensorFlow/models for utilizing the AdamW optimizer**
    ```bash
    !pip install "tf-models-official==2.11.0"
    ```
    TensorFlow/models offers official prebuilt models and optimization algorithms. It includes the AdamW optimizer, a variant of Adam that incorporates weight decay for more stable training and improved generalization. This library provides various model implementations and optimization strategies to enhance deep learning model performance.

- **TensorFlow Hub to access a wealth of pretrained models**
    ```bash
    !pip install "tensorflow_hub"
    ```
    TensorFlow Hub hosts a vast collection of pretrained models, embeddings, and modules for tasks like image recognition and text understanding. By leveraging TensorFlow Hub, you can easily integrate state-of-the-art architectures into your projects, saving valuable time and computational resources. It enables you to utilize cutting-edge models without training them from scratch, making it invaluable for machine learning practitioners.

## Installation

In [2]:
# Tensorflow Text for Natural Language Processing (NLP) tasks
! pip install "tensorflow-text==2.11.*"

# Tensorflow/models for utilizing the powerful AdamW optimizer
! pip install "tf-models-official==2.11.0"

# Tensorflow Hub to unlock a plethora of state-of-the-art pretrained models
! pip install "tensorflow_hub"

/bin/bash: /home/wasin/miniconda3/envs/tf/lib/libtinfo.so.6: no version information available (required by /bin/bash)
/bin/bash: /home/wasin/miniconda3/envs/tf/lib/libtinfo.so.6: no version information available (required by /bin/bash)
/bin/bash: /home/wasin/miniconda3/envs/tf/lib/libtinfo.so.6: no version information available (required by /bin/bash)


## Importing the libraries

In [3]:
# os and shutil for navigating directories and saving files
import os
import shutil

# All tensorflow Gang for Deep Learning
import tensorflow as tf
import tensorflow_hub as hub    
import tensorflow_text as text
from official.nlp import optimization

# Matplotlib for plotting
import matplotlib.pyplot as plt

# Mixed precision training
from tensorflow.keras import mixed_precision
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

2023-07-26 00:54:43.557501: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-26 00:54:44.906591: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/wasin/miniconda3/envs/tf/lib/:/home/wasin/miniconda3/envs/tf/lib/python3.9/site-packages/nvidia/cudnn/lib
2023-07-26 00:54:44.909064: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/wasin/mini

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 2070 with Max-Q Design, compute capability 7.5


# Download and Process Dataset

### Download the dataset using tf.keras.utils.get_file()

In [3]:
dataset = tf.keras.utils.get_file(
    'aclImdb.tar.gz',
    'http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz',
    cache_dir='.', cache_subdir='', untar=True
)

Downloading data from http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz


### Define path to the dataset for easy access

In [4]:
dataset_dir = os.path.join(os.path.dirname(dataset), 'aclImdb')
train_dir = os.path.join(dataset_dir, 'train')

# Remove unused folders to use binary classification
remove_dir = os.path.join(train_dir, 'unsup')
shutil.rmtree(remove_dir)

### Create dataset
using `text_dataset_from_directory` to create tf.data.Dataset object

In [5]:
AUTOTUNE = tf.data.AUTOTUNE
batch_size = 32
seed = 69

In [6]:
# Load the training dataset
raw_train_dataset = tf.keras.utils.text_dataset_from_directory(
    'aclImdb/train',        # directory of the dataset
    batch_size=batch_size,  # batch size
    validation_split=0.2,   # 20% of the dataset will be used for validation
    subset='training',      # training subset of the dataset define what this dataset will be used for
    seed=seed               # seed for reproducibility
)
class_names = raw_train_dataset.class_names
train_dataset = raw_train_dataset.cache().prefetch(buffer_size=AUTOTUNE)

# Load the validation dataset
validation_dataset = tf.keras.utils.text_dataset_from_directory(
    'aclImdb/train',
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed
)
validation_dataset = validation_dataset.cache().prefetch(buffer_size=AUTOTUNE)

# Load the test dataset
test_dataset = tf.keras.utils.text_dataset_from_directory(
    'aclImdb/test',
    batch_size=batch_size
)
test_dataset = test_dataset.cache().prefetch(buffer_size=AUTOTUNE)

Found 25000 files belonging to 2 classes.
Using 20000 files for training.


2023-07-25 17:28:27.933221: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-07-25 17:28:28.666794: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38245 MB memory:  -> device: 0, name: NVIDIA A100-PCIE-40GB, pci bus id: 0000:61:00.0, compute capability: 8.0


Found 25000 files belonging to 2 classes.
Using 5000 files for validation.
Found 25000 files belonging to 2 classes.


# Data Visualization

In [7]:
for text, labels in train_dataset.take(1):
    for i in range(3):
        print(f'Review: {text.numpy()[i]}')
        label = labels.numpy()[i]
        print(f'Label : {label} ({class_names[label]})')

Review: b"This is a very sad movie. Really. Nothing happens in this movie. The Script is bad!!! I guess they've just copy-paste the first 15 pages to 90 pages. The Producers must have thought let's create a Hollywood movie here in Belgium. They didn't succeed. Now in the third week it is only running in Antwerp and Brussels at 22h45 or something. In the past we have had really good movies in Belgium, like Daens. Shades is a waste of your time. Maybe you could sneak in the theater after you've seen a real movie. If you've seen 10 minutes of Shades, you've seen it all. It was advertised to death on local radio and TV. I hope it will disappear in the Shades soon."
Label : 0 (neg)
Review: b'Dick Foran and Peggy Moran, who were so good together in THE MUMMY\'S HAND, return for this very minor Universal Horror offering. But this time, instead of having Wallace Ford as the comedic sidekick "Babe," we get Fuzzy Knight substituting as a silly buddy named "Stuff". But the results are nowhere nea

2023-07-25 17:28:31.325828: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


# Loading BERT from the Tensorflow Hub
Inorder to fine-tune BERT model we will have to choose the desired BERT model from Tensorflow Hub
1. BERT-Base, Uncased and seven more models with trained weights released by the original BERT authors.
2. Small BERTs have the same general architecture but fewer and/or smaller Transformer blocks, which lets you explore tradeoffs between speed, size and quality.
3. ALBERT: four different sizes of "A Lite BERT" that reduces model size (but not computation time) by sharing parameters between layers.
4. BERT Experts: eight models that all have the BERT-base architecture but offer a choice between different pre-training domains, to align more closely with the target task.
5. Electra has the same architecture as BERT (in three different sizes), but gets pre-trained as a discriminator in a set-up that resembles a Generative Adversarial Network (GAN).
6. BERT with Talking-Heads Attention and Gated GELU [base, large] has two improvements to the core of the Transformer architecture.

In [4]:
# Switch this name to the model you want to use
bert_model_name = 'bert_en_uncased_L-12_H-768_A-12' 

map_name_to_handle = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_cased_L-12_H-768_A-12/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_multi_cased_L-12_H-768_A-12/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-6_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-8_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-768_A-12/1',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-128_A-2/1',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-256_A-4/1',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-512_A-8/1',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-12_H-768_A-12/1',
    'albert_en_base':
        'https://tfhub.dev/tensorflow/albert_en_base/2',
    'electra_small':
        'https://tfhub.dev/google/electra_small/2',
    'electra_base':
        'https://tfhub.dev/google/electra_base/2',
    'experts_pubmed':
        'https://tfhub.dev/google/experts/bert/pubmed/2',
    'experts_wiki_books':
        'https://tfhub.dev/google/experts/bert/wiki_books/2',
    'talking-heads_base':
        'https://tfhub.dev/tensorflow/talkheads_ggelu_bert_en_base/1',
}

map_model_to_preprocess = {
    'bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'bert_en_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_cased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-2_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-4_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-6_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-8_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-10_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-128_A-2':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-256_A-4':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-512_A-8':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'small_bert/bert_en_uncased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'bert_multi_cased_L-12_H-768_A-12':
        'https://tfhub.dev/tensorflow/bert_multi_cased_preprocess/3',
    'albert_en_base':
        'https://tfhub.dev/tensorflow/albert_en_preprocess/3',
    'electra_small':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'electra_base':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'experts_pubmed':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'experts_wiki_books':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
    'talking-heads_base':
        'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3',
}

tfhub_handle_encoder = map_name_to_handle[bert_model_name]
tfhub_handle_preprocess = map_model_to_preprocess[bert_model_name]

print(f'BERT model selected           : {tfhub_handle_encoder}')
print(f'Preprocess model auto-selected: {tfhub_handle_preprocess}')

BERT model selected           : https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3
Preprocess model auto-selected: https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3


# Create Preprocessing model
Inorder for text to be used with BERT text inputs need to be transformed to a numeric token ids and arraged in BERT format Tensors befor inputing to the model. Which is F*** tedious. But thanks to Tensorflow Text we can use the preprocessing model from the BERT model hub module to transform our text inputs to BERT format Tensors.

In [5]:
preprocess_model = hub.KerasLayer(tfhub_handle_preprocess)

In [6]:
# Try out the preprocessing model
text_test = ['this is the shittest movie I have ever digest into my brain']
text_preprocessed = preprocess_model(text_test)

print(f'Keys       \t: {list(text_preprocessed.keys())}')
print(f'Shape      \t: {text_preprocessed["input_word_ids"].shape}')
print(f'Word Ids   \t: {text_preprocessed["input_word_ids"][0, :12]}')
print(f'Input Mask \t: {text_preprocessed["input_mask"][0, :12]}')
print(f'Type Ids   \t: {text_preprocessed["input_type_ids"][0, :12]}')

# 3 Input BERT model need
# 1. input_word_ids
# 2. input_mask
# 3. input_type_ids 

Keys       	: ['input_word_ids', 'input_type_ids', 'input_mask']
Shape      	: (1, 128)
Word Ids   	: [  101  2023  2003  1996  4485 22199  3185  1045  2031  2412 17886  2046]
Input Mask 	: [1 1 1 1 1 1 1 1 1 1 1 1]
Type Ids   	: [0 0 0 0 0 0 0 0 0 0 0 0]


# Create BERT model

In [7]:
bert = hub.KerasLayer(tfhub_handle_encoder)

In [8]:
bert_results = bert(text_preprocessed)

print(f'Loaded BERT: {tfhub_handle_encoder}')
print(f'Pooled Outputs Shape:{bert_results["pooled_output"].shape}')
print(f'Pooled Outputs Values:{bert_results["pooled_output"][0, :12]}')
print(f'Sequence Outputs Shape:{bert_results["sequence_output"].shape}')
print(f'Sequence Outputs Values:{bert_results["sequence_output"][0, :12]}')

# BERT model return 3 outputs
# 1. pooled_output : Represent the entire input sequence as a single vector (Embedding of the entire sequence)
# 2. sequence_output : Represent contextual embedding for every token in the input sequence
# 3. encoder_outputs : No idea what this is LMAO

Loaded BERT: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3
Pooled Outputs Shape:(1, 768)
Pooled Outputs Values:[-0.7612924  -0.15849979  0.5720426   0.43245652 -0.33294767 -0.08109941
  0.5674548   0.10577297  0.29199526 -0.99840224  0.46006808  0.1845455 ]
Sequence Outputs Shape:(1, 128, 768)
Sequence Outputs Values:[[ 0.1378846   0.21694613  0.17518272 ... -0.00650308  0.22636311
   0.22305849]
 [-0.43794012  0.2332033  -0.17401683 ... -0.12302686  1.3747523
   0.35195723]
 [-0.132808    0.36113238  0.26082492 ...  0.17491016  0.78390354
   0.50536287]
 ...
 [ 0.26680347  0.12717715 -0.45306003 ... -0.2467128   0.27201653
   0.36186197]
 [-0.06213828  0.09613115 -0.06121171 ...  0.13682327  0.19195457
   0.39893985]
 [-0.3652673   0.42820388 -0.32194316 ...  0.13237154  1.1129215
   0.2600068 ]]


# Create & Fine-tune Classifier model

## Create the classifier model

In [9]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string, name='RAW_TEXT_INPUT_LAYER')
preprocessing_layer = hub.KerasLayer(tfhub_handle_preprocess, name='BERT_PREPROCESSING_LAYER')
encoder_inputs = preprocessing_layer(text_input)
encoder = hub.KerasLayer(tfhub_handle_encoder, trainable=True, name='BERT_ENCODER_LAYER')
outputs = encoder(encoder_inputs)

# Custom Network
nn = outputs['pooled_output']
nn = tf.keras.layers.Dropout(0.1)(nn)
nn = tf.keras.layers.Dense(1, activation=None, name='CLASSIFIER')(nn)

model = tf.keras.Model(text_input, nn)

Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


Instructions for updating:
Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089


In [14]:
# Test run the model
print(f'Input  : {text_test}')
print(f'Output : {tf.sigmoid(model(tf.constant(text_test)))}')

# Q: Why it got such a positive output?
# A: Because the model is not trained yet, so the output is random

Input  : ['this is the shittest movie I have ever digest into my brain']
Output : [[0.4663]]


In [15]:
# Plot the model with keras super cool plot function
tf.keras.utils.plot_model(model)

You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.


## Compile the model

In [16]:
# Define Loss Function

loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)
metrics = tf.metrics.BinaryAccuracy()

In [17]:
epochs = 5
steps_per_epoch = tf.data.experimental.cardinality(train_dataset).numpy()
num_train_steps = steps_per_epoch * epochs
num_warmup_steps = int(0.1*num_train_steps)

In [18]:
# Define Optimizer function AdamW the original BERT optimizer

optimizer = optimization.create_optimizer(
    init_lr=3e-5, # Recommended by Tensorflow 5e-5, 3e-5, 2e-5
    num_train_steps=num_train_steps,
    num_warmup_steps=num_warmup_steps,
    optimizer_type='adamw'
)

In [11]:
# Define Callbacks

# Tensorboard Callback
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir='logs',
    histogram_freq=1,
    profile_batch=0
)

# Checkpoint Callback
checkpoint_path = 'bert_imdb_sentiment_classifier.h5'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_binary_accuracy',
    save_best_only=True,
    save_weights_only=True
)

In [20]:
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)

## Fine-tune the model

In [None]:
# Take about 2min per epochs on PCI-E Nvidia A100 💀
history = model.fit(initial_epoch=5, x=train_dataset, validation_data=validation_dataset, epochs=10, callbacks=[tensorboard_callback, checkpoint_callback])

Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10

In [12]:
# Load the best model
model.load_weights(checkpoint_path)

In [23]:
import time

# Test the model
movie_reviews = [
    "Wow! This movie was a rollercoaster of emotions. The acting was superb, the plot was engaging, and the cinematography was breathtaking. I couldn't take my eyes off the screen!",
    "What a disappointment! The film had a promising premise, but it fell flat with weak execution. The characters were one-dimensional, and the plot was predictable. I expected more from such a hyped movie.",
    "This movie exceeded all my expectations. The storyline was gripping, and the performances were Oscar-worthy. It's a must-watch for anyone who loves intense dramas.",
    "I regret wasting my time on this film. The plot was confusing, and the pacing was all over the place. I struggled to connect with any of the characters, making it difficult to care about what was happening on screen.",
    "Incredible! This movie was a visual masterpiece. The special effects were mind-blowing, and the action sequences were adrenaline-pumping. I can't wait to see it again!",
    "I don't understand the hype around this movie. The dialogue was cheesy, and the acting was subpar. It felt like a generic, forgettable film that lacked any originality.",
    "A heartwarming and touching story that stayed with me long after the credits rolled. The performances were heartful and sincere, making it an emotionally enriching experience.",
    "This movie was a disaster. The plot was riddled with holes, and the ending was unsatisfying. I found myself checking my watch repeatedly, wishing it would end sooner.",
    "A brilliant combination of humor and heart. The witty dialogue had me laughing out loud, and the characters were relatable and endearing. A feel-good movie at its finest!",
    "I can't believe I wasted money on this film. The acting was cringe-worthy, and the plot was a jumbled mess. I left the theater feeling frustrated and cheated."
]


for text in movie_reviews:
    s = time.time()
    output = tf.sigmoid(model(tf.constant([text])))
    t = time.time() - s
    print(f'Input  : {text}')
    print(f'RawOut: {output}')
    print(f'Output : {tf.where(output < 0.5, 0, 1)[0][0]}')
    print(f'Time   : {t}')
    print("============================================================")

Input  : Wow! This movie was a rollercoaster of emotions. The acting was superb, the plot was engaging, and the cinematography was breathtaking. I couldn't take my eyes off the screen!
RawOut: [[0.998]]
Output : 1
Time   : 0.03455495834350586
Input  : What a disappointment! The film had a promising premise, but it fell flat with weak execution. The characters were one-dimensional, and the plot was predictable. I expected more from such a hyped movie.
RawOut: [[9.245e-05]]
Output : 0
Time   : 0.05763697624206543
Input  : This movie exceeded all my expectations. The storyline was gripping, and the performances were Oscar-worthy. It's a must-watch for anyone who loves intense dramas.
RawOut: [[1.]]
Output : 1
Time   : 0.0315401554107666
Input  : I regret wasting my time on this film. The plot was confusing, and the pacing was all over the place. I struggled to connect with any of the characters, making it difficult to care about what was happening on screen.
RawOut: [[0.0002632]]
Output :