In [3]:
# import os

import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_text 

import tensorflow_hub as hub
# import tensorflow_datasets as tfds
# tfds.disable_progress_bar()

from official.modeling import tf_utils
from official import nlp
from official.nlp import bert

# Load the required submodules
import official.nlp.optimization
import official.nlp.bert.bert_models
import official.nlp.bert.configs
import official.nlp.bert.run_classifier
import official.nlp.bert.tokenization
import official.nlp.data.classifier_data_lib
import official.nlp.modeling.losses
import official.nlp.modeling.models
import official.nlp.modeling.networks

In [2]:
gs_folder_bert = "gs://cloud-tpu-checkpoints/bert/v3/uncased_L-12_H-768_A-12"
# tf.io.gfile.listdir(gs_folder_bert)

hub_url_bert = "https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/3"

In [6]:
text_input = tf.keras.layers.Input(shape=(), dtype=tf.string)
preprocessor = hub.KerasLayer(
    'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3', 
)
encoder_inputs = preprocessor(text_input)
encoder = hub.KerasLayer(
    'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4', 
    trainable=False, 
)
outputs = encoder(encoder_inputs)
pooled_output = outputs["pooled_output"]     # [batch_size, 768].
sequence_output = outputs["sequence_output"] # [batch_size, seq_length, 768].


In [11]:
embedding_model = tf.keras.Model(text_input, pooled_output)
sentences = tf.constant(["Alice used to die a lot"])
# print(embedding_model(sentences))

## MLM

In [4]:
encoder = hub.load("https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4")
mlm = hub.KerasLayer(encoder.mlm, trainable=False)

In [5]:
preprocessor = hub.load(
    "https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3"
)

In [29]:
seq_length = 16
num_predict = 4

mlm_inputs = dict(
    input_word_ids=tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32),
    input_mask=tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32),
    input_type_ids=tf.keras.layers.Input(shape=(seq_length,), dtype=tf.int32),
    masked_lm_positions=tf.keras.layers.Input(shape=(num_predict,), dtype=tf.int32),
)

mlm_outputs = mlm(mlm_inputs)
mlm_logits = mlm_outputs["mlm_logits"]  # [batch_size, num_predict, vocab_size]
# pooled_output = mlm_outputs["pooled_output"]     # [batch_size, 768].
# sequence_output = mlm_outputs["sequence_output"] # [batch_size, seq_length, 768].


In [30]:
# Step 1: tokenize batches of text inputs.
text_inputs = [
    tf.keras.layers.Input(shape=(), dtype=tf.string),
    *[], 
] # This SavedModel accepts up to 2 text inputs.
tokenize = hub.KerasLayer(preprocessor.tokenize)
tokenized_inputs = [tokenize(segment) for segment in text_inputs]

In [40]:
dir(preprocessor)

['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_trackable_child',
 '_add_variable_with_custom_getter',
 '_checkpoint_dependencies',
 '_default_save_signature',
 '_deferred_dependencies',
 '_delete_tracking',
 '_deserialize_from_proto',
 '_gather_saveables_for_checkpoint',
 '_handle_deferred_dependencies',
 '_is_hub_module_v1',
 '_list_extra_dependencies_for_serialization',
 '_list_functions_for_serialization',
 '_lookup_dependency',
 '_map_resources',
 '_maybe_initialize_trackable',
 '_name_based_attribute_restore',
 '_name_based_restores',
 '_no_dependency',
 '_object_identifier',
 '_preload_simple_restoration',
 '_restore_from_checkpoint_position',

In [31]:
# Step 2 (optional): modify tokenized inputs.
pass

In [32]:
# Step 3: pack input sequences for the Transformer encoder.
bert_pack_inputs = hub.KerasLayer(
    preprocessor.bert_pack_inputs,
    arguments=dict(seq_length=seq_length))  # Optional argument.
encoder_inputs = bert_pack_inputs(tokenized_inputs)

# mlm_outputs = mlm(encoder_inputs)
# mlm_logits = mlm_outputs["mlm_logits"]  # [batch_size, num_predict, vocab_size]

In [36]:
def prep(text):
    prep_model = tf.keras.Model(text_inputs, encoder_inputs)
    sentences = tf.constant([
        text, 
    ])
    return dict(
        **prep_model(sentences), 
        masked_lm_positions=tf.constant([[3,5,7,9]]), 
    )

In [37]:
prep("Hello my dude")

{'input_mask': <tf.Tensor: shape=(1, 16), dtype=int32, numpy=array([[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>,
 'input_type_ids': <tf.Tensor: shape=(1, 16), dtype=int32, numpy=array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])>,
 'input_word_ids': <tf.Tensor: shape=(1, 16), dtype=int32, numpy=
 array([[  101,  7592,  2026, 12043,   102,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0]])>,
 'masked_lm_positions': <tf.Tensor: shape=(1, 4), dtype=int32, numpy=array([[3, 5, 7, 9]])>}

In [38]:
mlm_model = tf.keras.Model(mlm_inputs, mlm_logits)
sentences = prep("Hello my dude")
# sentences = tf.constant(
#     prep("Hello my dude"), 
# )
print(mlm_model(sentences))

tf.Tensor(
[[[-11.770448  -12.12576   -12.188773  ... -10.236775  -11.455725
    -7.4752703]
  [ -8.606209   -8.841853   -8.818768  ...  -8.891097   -7.7307897
    -5.18248  ]
  [ -7.917266   -8.057146   -8.08658   ...  -8.0188     -7.3871493
    -4.688094 ]
  [ -8.401155   -8.482981   -8.588973  ...  -8.461361   -7.946863
    -4.342584 ]]], shape=(1, 4, 30522), dtype=float32)
