In [1]:
from abc import ABC, abstractmethod
from typing import Callable

# BERT Model

In [2]:
from tensorflow_hub import KerasLayer
import tensorflow as tf

from tensorflow.keras.layers import Dense, Input, LSTM
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from official.nlp import optimization

from official.nlp.bert.tokenization import FullTokenizer

## BERT Tokenizers

In [3]:
class AbstractBertTokenizer(ABC):
    """ Abstract BERT Tokenizer"""
    label_pattern = None
    
    def __init__(self, encoder, bert_input_size):
        """ Create the BERT encoder and tokenizer """
        self.bert_input_size = bert_input_size
        self.tokenizer = FullTokenizer(
            encoder.resolved_object.vocab_file.asset_path.numpy(), 
            do_lower_case=encoder.resolved_object.do_lower_case.numpy()
        )

    @abstractmethod
    def tokenize_input(self, x):
        """ Tokenize input data """
        return

    def tokenize_labels(self, y):
        """ Tokenize input data labels """
        if self.label_pattern is not None:
            labels = [int(v) for v,n in zip(y, self.label_pattern) for i in range(n)]
            return tf.convert_to_tensor(labels, tf.int32)
        else:
            raise Exception("Must tokenize the input first")
    
    def _format_bert_tokens(self, ragged_word_ids):
        """ Create, format and pad BERT's input tensors """
        # Generate mask, and pad word_ids and mask
        mask = tf.ones_like(ragged_word_ids).to_tensor()
        word_ids = ragged_word_ids.to_tensor()
        padding = tf.constant([[0, 0], [0, (self.bert_input_size - mask.shape[1])]])
        word_ids = tf.pad(word_ids, padding, "CONSTANT")
        mask = tf.pad(mask, padding, "CONSTANT")
        type_ids = tf.zeros_like(mask)
        
        return {
            'input_word_ids': word_ids,
            'input_mask': mask,
            'input_type_ids': type_ids,
        }
    
    def _format_bert_word_piece_input(self, word_piece_tokens):
        word_piece_tokens.insert(0, '[CLS]')
        word_piece_tokens.append('[SEP]')
        return self.tokenizer.convert_tokens_to_ids(word_piece_tokens)

In [4]:
class BertIndividualTweetTokenizer(AbstractBertTokenizer):
    """ BERT tokenizer which tokenizes historical tweet data as individual tweets """
    
    def tokenize_input(self, X):
        """ Tokenize input data """
        tokenized_tweets = [
            self._tokenize_single_tweet(tweet) for tweet_feed in X for tweet in tweet_feed
        ]
        self.label_pattern = [len(tweet_feed) for tweet_feed in X]
        word_ids = tf.ragged.constant(tokenized_tweets)
        return self._format_bert_tokens(word_ids)

    def _tokenize_single_tweet(self, tweet):
        """ Tokenize a single tweet, truncating its tokens to bert_input_size """
        tokens = self.tokenizer.tokenize(tweet)[:self.bert_input_size-2]
        return self._format_bert_word_piece_input(tokens)

In [5]:
class BertTweetFeedTokenizer(AbstractBertTokenizer):
    """ BERT tokenizer which tokenizes historical tweet data as tweet feed chunks """
    
    def tokenize_input(self, X, overlap=50):
        """ Tokenize input data """
        tokenized_tweet_feeds = [
            self._tokenize_tweet_feed(" ".join(tweet_feed), overlap) for tweet_feed in X
        ]
        self.label_pattern = [len(tweet_feed) for tweet_feed in tokenized_tweet_feeds]
        flattened_feeds = [chunk for feed in tokenized_tweet_feeds for chunk in feed]
        word_ids = tf.ragged.constant(flattened_feeds)
        return self._format_bert_tokens(word_ids)
    
    def _tokenize_tweet_feed(self, tweet_feed, overlap):
        """ Tokenize an entire tweet feed into chunks """
        feed_tokens = self.tokenizer.tokenize(tweet_feed)
        tokens = [
            feed_tokens[i:i+self.bert_input_size-2] 
            for i in range(0, len(feed_tokens), self.bert_input_size-overlap)
        ]

        return list(map(self._format_bert_word_piece_input, tokens))

## BERT Models

In [6]:
def base_bert_model(encoder, input_size):
    # Create BERT input layers
    def input_layer(input_name):
        return Input(shape=(input_size,), dtype=tf.int32, name=input_name)

    inputs = {
        'input_word_ids': input_layer("inputs/input_word_ids"),
        'input_mask': input_layer("inputs/input_mask"),
        'input_type_ids': input_layer("inputs/input_type_ids"),
    }

    # BERT's output
    return inputs, encoder(inputs)

In [7]:
# BERT model with a Dense sigmoid output layer
def dense_bert_model(encoder, input_size):
    inputs, bert_output = base_bert_model(encoder, input_size)

    # Dense layer output
    dense_output = Dense(1, activation='sigmoid')(bert_output['pooled_output'])

    # Create the Keras model and compile
    return Model(inputs, dense_output)

In [8]:
# BERT model with an LSTM output layer
def create_bert_model_lstm(encoder, input_size, **kwargs):
    bert_pooled_output = base_bert_model(encoder, input_size)['pooled_output']

    # LSTM layer output
    dense_output = LSTM(1, **kwargs)(bert_pooled_output)

    # Create the Keras model and compile
    return Model(inputs, dense_output)

# Model Evaluation

In [9]:
class BertModelEvalHandler(ABC):
    def __init__(self, bert_url: str, bert_input_size: int, bert_tokenizer_class: AbstractBertTokenizer.__class__,
                bert_model: Callable):
        self.encoder = KerasLayer(bert_url, trainable=True)
        self.tokenizer = bert_tokenizer_class(self.encoder, bert_input_size)
        self.bert = bert_model(self.encoder, bert_input_size)
    
    def train_bert(self, X, y, batch_size, epochs, X_validation, y_validation, optimizer_name, lr, checkpoint_path=None, 
                   tensorboard_path=None):
        # Tokenize input data
        X = self.tokenizer.tokenize_input(X)
        y = self.tokenizer.tokenize_labels(y)
        X_validation = self.tokenizer.tokenize_input(X_validation)
        y_validation = self.tokenizer.tokenize_labels(y_validation)
        
        # Setup callbacks
        callbacks = []
        if checkpoint_path is not None:
            callbacks.append(ModelCheckpoint(
                filepath=checkpoint_path, 
                save_weights_only=True,
                save_best_only=True,
                verbose=1,
            ))
        
        if tensorboard_path is not None:
            callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=tensorboard_path))
        
        # Setup optimizer
        if optimizer_name == 'adam':
            optimizer = tf.keras.optimizers.Adam(lr)
        elif optimizer_name == 'adamw':
            total_training_steps = epochs * len(X['input_word_ids']) / batch_size
            warmup_steps = int(0.1 * total_training_steps)
            optimizer = optimization.create_optimizer(
                init_lr=lr,
                num_train_steps=total_training_steps,
                num_warmup_steps=warmup_steps,
                optimizer_type='adamw'
            )
        
        # Compile and train BERT
        self.bert.compile(
            optimizer=optimizer, 
            loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
            metrics=tf.metrics.BinaryAccuracy(),
        )
        
        self.train_history = self.bert.fit(
            x=X,
            y=y,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(X_validation, y_validation),
            callbacks=callbacks,
        )
        return self.train_history