In [1]:
import torch
from torch import nn as nn
from transformers import BertConfig
from transformers import BertModel
from transformers import BertPreTrainedModel

In [2]:
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
import torch.nn.functional as F

In [3]:
class SpERT(BertPreTrainedModel):
    """ Span-based model to jointly extract entities and relations """

    def __init__(self, config: BertConfig, relation_types: int, entity_types: int, 
                 width_embedding_size: int, prop_drop: float, freeze_transformer: bool, max_pairs: int, 
                 is_overlapping: bool, relation_filter_threshold: float):
        super(SpERT, self).__init__(config)

        # BERT model
        self.bert = BertModel(config)

        # layers
        self.relation_classifier = nn.Linear(config.hidden_size * 3 + width_embedding_size * 2, relation_types)
        self.entity_classifier = nn.Linear(config.hidden_size * 2 + width_embedding_size, entity_types)
        self.width_embedding = nn.Embedding(100, width_embedding_size)
        self.dropout = nn.Dropout(prop_drop)

        self._hidden_size = config.hidden_size
        self._relation_types = relation_types
        self._entity_types = entity_types
        self._relation_filter_threshold = relation_filter_threshold
        self._max_pairs = max_pairs
        self._is_overlapping = is_overlapping # whether overlapping entities are allowed

        # weight initialization
        self.init_weights()

        if freeze_transformer:
            # freeze all transformer weights
            for param in self.bert.parameters():
                param.requires_grad = False
                
                        
    def _classify_entity(self, token_embedding, width_embedding, cls_embedding, entity_mask, entity_label):
        """
        INPUT:
        token_embedding.shape = (sentence_length, hidden_size)
        width_embedding.shape = (entity_count, width_embedding_size)
        cls_embedding.shape = (1, hidden_size)
        entity_mask.shape = (entity_count, sentence_length)
        entity_label.shape = (entity_count,)
        
        RETURN:
        entity_logit.shape = (entity_count, self._entity_types)
        entity_loss -> scala
        entity_pred.shape = (entity_count,)
        """
        sentence_length = token_embedding.shape[0]
        hidden_size = token_embedding.shape[1]
        entity_count = entity_mask.shape[0]
        
        entity_embedding = torch.mul(token_embedding.view(1, sentence_length, hidden_size), 
                                     entity_mask.view(entity_count, sentence_length, 1))
        
        entity_embedding = entity_embedding.max(dim=-2)[0] # maxpool
        
        entity_embedding = torch.cat([entity_embedding, 
                                      width_embedding, 
                                      cls_embedding.repeat(entity_count, 1)], dim=1)
        
        entity_logit = self.entity_classifier(entity_embedding)
        entity_loss = None
        if entity_label != None:
            # If entity labels are provided, calculate cross entropy loss and take the average over all samples
            # Refer to the paper
            loss_fct = CrossEntropyLoss(reduction='mean')
            entity_loss = loss_fct(entity_logit, entity_label)
        entity_pred = F.softmax(entity_logit, dim=-1).argmax(dim=-1).long()
        
        return entity_logit, entity_loss, entity_pred 
    
    
    def _filter_span(self, entity_mask: torch.tensor, entity_pred: torch.tensor):
        entity_count = entity_mask.shape[0]
        sentence_length = entity_mask.shape[1]
        entity_span = []
        entity_embedding = torch.zeros((sentence_length,)) if not self._is_overlapping else None
        
        for i in range(entity_count):
            if entity_pred[i] != 0:
                begin = torch.argmax(entity_mask[i]).item()
                end = sentence_length - torch.argmax(entity_mask[i].flip(0)).item()
                
                assert end > begin
                assert entity_mask[i, begin:end].sum() == end - begin
                
                if self._is_overlapping:
                    entity_span.append((begin, end, entity_pred[i].item()))
                elif not self._is_overlapping and entity_embedding[begin:end].sum() == 0:
                    entity_span.append((begin, end, entity_pred[i].item()))
                    entity_embedding[begin:end] = entity_pred[i]
        
        return entity_span, entity_embedding
    
    
    def _generate_relation_mask(self, entity_span, sentence_length):
        relation_mask = []
        for e1 in entity_span:
            for e2 in entity_span:
                c = (min(e1[1], e2[1]), max(e1[0], e2[0]))
                if c[1] > c[0]:
                    template = [0] * sentence_length
                    template[e1[0]: e1[1]] = [1] * (e1[1] - e1[0])
                    template[e2[0]: e2[1]] = [2] * (e2[1] - e2[0])
                    template[c[0]: c[1]] = [3] * (c[1] - c[0])
                    relation_mask.append(template)        
        return torch.tensor(relation_mask, dtype=torch.long)
    
    
    def _classify_relation(self, token_embedding, e1_width_embedding, e2_width_embedding, 
                           relation_mask, relation_label):
        """
        INPUT:
        token_embedding.shape = (sentence_length, hidden_size)
        e1_width_embedding.shape = (relation_count, width_embedding_size)
        e2_width_embedding.shape = (relation_count, width_embedding_size)
        relation_mask.shape = (relation_count, sentence_length)
        relation_label.shape = (relation_count,)
        
        RETURN:
        relation_logit.shape = (relation_count, self._relation_types)
        relation_loss -> scala
        relation_pred.shape = (relation_count,)
        """
        sentence_length = token_embedding.shape[0]
        hidden_size = token_embedding.shape[1]
        relation_count = relation_mask.shape[0]
        
        e1_embedding = torch.mul(token_embedding.view(1, sentence_length, hidden_size), 
                                 (relation_mask == 1).view(relation_count, sentence_length, 1))
        e1_embedding = e1_embedding.max(dim=-2)[0] # maxpool
        
        e2_embedding = torch.mul(token_embedding.view(1, sentence_length, hidden_size), 
                                 (relation_mask == 2).view(relation_count, sentence_length, 1))
        e2_embedding = e2_embedding.max(dim=-2)[0] # maxpool
        
        c_embedding = torch.mul(token_embedding.view(1, sentence_length, hidden_size), 
                                 (relation_mask == 3).view(relation_count, sentence_length, 1))
        c_embedding = c_embedding.max(dim=-2)[0] # maxpool
        
        relation_embedding = torch.cat([e1_embedding, e1_width_embedding,
                                        c_embedding,
                                        e2_embedding, e2_width_embedding], dim=1)
        
        relation_logit = self.relation_classifier(relation_embedding)
        relation_loss = None
        if relation_label != None:
            # If relation labels are provided, calculate the binary cross entropy loss 
            # and take the sum over all samples
            loss_fct = BCEWithLogitsLoss(reduction='sum')
            onehot_relation_label = F.one_hot(relation_label, num_classes=self._relation_types).float()
            relation_loss = loss_fct(relation_logit, onehot_relation_label)
            
        relation_softmax = F.softmax(relation_logit, dim=-1)
        # Filter out low confident relations
        relation_softmax[relation_softmax < self._relation_filter_threshold] = 0
        relation_pred = relation_softmax.argmax(dim=-1).long()
        
        return relation_logit, relation_loss, relation_pred 
    
    
    def _filter_relation(self, relation_mask: torch.tensor, relation_pred: torch.tensor):
        relation_count = relation_mask.shape[0]
        sentence_length = relation_mask.shape[1]
        relation_span = []
        
        for i in range(relation_count):
            if relation_pred[i] != 0:
                e1_begin = torch.argmax((relation_mask[i] == 1).long()).item()
                e1_end = sentence_length - torch.argmax((relation_mask[i].flip(0) == 1).long()).item()
                
                assert e1_end > e1_begin
                assert relation_mask[i, e1_begin:e1_end].sum() == (e1_end - e1_begin) * 1
                
                e2_begin = torch.argmax((relation_mask[i] == 2).long()).item()
                e2_end = sentence_length - torch.argmax((relation_mask[i].flip(0) == 2).long()).item()
                
                assert e2_end > e2_begin
                assert relation_mask[i, e2_begin:e2_end].sum() == (e2_end - e2_begin) * 2
                
                relation_span.append((e1_begin, e1_end, e2_begin, e2_end, relation_pred[i].item()))
        
        return relation_span
    
                
    def forward(self, input_ids: torch.tensor, attention_mask: torch.tensor, token_type_ids: torch.tensor, 
                entity_mask: torch.tensor = None, entity_label: torch.tensor = None, 
                relation_mask: torch.tensor = None, relation_label: torch.tensor = None,
                is_training: bool = True):
            
        # get the last hidden layer from BERT
        bert_embedding = self.bert(input_ids=input_ids, 
                                   attention_mask=attention_mask, 
                                   token_type_ids=token_type_ids)['last_hidden_state']
        
        # get the CLS and other tokens embedding
        bert_embedding = torch.reshape(bert_embedding, (-1, self._hidden_size))
        cls_embedding = bert_embedding[:1] # CLS is the first element
        token_embedding = bert_embedding[1:-1] # everything except CLS and SEP at both ends
        
        # get the width embedding for each entity length
        width_embedding = self.width_embedding(torch.sum(entity_mask, dim=-1))
        entity_logit, entity_loss, entity_pred \
            = self._classify_entity(token_embedding, width_embedding, cls_embedding, entity_mask, entity_label)
        
        entity_span, entity_embedding = self._filter_span(entity_mask, entity_pred)
        
        # if not relation_mask then generate them from pairs of entities
        # only for prediction and evaluation
        if not is_training or relation_mask == None:
            relation_mask = self._generate_relation_mask(entity_span, token_embedding.shape[0])
            relation_label = None
        
        # return immediately if there is no relations to predict (e.g. there are less than 2 entities)
        output = {
            "loss": entity_loss,
            "entity": {
                "logit": entity_logit,
                "pred": entity_pred,
                "span": entity_span,
                "embedding": entity_embedding
            },
            "relation": None
        }
        if relation_mask == None or torch.equal(relation_mask, torch.tensor([], dtype=torch.long)):
            return output
        
        relation_count = relation_mask.shape[0]
        relation_logit = torch.zeros((relation_count, self._relation_types))
        relation_loss = []
        relation_pred = torch.zeros((relation_count,), dtype=torch.long)
        e1_width_embedding = self.width_embedding(torch.sum(relation_mask == 1, dim=-1))
        e2_width_embedding = self.width_embedding(torch.sum(relation_mask == 2, dim=-1))
        
        # break down relation_mask (list of possible relations) to smaller chunks
        for i in range(0, relation_count, self._max_pairs):
            j = min(relation_count, i + self._max_pairs)
            logit, loss, pred = self._classify_relation(token_embedding, 
                                                        e1_width_embedding[i: j], 
                                                        e2_width_embedding[i: j], 
                                                        relation_mask[i: j], 
                                                        relation_label[i: j] if relation_label != None else None)
            relation_logit[i: j] = logit
            if loss != None:
                relation_loss.append(loss)
            relation_pred[i: j] = pred
        # relation loss is the average of binary cross entropy loss of each sample
        # refer to the paper
        relation_loss = None if len(relation_loss) == 0 else (sum(relation_loss) / float(relation_count))
        relation_span = self._filter_relation(relation_mask, relation_pred)
        # Final loss is the sum of entity_loss and relation_loss
        if relation_loss != None: 
            if output["loss"] == None: 
                output["loss"] = relation_loss
            else:
                output["loss"] += relation_loss
        output["relation"] = {
            "logit": relation_logit,
            "pred": relation_pred,
            "span": relation_span
        }
        return output

In [4]:
import conll04_constants as constants
import conll04_input_generator as input_generator

In [5]:
config = BertConfig.from_pretrained(constants.model_path)
model = SpERT.from_pretrained(constants.model_path,
                              config=config,
                              # SpERT model parameters
                              relation_types=6, 
                              entity_types=5, 
                              width_embedding_size=constants.width_embedding_size, 
                              prop_drop=constants.prop_drop, 
                              freeze_transformer=True, 
                              max_pairs=constants.max_pairs, 
                              is_overlapping=False, 
                              relation_filter_threshold=constants.relation_filter_threshold)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['relation_classifier.weight', 'relation_classifie

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

SpERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      

In [None]:
generator = input_generator.data_generator("train", device)
inputs = next(generator)
inputs

In [8]:
outputs = model(**inputs, is_training=True)
outputs

{'loss': tensor(6.0548, grad_fn=<AddBackward0>),
 'entity': {'logit': tensor([[ 2.8177e-01,  1.2669e-02, -2.2483e-02,  7.4377e-01,  6.4114e-01],
          [ 3.6579e-01,  7.8646e-04, -2.0291e-01,  7.8411e-01,  4.7328e-01],
          [-3.8452e-02, -1.8831e-01, -2.1909e-01,  8.1786e-01,  8.2508e-01],
          [ 4.6173e-01, -1.6989e-01, -3.5750e-01,  1.1920e+00,  6.0526e-01],
          [ 9.0851e-02,  2.4033e-01, -3.4296e-01,  5.8736e-01,  6.9949e-01],
          [ 2.0702e-01,  8.6041e-02, -3.4433e-01,  9.3151e-01,  8.5519e-01],
          [ 1.5012e-01, -3.1589e-02, -1.3847e-01,  7.2958e-01,  8.9096e-01],
          [ 1.3835e-01, -1.9620e-01, -4.9140e-01,  8.0954e-01,  2.5814e-01],
          [ 1.7310e-01, -1.4149e-02, -3.5881e-01,  9.3199e-01,  8.7219e-01],
          [ 1.6613e-01, -1.6769e-01, -6.4479e-01,  7.1199e-01,  5.1751e-01],
          [ 4.6173e-01,  9.6146e-02, -6.3711e-03,  8.2203e-01,  8.5123e-01],
          [ 1.7755e-01, -2.7787e-01, -4.2383e-01,  7.4441e-01,  5.2356e-01],
        

In [9]:
outputs = model(**inputs, is_training=False)
outputs

{'loss': tensor(1.7337, grad_fn=<NllLossBackward>),
 'entity': {'logit': tensor([[ 2.8177e-01,  1.2669e-02, -2.2483e-02,  7.4377e-01,  6.4114e-01],
          [ 3.6579e-01,  7.8646e-04, -2.0291e-01,  7.8411e-01,  4.7328e-01],
          [-3.8452e-02, -1.8831e-01, -2.1909e-01,  8.1786e-01,  8.2508e-01],
          [ 4.6173e-01, -1.6989e-01, -3.5750e-01,  1.1920e+00,  6.0526e-01],
          [ 9.0851e-02,  2.4033e-01, -3.4296e-01,  5.8736e-01,  6.9949e-01],
          [ 2.0702e-01,  8.6041e-02, -3.4433e-01,  9.3151e-01,  8.5519e-01],
          [ 1.5012e-01, -3.1589e-02, -1.3847e-01,  7.2958e-01,  8.9096e-01],
          [ 1.3835e-01, -1.9620e-01, -4.9140e-01,  8.0954e-01,  2.5814e-01],
          [ 1.7310e-01, -1.4149e-02, -3.5881e-01,  9.3199e-01,  8.7219e-01],
          [ 1.6613e-01, -1.6769e-01, -6.4479e-01,  7.1199e-01,  5.1751e-01],
          [ 4.6173e-01,  9.6146e-02, -6.3711e-03,  8.2203e-01,  8.5123e-01],
          [ 1.7755e-01, -2.7787e-01, -4.2383e-01,  7.4441e-01,  5.2356e-01],
     

In [10]:
del inputs["relation_mask"]
del inputs["relation_label"]
outputs = model(**inputs, is_training=False)
outputs

{'loss': tensor(1.7337, grad_fn=<NllLossBackward>),
 'entity': {'logit': tensor([[ 2.8177e-01,  1.2669e-02, -2.2483e-02,  7.4377e-01,  6.4114e-01],
          [ 3.6579e-01,  7.8646e-04, -2.0291e-01,  7.8411e-01,  4.7328e-01],
          [-3.8452e-02, -1.8831e-01, -2.1909e-01,  8.1786e-01,  8.2508e-01],
          [ 4.6173e-01, -1.6989e-01, -3.5750e-01,  1.1920e+00,  6.0526e-01],
          [ 9.0851e-02,  2.4033e-01, -3.4296e-01,  5.8736e-01,  6.9949e-01],
          [ 2.0702e-01,  8.6041e-02, -3.4433e-01,  9.3151e-01,  8.5519e-01],
          [ 1.5012e-01, -3.1589e-02, -1.3847e-01,  7.2958e-01,  8.9096e-01],
          [ 1.3835e-01, -1.9620e-01, -4.9140e-01,  8.0954e-01,  2.5814e-01],
          [ 1.7310e-01, -1.4149e-02, -3.5881e-01,  9.3199e-01,  8.7219e-01],
          [ 1.6613e-01, -1.6769e-01, -6.4479e-01,  7.1199e-01,  5.1751e-01],
          [ 4.6173e-01,  9.6146e-02, -6.3711e-03,  8.2203e-01,  8.5123e-01],
          [ 1.7755e-01, -2.7787e-01, -4.2383e-01,  7.4441e-01,  5.2356e-01],
     

In [11]:
# inputs = next(generator)
# inputs

{'input_ids': tensor([[  101,  3680,  1012,  3123,  5086, 14863,  4783,  2147,  2063,  1010,
           2028,  1997, 11154,  1005,  1055,  2087,  5182,  9559,  1998,  1037,
           2280,  2266,  1997,  1996,  2406,  1005,  1055,  2152,  2457,  1010,
           2000,  8556,  1012,   102]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'entity_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'entity_label': tensor([2, 2, 1, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
 

In [13]:
# del inputs["entity_label"]
# outputs = model(**inputs, is_training=True)
# outputs

{'loss': tensor(4.6240, grad_fn=<DivBackward0>),
 'entity': {'logit': tensor([[ 0.5819,  0.0741, -0.0334,  0.3927,  0.7033],
          [ 0.2299,  0.1545,  0.0303,  0.1535,  0.7243],
          [ 0.2250,  0.0322, -0.2512,  0.0519,  0.5079],
          [ 0.4887, -0.3116, -0.2378,  0.3460,  0.3862],
          [ 0.4297, -0.1324, -0.1910,  0.2687,  0.8742],
          [ 0.3672,  0.3250,  0.0173,  0.1853,  0.7590],
          [ 0.5862,  0.2360, -0.3061,  0.3153,  0.9475],
          [ 0.5526, -0.0358, -0.2751,  0.3291,  0.5196],
          [ 0.1845, -0.1523, -0.2466,  0.3240,  0.5028],
          [ 0.4602, -0.0172, -0.1951,  0.2573,  0.5737],
          [ 0.3466, -0.2835, -0.1531,  0.2155,  0.7509],
          [ 0.3280, -0.2454, -0.0652,  0.1407,  0.8799],
          [ 0.4998, -0.1880, -0.2143,  0.2748,  0.9990],
          [ 0.4512, -0.1586, -0.2129,  0.2846,  0.8072],
          [ 0.5344, -0.0406, -0.3060,  0.3799,  0.9146],
          [ 0.4092, -0.5269, -0.3275,  0.2977,  0.5559],
          [ 0.5125, 