# IMDB BERT

A simple BERT model trained from scratch using the IMDB dataset

## Environment prep and data download

In [None]:
!apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2
!pip uninstall -y -q tensorflow keras tensorflow-estimator tensorflow-text
!pip install -q tensorflow_datasets
!pip install -q -U tensorflow-text tensorflow

In [2]:
!curl -O https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
!tar -xf aclImdb_v1.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 80.2M  100 80.2M    0     0  5703k      0  0:00:14  0:00:14 --:--:-- 12.3M


## Data processing

Import all the libraries needed

In [1]:
import os
import tempfile
from typing import Tuple
import re
import glob
import functools
from dataclasses import dataclass

import tensorflow as tf
import tensorflow_text as text
from tensorflow_text.tools.wordpiece_vocab import bert_vocab_from_dataset as bert_vocab
from tensorflow.keras.callbacks import Callback

from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

In [2]:
def get_text_list_from_files(files: list) -> list:
    """
    Read in data from files
    :param files: File list
    :return: Text list
    """
    text_list = []
    for name in files:
        with open(name) as f:
            for line in f:
                text_list.append(line)
    return text_list


def get_data_from_text_files(folder_name: str) -> pd.DataFrame:
    """
    Extract data from downloaded files
    :param folder_name: Directory name
    :return: Dataframe containing text
    """
    pos_files = glob.glob("aclImdb/" + folder_name + "/pos/*.txt")
    pos_texts = get_text_list_from_files(pos_files)
    neg_files = glob.glob("aclImdb/" + folder_name + "/neg/*.txt")
    neg_texts = get_text_list_from_files(neg_files)
    df = pd.DataFrame(
        {
            "review": pos_texts + neg_texts,
            "sentiment": [0] * len(pos_texts) + [1] * len(neg_texts),
        }
    )
    df = df.sample(len(df)).reset_index(drop=True)
    return df

train_df = get_data_from_text_files("train")
test_df = get_data_from_text_files("test")

all_data = train_df.append(test_df)
all_data.head()

Unnamed: 0,review,sentiment
0,"Hoo boy, this was a real trial to get through....",1
1,I've noticed that a lot of people who post on ...,0
2,I know John Singleton's a smart guy 'coz he ma...,1
3,This was a less than exciting short film I saw...,1
4,I would not like to comment on how good the mo...,0


In [3]:
def cln_text(text: str) -> bytes:
    """
    Remove ALL punctuation and html tags from text
    :param text: Text to clean
    :return: Cleansed text
    """
    lowercase = tf.strings.lower(text)
    stripped_html = tf.strings.regex_replace(lowercase, "<br />", " ")
    return tf.strings.regex_replace(
            stripped_html, "[%s]" % re.escape("!#$%&'()*+-/:;<=>?@\^_`{|}~.,"), "").numpy()

all_data['review_cln'] = all_data['review'].map(cln_text)
all_data.head()

Unnamed: 0,review,sentiment,review_cln
0,"Hoo boy, this was a real trial to get through....",1,b'hoo boy this was a real trial to get through...
1,I've noticed that a lot of people who post on ...,0,b'ive noticed that a lot of people who post on...
2,I know John Singleton's a smart guy 'coz he ma...,1,b'i know john singletons a smart guy coz he ma...
3,This was a less than exciting short film I saw...,1,b'this was a less than exciting short film i s...
4,I would not like to comment on how good the mo...,0,b'i would not like to comment on how good the ...


### Generate TensorFlow Datasets

Save data as TensorFlow datasets for use with the model later

In [4]:
X_train, X_test = train_test_split(all_data['review_cln'], test_size=0.33, random_state=42)

train_dataset = tf.data.Dataset.from_tensor_slices(X_train)
test_dataset = tf.data.Dataset.from_tensor_slices(X_test)
next(iter(train_dataset))

<tf.Tensor: shape=(), dtype=string, numpy=b'"kaabee" depicts the hardship of a woman in pre and during wwii raising her kids alone after her husband imprisoned for "thought crime" this movie was directed by yamada youji and as expected the atmosphere of this movie is really wonderful although the historical correctness of some scenes most notably the beach scene is a suspect  the acting in this movie is absolutely incredible i am baffled at how they managed to gather this allstar cast for a 2008 film yoshinaga sayuri possibly the most decorated stillactive actress in japan will undoubtedly win more individual awards for her performance in this film shoufukutei tsurube in a supporting role was really nice as well it was asano tadanobu though who delivered the most impressive performance perfectly portraying the wittiness of his character and the difficult situation he was in  films with prewar setting is not my thing but thanks to wonderful directing and acting i was totally absorbed by

In [5]:
try:
  os.mkdir('data')
except FileExistsError:
  print('Directory already exisits')

def save_tf_data(tf_data: tf.data.Dataset, dir_name: str) -> None:
  """
  Saves the dataset
  """
  path = os.path.join('data', dir_name)
  tf.data.Dataset.save(tf_data, path)

save_tf_data(train_dataset, 'train_data')
save_tf_data(test_dataset, 'test_data')

In [6]:
def load_tf_data(dir_name: str) -> tf.data.Dataset:
    """
    Loads the tensorflow dataset
    """
    path = os.path.join('data', dir_name)
    return tf.data.Dataset.load(path)

train_dataset = load_tf_data('train_data')
test_dataset = load_tf_data('test_data')

### Create tokeniser

BERT requires tokens as input to the model. To create a custom tokeniser follow the steps below. Alternatively, a pre-trained model can be used instead.

In [24]:
@dataclass
class Config:
    def __init__(self):
        self.START_TOKEN = None
        self.END_TOKEN = None
        self.MASK_TOKEN = None
        self.UNK_TOKEN = None
    
    def load_vocab(self, vocab):
        self.START_TOKEN = vocab.index("[CLS]")
        self.END_TOKEN = vocab.index("[SEP]")
        self.MASK_TOKEN = vocab.index("[MASK]")
        self.UNK_TOKEN = vocab.index("[UNK]")
        
    MAX_SEQ_LEN = 256
    MAX_PREDICTIONS_PER_BATCH = 5
    VOCAB_SIZE = 30000
    BATCH_SIZE = 32
    EMBED_DIM = 256  # Dimensionality of embeddings
    NUM_HEAD = 4  # No. of attention heads
    FF_DIM = 512  # Dimensionality of feed forward network
    NUM_LAYERS = 4  # No. of layers
    DROPOUT = 0.1

config = Config()

In [25]:
bert_tokenizer_params=dict(lower_case=True)
reserved_tokens=["[PAD]", "[UNK]", "[MASK]", "[CLS]", "[SEP]"]

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = config.VOCAB_SIZE,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)

In [17]:
%%time
imdb_vocab = bert_vocab.bert_vocab_from_dataset(
    train_dataset,
    **bert_vocab_args
)

CPU times: user 13min 8s, sys: 6 s, total: 13min 14s
Wall time: 13min 51s


In [91]:
with open('imdb_vocab.txt', 'w') as f:
  for line in imdb_vocab:
        f.write(f"{line}\n")

In [8]:
with open('imdb_vocab.txt', 'r') as f:
  imdb_vocab = f.read().splitlines() 

In [27]:
# Add vocab to config dataset
config.load_vocab(imdb_vocab)

In [29]:
lookup_table = tf.lookup.StaticVocabularyTable(
    tf.lookup.KeyValueTensorInitializer(
      keys=imdb_vocab,
      key_dtype=tf.string,
      values=tf.range(
          tf.size(imdb_vocab, out_type=tf.int64), dtype=tf.int64),
          value_dtype=tf.int64
        ),
      num_oov_buckets=1
)


In [30]:
tokenizer = text.BertTokenizer(
      lookup_table,
      token_out_type=tf.int64)

In [42]:
tokenizer.tokenize('Hello world')

<tf.RaggedTensor [[[1],
  [258]]]>

### Mask inputs

imdbBERT is trained on a masked language modelling task. This requires a portion of the input tokens to be masked. The model is then trained to predict the correct token to replace the mask.

In [32]:
@tf.function
def bert_pretrain_preprocess(vocab_table, feature):
    # Input is a string Tensor of documents, shape [batch, 1].
    # Tokenize segments to shape [num_sentences, (num_words)] each.
    tokenizer = text.BertTokenizer(
        vocab_table,
        token_out_type=tf.int64)

    segments = tokenizer.tokenize(feature).merge_dims(1, -1)

    # Truncate inputs to a maximum length.
    trimmer = text.RoundRobinTrimmer(max_seq_length=config.MAX_SEQ_LEN)
    trimmed_segments = trimmer.trim([segments])

    # Combine segments, get segment ids and add special tokens.
    segments_combined, segment_ids = text.combine_segments(
        trimmed_segments,
        start_of_sequence_id=config.START_TOKEN,
        end_of_segment_id=config.END_TOKEN)

    random_selector = text.RandomItemSelector(
        max_selections_per_batch=config.MAX_PREDICTIONS_PER_BATCH,
        selection_rate=0.2,
        unselectable_ids=[config.START_TOKEN, config.END_TOKEN, config.UNK_TOKEN]
    )

    mask_values_chooser = text.MaskValuesChooser(config.VOCAB_SIZE, config.MASK_TOKEN, 0.8)

    # Apply dynamic masking task.
    masked_input_ids, masked_lm_positions, masked_lm_ids = (
        text.mask_language_model(
            segments_combined,
            random_selector,
            mask_values_chooser,
        )
    )

    padded_inputs, _ = text.pad_model_inputs(
        segments_combined, max_seq_length=config.MAX_SEQ_LEN)

    # Prepare and pad combined segment inputs
    masked_word_ids, input_mask = text.pad_model_inputs(
        masked_input_ids, max_seq_length=config.MAX_SEQ_LEN)

    return masked_word_ids, padded_inputs

In [33]:
def make_batches(ds: tf.data.Dataset, lk_up, BUFFER_SIZE: int = 20000, BATCH_SIZE: int = 64):
    """
    It tokenizes the text, and filters out the sequences that are too long. (The batch/unbatch is included because the
    tokenizer is much more efficient on large batches). The cache method ensures that that work is only executed once.
    Then shuffle and, dense_to_ragged_batch randomize the order and assemble batches of examples. Finally, prefetch runs
    the dataset in parallel with the model to ensure that data is available when needed. See Better performance with the
    tf.data for details.
    :param ds: Tensorflow dataset
    :param BUFFER_SIZE: Size of buffer (randomly samples elements from buffer)
    :param BATCH_SIZE: No. of elements within a batch
    :return:
    """
    return (
        ds
        .shuffle(BUFFER_SIZE)
        .batch(BATCH_SIZE)
        .map(functools.partial(bert_pretrain_preprocess, lookup_table))
        .prefetch(buffer_size=tf.data.AUTOTUNE))


train_dataset_masked = make_batches(train_dataset, lookup_table)
test_dataset_masked = make_batches(test_dataset, lookup_table)

## Model Definition

In [35]:
class PositionalEmbedding(tf.keras.layers.Layer):
    """
    PositionEmbedding layer that looks-up a token's embedding vector and adds the position vector
    """

    def __init__(self, vocab_size, d_model):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.pos_encoding = self.positional_encoding(length=2048, depth=d_model)

    @staticmethod
    def positional_encoding(length: int, depth: int):
        """
        A Transformer adds a "Positional Encoding" to the embedding vectors. It uses a set of sines and cosines at
        different frequencies (across the sequence).

        :param length: Length of position embedding vector
        :param depth: Size of feed forward neural network
        :return: Positional encoding vector
        """
        depth = depth / 2

        positions = np.arange(length)[:, np.newaxis]  # (seq, 1)
        depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)

        angle_rates = 1 / (10000 ** depths)  # (1, depth)
        angle_rads = positions * angle_rates  # (pos, depth)

        pos_encoding = np.concatenate(
            [np.sin(angle_rads), np.cos(angle_rads)],
            axis=-1)

        return tf.cast(pos_encoding, dtype=tf.float32)

    def compute_mask(self, *args, **kwargs):
        return self.embedding.compute_mask(*args, **kwargs)

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positional_encoding.
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x


class BaseAttention(tf.keras.layers.Layer):
    """
    Attention layers are used throughout the model. These are all identical except for how the attention is configured.
    Each one contains a layers.MultiHeadAttention, a layers.LayerNormalization and a layers.Add.
    """

    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()
        self.last_attn_scores = None



class GlobalSelfAttention(BaseAttention):
    """
    This layer is responsible for processing the context sequence, and propagating information along its length.
    Since the context sequence is fixed while the translation is being generated, information is allowed to flow in
    both directions.

    To implement this layer you just need to pass the target sequence, x, as both the query, and value
    arguments to the mha layer:
    """

    def call(self, x):
        """

        :param x: Target sequence
        :return:
        """
        attn_output = self.mha(
            query=x,
            value=x,
            key=x)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x



class FeedForward(tf.keras.layers.Layer):
    """
    The transformer also includes this point-wise feed-forward network in both the encoder and decoder.
    The network consists of two linear layers (tf.keras.layers.Dense) with a ReLU activation in-between, and a
    dropout layer.
    As with the attention layers, the code here also includes the residual connection and normalization.

    """

    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x


class EncoderLayer(tf.keras.layers.Layer):
    """
    The encoder contains a stack of N encoder layers. Where each EncoderLayer contains a GlobalSelfAttention
    and FeedForward layer.
    """

    def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1):
        super().__init__()

        self.self_attention = GlobalSelfAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x


class Encoder(tf.keras.layers.Layer):
    """
    The encoder consists of:
        - A PositionalEmbedding layer at the input.
        - A stack of EncoderLayer layers.
    """

    def __init__(self, *, num_layers, d_model, num_heads,
                 dff, vocab_size, dropout_rate=0.1):
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size, d_model=d_model)

        self.enc_layers = [
            EncoderLayer(d_model=d_model,
                         num_heads=num_heads,
                         dff=dff,
                         dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        # `x` is token-IDs shape: (batch, seq_len)
        x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.

        # Add dropout.
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, d_model)`.


class imdbBERT(tf.keras.Model):
    """
    Put Encoder and Decoder together and add a final linear (Dense) layer which converts the resulting vector at
    each location into output token probabilities.
    """

    def __init__(self, *, num_layers, d_model, num_heads, dff,
                 input_vocab_size, target_vocab_size, dropout_rate=0.1):
        super().__init__()
        self.encoder = Encoder(num_layers=num_layers, d_model=d_model,
                               num_heads=num_heads, dff=dff,
                               vocab_size=input_vocab_size,
                               dropout_rate=dropout_rate)


        self.final_layer = tf.keras.layers.Dense(target_vocab_size)


    @tf.function
    def call(self, inputs):
        # To use a Keras model with `.fit` you must pass all your inputs in the
        # first argument.
        encoded = self.encoder(inputs)  # (batch_size, context_len, d_model)        

        # Final linear layer output.
        logits = self.final_layer(encoded)  # (batch_size, target_len, target_vocab_size)

        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            del logits._keras_mask
        except AttributeError:
            pass

        # Return the final output and the attention weights.
        return logits

### Model training functions

A custom schedule is used to vary the learning rate during training. Masked loss/accuracy as used to determine model progress

In [36]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super().__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, dtype=tf.float32)
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)
    
    def get_config(self):
        config = {
        'd_model': self.d_model,
        'warmup_steps': self.warmup_steps,

        }
        return config


def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
    return loss


def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match) / tf.reduce_sum(mask)

This callback outputs how the model is progressing against a static task. It is predicting output tokens for the phrase:

*I have watched this **[MASK]** and it was awesome* 

In [37]:
class MaskedTextGenerator(Callback):
    def __init__(self, sample_tokens, vocab, top_k=5):
        self.id2token = dict(enumerate(vocab))
        self.token2id = {y: x for x, y in self.id2token.items()}
        self.sample_tokens = sample_tokens
        self.k = top_k
        self.mask_token_id = self.token2id.get('[MASK]')

    def decode(self, tokens):
        return " ".join([self.id2token[t] for t in tokens if t != 0])

    def convert_ids_to_tokens(self, id):
        return self.id2token[id]

    def on_epoch_end(self, epoch, logs=None):
        prediction = self.model.predict(self.sample_tokens)

        masked_index = np.where(self.sample_tokens == self.mask_token_id)
        masked_index = masked_index[1]
        mask_prediction = prediction[0][masked_index]

        top_indices = mask_prediction[0].argsort()[-self.k:][::-1]
        values = mask_prediction[0][top_indices]

        for i in range(len(top_indices)):
            p = top_indices[i]
            v = values[i]
            tokens = np.copy(self.sample_tokens[0])
            tokens[masked_index[0]] = p
            result = {
                "input_text": self.decode(self.sample_tokens[0].numpy()),
                "prediction": self.decode(tokens),
                "probability": v,
                "predicted mask token": self.convert_ids_to_tokens(p),
            }
            print(result)


### Compile Model

In [38]:
transformer = imdbBERT(
    num_layers=config.NUM_LAYERS,
    d_model=config.EMBED_DIM,
    num_heads=config.NUM_HEAD,
    dff=config.FF_DIM,
    input_vocab_size=config.VOCAB_SIZE,
    target_vocab_size=config.VOCAB_SIZE,
    dropout_rate=config.DROPOUT)

learning_rate = CustomSchedule(config.EMBED_DIM)

optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)

transformer.compile(
    loss=masked_loss,
    optimizer=optimizer,
    metrics=[masked_accuracy])

In [41]:
#  Text to test with: "i have watched this [MASK] and it was awesome"
sample_tokens = tf.ragged.constant([[28, 100, 365, 86, 2, 80, 85, 88, 1223]])

masked_token_ids, _ = text.pad_model_inputs(
    sample_tokens, max_seq_length=config.MAX_SEQ_LEN)

generator_callback = MaskedTextGenerator(masked_token_ids, imdb_vocab) 

In [40]:
tokenizer.detokenize([[28, 100, 365, 86, 2, 80, 85, 88, 1223]])

<tf.RaggedTensor [[b'i', b'have', b'watched', b'this', b'[MASK]', b'and', b'it', b'was',
  b'awesome']]>

In [None]:
# tf.gather(imdb_vocab, masked_token_ids)

### Model Training

In [None]:
transformer.fit(train_dataset_masked,
                epochs=3,
                validation_data=test_dataset_masked,
                callbacks=[generator_callback])

In [105]:
transformer.summary()

Model: "imdb_bert"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder (Encoder)           multiple                  12943360  
                                                                 
 dense_8 (Dense)             multiple                  7710000   
                                                                 
Total params: 20,653,360
Trainable params: 20,653,360
Non-trainable params: 0
_________________________________________________________________


### Save/Load Model

The custom learning rate schedule means we need to save the model weights instead of the whole model. The checkpoints can then be reloaded into a new model later on for inference

In [128]:
transformer.save_weights('/checkpoints/my_checkpoint')

In [129]:
loaded_model = imdbBERT(
    num_layers=config.NUM_LAYERS,
    d_model=config.EMBED_DIM,
    num_heads=config.NUM_HEAD,
    dff=config.FF_DIM,
    input_vocab_size=config.VOCAB_SIZE,
    target_vocab_size=config.VOCAB_SIZE,
    dropout_rate=config.DROPOUT)

In [130]:
loaded_model.load_weights('/checkpoints/my_checkpoint')

<tensorflow.python.checkpoint.checkpoint.CheckpointLoadStatus at 0x7fbf60814150>

In [131]:
prediction = loaded_model.predict(masked_token_ids)
masked_index = np.where(masked_token_ids == 2)
masked_index = masked_index[1]
mask_prediction = prediction[0][masked_index]
top_indices = mask_prediction[0].argsort()[-5 :][::-1]
values = mask_prediction[0][top_indices]
values



array([0.53967685, 0.48042318, 0.47334057, 0.47148484, 0.4674449 ],
      dtype=float32)