In [14]:
preprocessor_url = "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
encoder_url = "https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-10_H-128_A-2/2"

In [15]:
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

  from pkg_resources import parse_version


In [16]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessor = hub.KerasLayer(preprocessor_url)
encoder_inputs = preprocessor(text_input)

In [17]:
encoder = hub.KerasLayer(encoder_url, trainable=True)
outputs = encoder(encoder_inputs)

pooled_output = outputs["pooled_output"]
sequence_output = outputs["sequence_output"]

In [18]:
layer = tf.keras.layers.Dropout(0.1, name='dropout')(pooled_output)
layer = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(layer)

model = tf.keras.Model(inputs=[text_input], outputs=[layer])

In [19]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None,)]            0           []                               
                                                                                                  
 keras_layer (KerasLayer)       {'input_word_ids':   0           ['input_1[0][0]']                
                                (None, 128),                                                      
                                 'input_mask': (Non                                               
                                e, 128),                                                          
                                 'input_type_ids':                                                
                                (None, 128)}                                                  

In [20]:
from tensorflow import keras
from tensorflow.keras import callbacks, metrics 

In [21]:
METRICS = [
    metrics.BinaryAccuracy(name='accuracy'),
    metrics.Precision(name='precision'),
    metrics.Recall(name='recall')
]

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=3e-5), 
    loss='binary_crossentropy', 
    metrics=METRICS
)

early_stopping = callbacks.EarlyStopping(
    monitor='val_loss',
    mode='min',
    min_delta=1e-4,
    patience=3,
    verbose=1,
    restore_best_weights=True
)

In [23]:
history = model.fit(
    X_train, 
    y_train, 
    validation_data=(X_val, y_val),
    batch_size=32,
    epochs=5,
    verbose=1,
    callbacks=[early_stopping]
)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
