In [2]:
!pip install transformers
%tensorflow_version 2.x

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/cd/38/c9527aa055241c66c4d785381eaf6f80a28c224cae97daa1f8b183b5fabb/transformers-2.9.0-py3-none-any.whl (635kB)
[K     |████████████████████████████████| 645kB 3.0MB/s 
Collecting tokenizers==0.7.0
[?25l  Downloading https://files.pythonhosted.org/packages/14/e5/a26eb4716523808bb0a799fcfdceb6ebf77a18169d9591b2f46a9adb87d9/tokenizers-0.7.0-cp36-cp36m-manylinux1_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 8.1MB/s 
Collecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/98/2c/8df20f3ac6c22ac224fff307ebc102818206c53fc454ecd37d8ac2060df5/sentencepiece-0.1.86-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 35.7MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |█████

In [3]:
import os

import tensorflow as tf
from transformers import TFRobertaModel

logger = tf.get_logger()
logger.info(tf.__version__)

INFO:tensorflow:2.2.0


### Parsing tfrecords + tf.data.Dataset

In [0]:
class TextDataset(tf.data.Dataset):
    feature = {}
    for i in range(512):
        feature['dim_' + str(i)] = tf.io.VarLenFeature(tf.int64)


    def _parse_example(example_proto):
        parsed_example_dict = tf.io.parse_single_example(example_proto, TextDataset.feature)
        parsed_example = [tf.sparse.to_dense(parsed_example_dict['dim_'+str(i)]) for i in range(512)]
        parsed_example = tf.transpose(tf.stack(parsed_example, axis=0), perm=[1, 0])
        return parsed_example


    def _construct_inputs(input_ids):
        input_ids = tf.cast(input_ids, dtype=tf.int32)
        num_papers = tf.shape(input_ids)[0]
        idx_a = tf.random.uniform(minval=0, maxval=num_papers, shape=[], dtype=tf.int32)
        input_a = tf.gather(input_ids, idx_a)
        all_related_papers = tf.gather(input_ids, tf.where(
            tf.logical_not(tf.reduce_all(tf.equal(input_ids, input_a), axis=-1)))[:, 0])

        idx = tf.random.categorical(tf.zeros([1, num_papers-1], dtype=tf.float32), num_samples=4)[0]
        input_b, input_c, input_d, input_e = tf.unstack(tf.gather(all_related_papers, idx),
                                                        num=4,
                                                        axis=0)
        return input_a, input_b, input_c, input_d, input_e


    def _parse_and_create_sample(example_proto):
        input_ids = TextDataset._parse_example(example_proto)
        sample = TextDataset._construct_inputs(input_ids)
        positive_labels = tf.ones([1])
        negative_labels = tf.ones([batch_size])
        return sample, (negative_labels, positive_labels)
    
    def __new__(cls, tfrecords_pattern, epochs, batch_size):
        _options = tf.data.Options()
        _options.experimental_deterministic = False

        tfrecords = tf.data.Dataset.list_files(tfrecords_pattern)
        dataset = tfrecords.interleave(tf.data.TFRecordDataset,
                                       cycle_length=4,
                                       block_length=16,
                                       num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.with_options(_options)
        dataset = dataset.map(cls._parse_and_create_sample,
                              num_parallel_calls=tf.data.experimental.AUTOTUNE)
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.repeat(epochs)
        dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
        return dataset

In [0]:
batch_size=4
epochs = 2
lr = 1e-5
dataset = TextDataset('tfrecords/*', epochs=epochs, batch_size=batch_size)

In [0]:
def dot_product(x, y, pairwise=False, name=None):
  if pairwise:
    x = tf.expand_dims(x, axis=1)
  return tf.reduce_sum(tf.multiply(x, y), axis=-1, name=name, keepdims=not pairwise)

base_model = TFRobertaModel.from_pretrained('allenai/biomed_roberta_base', from_pt=True)

inputs = [tf.keras.Input(shape=[512], dtype=tf.int32, name='input_{}'.format(i), batch_size=batch_size) for i in ['a', 'b', 'c', 'd', 'e']]
outputs = [tf.reduce_mean(base_model(x)[0], axis=1) for x in inputs]

ff1 = tf.keras.layers.Dense(768, activation='tanh', name='ff1')
ff1_outputs = [ff1(x) for x in outputs]

mean_related_papers = tf.reduce_mean(tf.concat(ff1_outputs[1:], axis=1), axis=1, keepdims=True)
ff2 = tf.keras.layers.Dense(768, activation='tanh', name='ff2')
ff2_output = ff2(mean_related_papers)

negative_outputs = dot_product(ff1_outputs[0], ff1_outputs[0], pairwise=True, name='negative')
positive_outputs = dot_product(ff2_output, ff2_output, pairwise=False, name='positive')

model = tf.keras.Model(inputs=inputs, outputs=[negative_outputs, positive_outputs])

[<tf.Tensor 'Mean:0' shape=(4, 768) dtype=float32>,
 <tf.Tensor 'Mean_1:0' shape=(4, 768) dtype=float32>,
 <tf.Tensor 'Mean_2:0' shape=(4, 768) dtype=float32>,
 <tf.Tensor 'Mean_3:0' shape=(4, 768) dtype=float32>,
 <tf.Tensor 'Mean_4:0' shape=(4, 768) dtype=float32>]

In [0]:
def positive_loss(_, y_pred):
  y_pred = tf.nn.sigmoid(y_pred)
  y_true = tf.ones([batch_size, 1])
  return tf.losses.binary_crossentropy(y_true, y_pred)

def negative_loss(_, y_pred):
  y_pred = tf.nn.softmax(y_pred)
  y_true = tf.eye(batch_size)
  return tf.losses.categorical_crossentropy(y_true, y_pred)

In [0]:
loss_dict = {
    'tf_op_layer_positive': positive_loss,
    'tf_op_layer_negative': negative_loss
}

model.compile(loss=loss_dict, optimizer=tf.keras.optimizers.Adam(lr))

In [0]:
callbacks_list = [tf.keras.callbacks.ModelCheckpoint(model_dir + '/weights.{epoch:02d}', save_weights_only=True)]
model.fit(dataset, epochs=epochs, callbacks=callbacks_list)