In [1]:
import math
import os
from typing import Optional, Tuple
import numpy as np

import torch
from torch import nn
from torch.nn import CrossEntropyLoss

In [2]:
import sys
sys.path.append("..\\parser")
import conll04_parser

In [3]:
from transformers import BertModel
from transformers.modeling_outputs import ModelOutput

In [5]:
class MreOutput(ModelOutput):
    """
    Base class for outputs of token classification models.

    Args:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
            Classification loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, max_entities ** 2, num_labels)`):
            Classification scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
            sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

In [6]:
class BertForMre(nn.Module):

    def __init__(
        self, 
        num_labels,  
        model_name = "bert-base-uncased"
    ):
        super(BertForMre, self).__init__()
        self.num_labels = num_labels
        self.bert = BertModel.from_pretrained(model_name)
        self.bert.train() # Set BERT to training mode
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size * 2, num_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        e1_mask=None,
        e2_mask=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
    ):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        
        logits = None
        if e1_mask is not None and e2_mask is not None:
            num_relation = e1_mask.shape[0]
            seq_length = e1_mask.shape[1]
            sequence_output = torch.stack([sequence_output] * num_relation, dim=1)
            
            e1_mask = torch.reshape(e1_mask, [-1, num_relation, seq_length, 1])
            e1 = torch.mul(sequence_output, e1_mask.float())
            e1 = torch.sum(e1, dim=-2) / torch.clamp(torch.sum(e1_mask.float(), dim=-2), min=1.0)
            e1 = torch.reshape(e1, [-1, self.bert.config.hidden_size])
            
            e2_mask = torch.reshape(e2_mask, [-1, num_relation, seq_length, 1])
            e2 = torch.mul(sequence_output, e2_mask.float())
            e2 = torch.sum(e2, dim=-2) / torch.clamp(torch.sum(e2_mask.float(), dim=-2), min=1.0)
            e2 = torch.reshape(e2, [-1, self.bert.config.hidden_size])
            
            sequence_output = torch.cat([e1, e2], dim=-1)
            logits = self.classifier(sequence_output)

        loss = None
        if logits is not None and labels is not None:
            # print(logits.type(), labels.type())
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.long().view(-1))

        if not return_dict:
            output = (logits,) + outputs[2:] if logits is not None else outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return MreOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [7]:
model = BertForMre(7)

In [8]:
model

BertForMre(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

In [9]:
# Freeze all layers except for the last classifier layer on top
for param in model.parameters():
    param.requires_grad = False
model.classifier.weight.requires_grad = True
model.classifier.bias.requires_grad = True

In [10]:
for param in model.parameters():
    print("size:", param.shape)
    print(param.requires_grad)

size: torch.Size([30522, 768])
False
size: torch.Size([512, 768])
False
size: torch.Size([2, 768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([3072, 768])
False
size: torch.Size([3072])
False
size: torch.Size([768, 3072])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
si

In [11]:
def generate_entity_mask(sequence_length, entity_position, relations):
    relation_count = len(entity_position) * (len(entity_position) - 1)
    e1_mask = torch.zeros((relation_count, sequence_length))
    e2_mask = torch.zeros((relation_count, sequence_length))
    labels = torch.zeros(relation_count)
    i = 0
    for e1 in entity_position:
        for e2 in entity_position:
            if e1 != e2:
                l1, h1 = entity_position[e1]
                l2, h2 = entity_position[e2]
                e1_mask[i, l1:h1] = 1
                e2_mask[i, l2:h2] = 1
                for relation in relations:
                    if relations[relation]["source"] == e1 and relations[relation]["target"] == e2:
                        labels[i] = relations[relation]["type"]
                i += 1
    return e1_mask, e2_mask, labels

In [12]:
# Test generate_entity_mask()
docs = conll04_parser.get_docs("train")
extracted_doc = conll04_parser.extract_doc(docs[0])
assert extracted_doc["document"] == "1024"
e1_mask, e2_mask, labels = generate_entity_mask(
    extracted_doc["data_frame"].shape[0], 
    extracted_doc["entity_position"], 
    extracted_doc["relations"]
)

assert np.array_equal(e1_mask, np.array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], 
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

assert np.array_equal(e2_mask, np.array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]))

assert np.array_equal(labels, np.array([0., 2., 0., 2., 0., 0.]))

In [13]:
from transformers import AdamW
optimizer = AdamW(model.parameters(), lr=1e-5)

In [14]:
outputs = model(
    torch.tensor([extracted_doc["data_frame"]["token_ids"]]), 
    e1_mask=e1_mask, 
    e2_mask=e2_mask, 
    labels=labels
)

In [15]:
loss = outputs.loss
loss.backward()
optimizer.step()

In [16]:
loss

tensor(2.1971, grad_fn=<NllLossBackward>)