In [1]:
import pathlib
import tensorflow as tf
import tensorflow_ranking as tfr 

In [2]:
context_feature_spec = {
    'query_tokens': tf.io.RaggedFeature(dtype=tf.string)
}
example_feature_spec = {
    'document_tokens': tf.io.RaggedFeature(dtype=tf.string)
}
label_spec = {
    'relevance': tf.io.FixedLenFeature(shape=(1,), dtype=tf.int64, default_value=-1)
}

In [3]:
input_creator = tfr.keras.model.FeatureSpecInputCreator(
    context_feature_spec, example_feature_spec)

In [4]:
input_creator()

({'query_tokens': <KerasTensor: type_spec=RaggedTensorSpec(TensorShape([None, None]), tf.string, 1, tf.int64) (created by layer 'query_tokens')>},
 {'document_tokens': <KerasTensor: type_spec=RaggedTensorSpec(TensorShape([None, None, None]), tf.string, 2, tf.int64) (created by layer 'document_tokens')>})

In [11]:
class LookUpTablePreprocessor(tfr.keras.model.Preprocessor):
    def __init__(self, vocab_file, vocab_size, embedding_dim):
        self._vocab_file = vocab_file
        self._vocab_size = vocab_size
        self._embedding_dim = embedding_dim

    def __call__(self, context_inputs, example_inputs, mask):
        list_size = tf.shape(mask)[1]
        lookup = tf.keras.layers.StringLookup(
            max_tokens = self._vocab_size,
            vocabulary = self._vocab_file,
            mask_token=None)
        embedding = tf.keras.layers.Embedding(
            input_dim = self._vocab_size,
            output_dim = self._embedding_dim,
            embeddings_initializer = None,
            embeddings_constraint = None)
        
        context_features = {
            key: tf.reduce_mean(embedding(lookup(value)), axis=-2)
            for key, value in context_inputs.items()
        }
        example_features = {
            key: tf.reduce_mean(embedding(lookup(value)), axis=-2)
            for key, value in example_inputs.items()
        }
        return context_features, example_features

In [13]:
_VOCAB_FILE = '/home/guhangsong/Data/Antique/vocab.txt'
_VOCAB_SIZE = len(pathlib.Path(_VOCAB_FILE).read_text().split())

preprocessor = LookUpTablePreprocessor(_VOCAB_FILE, _VOCAB_SIZE, 20)

In [8]:
scorer = tfr.keras.model.DNNScorer(
    hidden_layer_dims = [64, 32, 16],
    output_units = 1,
    activation = tf.nn.relu,
    use_batch_norm = True
)

In [15]:
model_builder = tfr.keras.model.ModelBuilder(
    input_creator = input_creator,
    preprocessor = preprocessor,
    scorer = scorer,
    mask_feature_name = 'example_list_mask',
    name = 'antique_model'
)

In [18]:
model = model_builder.build()
tf.keras.utils.plot_model(model, expand_nested=True)

('You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) ', 'for plot_model/model_to_dot to work.')
