In [2]:
import tensorflow as tf
from transformers import *

In [3]:
class TFBertForRE(TFBertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.num_labels = config.num_labels

        self.bert = TFBertMainLayer(config, name="bert")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.ent_dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
        self.classifier = tf.keras.layers.Dense(
            config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
        )

    def call(self, inputs, **kwargs):

        outputs = self.bert(inputs, **kwargs)

        pooled_output = outputs[1] # last layer hidden-state of the first token of the sequence [CLS] (batch_size, hidden_size)
        sequence_output = outputs[0] # sequence of hidden-states at the output of the last layer of the model (batch_size, sequence_length, hidden_size)

        def extract_entity(sequence_output, e_mask):
            extended_e_mask = tf.expand_dims(e_mask, 1) # shape (batch_size, 1, sequence_length)
            ext_entity = tf.matmul(extended_e_mask, sequence_output) # shape (batch_size, 1, hidden_size)
            return tf.squeeze(ext_entity, [1]) # shape (batch_size, hidden_size)

        e1_mask = inputs.get("e1_mask", e1_mask)
        e2_mask = inputs.get("e2_mask", e2_mask)

        e1_h = self.ent_dropout(extract_entity(sequence_output, e1_mask))
        e2_h = self.ent_dropout(extract_entity(sequence_output, e2_mask))
        context = self.dropout(pooled_output, training=kwargs.get("training", False))
        # now concatenate 3 tensors:
        pooled_output = tf.concat([context,e1_h,e2_h], -1)

        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        return outputs  # logits, (hidden_states), (attentions)