In [1]:
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Lambda, Dense, Activation, Concatenate
from transformers import TFBertModel
print('TensorFlow:', tf.__version__)

I1221 21:22:02.144697 4524287424 file_utils.py:35] PyTorch version 1.2.0 available.
I1221 21:22:02.145673 4524287424 file_utils.py:48] TensorFlow version 2.0.0 available.


TensorFlow: 2.0.0


In [2]:
def get_random_title():
    return tf.random.uniform(shape=[512], maxval=200, dtype=tf.int32)

def get_random_citation():
    return tf.random.uniform(shape=[768], minval=-1, maxval=1, dtype=tf.float32)


def generate_positive_sample():
    title = get_random_title()
    posCitations = tf.stack([get_random_citation() for _ in range(4)], axis=0)
    return title, posCitations

def generate_negative_sample():
    title = get_random_title()
    negCitations = tf.stack([get_random_citation() for _ in range(16)], axis=0)
    return title, negCitations

def merge_datasets(pos, neg):
    title = pos[0]
    posCite = pos[1]
    negCite = neg[1]
    
    posLabels = tf.ones(shape=[4])
    negLabels = tf.zeros(shape=[16])
    labels = tf.concat([posLabels, negLabels], axis=0)
    
    return (title, posCite, negCite), labels

In [3]:
def create_model():
    textInputs = tf.keras.Input(shape=(512,), dtype=tf.int32)
    posCite = tf.keras.Input(shape=(4, 768))
    negCite = tf.keras.Input(shape=(16, 768))
    
    bert_model = TFBertModel.from_pretrained('bert-base-uncased')
    
    textOut = bert_model(textInputs)
    textOutMean = tf.reduce_mean(textOut[0], axis=1)
    textSim = Dense(units=768, activation='tanh', name='DenseText')(textOutMean)
    textSim = tf.tile(tf.expand_dims(textSim, axis=1), multiples=[1, 20, 1])

    Cite = tf.concat([posCite, negCite], axis=1)
    CiteSim = Dense(units=768, activation='tanh', name='DenseCite')(Cite)

    logits = tf.reduce_sum(tf.multiply(textSim, CiteSim), axis=2)
    return tf.keras.Model(inputs=[textInputs, posCite, negCite], outputs=[logits])

In [5]:
batch_size = 3
model = create_model()
model.compile(loss=tf.losses.BinaryCrossentropy(from_logits=True),
              optimizer=tf.optimizers.Adam())

In [6]:
posDataset = tf.data.Dataset.range(5000)
posDataset = posDataset.shuffle(512)
posDataset = posDataset.map(lambda _ : generate_positive_sample())
print(tf.data.experimental.get_structure(posDataset))

negDataset = tf.data.Dataset.range(5000)
negDataset = negDataset.shuffle(512)
negDataset = negDataset.map(lambda _ : generate_negative_sample())
print(tf.data.experimental.get_structure(negDataset))


dataset = tf.data.Dataset.zip((posDataset, negDataset))
dataset = dataset.map(merge_datasets)
dataset = dataset.batch(batch_size, drop_remainder=True)
print(tf.data.experimental.get_structure(dataset))

(TensorSpec(shape=(512,), dtype=tf.int32, name=None), TensorSpec(shape=(4, 768), dtype=tf.float32, name=None))
(TensorSpec(shape=(512,), dtype=tf.int32, name=None), TensorSpec(shape=(16, 768), dtype=tf.float32, name=None))
((TensorSpec(shape=(3, 512), dtype=tf.int32, name=None), TensorSpec(shape=(3, 4, 768), dtype=tf.float32, name=None), TensorSpec(shape=(3, 16, 768), dtype=tf.float32, name=None)), TensorSpec(shape=(3, 20), dtype=tf.float32, name=None))


In [7]:
model.fit(dataset, epochs=5, steps_per_epoch=20)

Train for 20 steps
Epoch 1/5


W1221 21:22:30.223130 4524287424 optimizer_v2.py:1029] Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.
W1221 21:22:35.219227 4524287424 optimizer_v2.py:1029] Gradients do not exist for variables ['tf_bert_model/bert/pooler/dense/kernel:0', 'tf_bert_model/bert/pooler/dense/bias:0'] when minimizing the loss.




KeyboardInterrupt: 