In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import numpy as np
import json
import re
import time
from tqdm import tqdm
from collections import deque

batch_size = 16
transformers_path = "../hfl/chinese-roberta-wwm-ext"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class CustomDataset(Dataset):
    def __init__(self, train_mode: bool, transformers_path, max_len=128):
        self.train_mode = train_mode
        self.max_len = max_len
        self.tokenizer = BertTokenizer.from_pretrained(transformers_path)
        data_file_path = "./data/train_data.json" if self.train_mode else "./data/test_data.json"
        with open(data_file_path, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        self.dataset_length = len(self.dataset)
        print(f"load {'train' if self.train_mode else 'test'} dataset size: {self.dataset_length}")
        with open("./data/entity_types.json", "r", encoding="utf-8") as f:
            self.entity_types = json.load(f)
            self.num_entity = len(self.entity_types)
            self.id2entity = {idx: x for idx, x in enumerate(self.entity_types)}
            self.entity2id = {x: idx for idx, x in enumerate(self.entity_types)}
        with open("./data/relation_types.json", "r", encoding="utf-8") as f:
            self.relation_types = json.load(f)
            self.num_relation = len(self.relation_types)
            self.id2relation = {idx: x for idx, x in enumerate(self.relation_types)}
            self.relation2id = {x: idx for idx, x in enumerate(self.relation_types)}

    def _encode(self, sample):
        input_text = sample["text"]
        input_text_tokens = ['[CLS]'] + list(input_text)
        input_text_tokens = input_text_tokens[:self.max_len - 1] + ['[SEP]']
        input_text_tokens = input_text_tokens + ['[PAD]' for _ in range(self.max_len - len(input_text_tokens))]
        input_ids = self.tokenizer.convert_tokens_to_ids(input_text_tokens)
        attention_mask = [1 if token != '[PAD]' else 0 for token in input_text_tokens]
        token_type_ids = [0 for token in input_text_tokens]

        _input_encoding = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids
        }
        _input_encoding = {k: torch.tensor(v, dtype=torch.long) for k, v in _input_encoding.items()}
        return _input_encoding

    def _calc_target_matrix(self, sample):
        eh_et_matrix = np.zeros([self.num_entity, self.max_len - 2, self.max_len - 2], dtype=int)  # entity_head to entity_tail
        sh_oh_matrix = np.zeros([self.num_relation, self.max_len - 2, self.max_len - 2], dtype=int)  # subject_head to object_head
        st_ot_matrix = np.zeros([self.num_relation, self.max_len - 2, self.max_len - 2], dtype=int)  # subject_tail to object_tail
        input_text = sample["text"]
        cache_entity = dict()
        for entity in sample["entity_list"]:
            entity_text = entity["text"]
            entity_type = entity["type"]
            entity_type_id = self.entity2id[entity_type]
            spans = [re_result.span() for re_result in re.finditer(re.escape(entity_text), input_text)]
            spans = [(x1, x2 - 1) for x1, x2 in spans]
            cache_entity[entity_text] = spans
            for (x1, x2) in spans:
                if x1 < self.max_len - 3 and x2 < self.max_len - 3:
                    eh_et_matrix[entity_type_id, x1, x2] = 1

        for relation in sample["relation_list"]:
            subject_text = relation["subject"]
            object_text = relation["object"]
            relation_type = relation["predicate"]
            relation_type_id = self.relation2id[relation_type]
            for (sx1, sx2) in cache_entity[subject_text]:
                for (ox1, ox2) in cache_entity[object_text]:
                    if sx2 < self.max_len - 3 and ox2 < self.max_len - 3:
                        sh_oh_matrix[relation_type_id, sx1, ox1] = 1
                        st_ot_matrix[relation_type_id, sx2, ox2] = 1
        return eh_et_matrix, sh_oh_matrix, st_ot_matrix

    def __getitem__(self, idx):
        sample = self.dataset[idx]
        input_encoding = self._encode(sample)
        eh_et_matrix, sh_oh_matrix, st_ot_matrix = self._calc_target_matrix(sample)
        return input_encoding, eh_et_matrix, sh_oh_matrix, st_ot_matrix

    def __len__(self):
        return self.dataset_length


class GlobalPointer(nn.Module):
    def __init__(self, num_class, d_model):
        super(GlobalPointer, self).__init__()
        self.num_class = num_class
        self.proj_q = nn.Linear(d_model, num_class * d_model)
        self.proj_k = nn.Linear(d_model, num_class * d_model)

    def forward(self, embedding: torch.Tensor, mask, mask_tri: bool):
        # embedding shape [batch, seq_len, d_model]
        (batch, seq_len, d_model) = embedding.shape
        q = self.proj_q(embedding).reshape([batch, seq_len, self.num_class, d_model])
        k = self.proj_k(embedding).reshape([batch, seq_len, self.num_class, d_model])
        tag_matrix = torch.einsum("bmcd,bncd->bcmn", q, k)

        mask_seq = torch.einsum("bm,bn->bmn", mask, mask).unsqueeze(dim=-3)
        if mask_tri:
            mask_tri = torch.triu(torch.ones_like(tag_matrix))
            mask = torch.logical_and(mask_seq, mask_tri)
        else:
            mask = mask_seq
        tag_matrix = torch.masked_fill(tag_matrix, torch.logical_not(mask), -1e4)
        return tag_matrix


class CustomModel(nn.Module):
    def __init__(self, transformers_path, num_entity, num_relation):
        super(CustomModel, self).__init__()
        self.bert_module = BertModel.from_pretrained(transformers_path)
        self.eh_et_pointer = GlobalPointer(num_class=num_entity, d_model=self.bert_module.config.hidden_size)
        self.sh_oh_pointer = GlobalPointer(num_class=num_relation, d_model=self.bert_module.config.hidden_size)
        self.st_ot_pointer = GlobalPointer(num_class=num_relation, d_model=self.bert_module.config.hidden_size)

    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert_module(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        eh_et_matrix = self.eh_et_pointer(bert_output, attention_mask, True)  # entity_head to entity_tail
        sh_oh_matrix = self.sh_oh_pointer(bert_output, attention_mask, False)  # subject_head to object_head
        st_ot_matrix = self.st_ot_pointer(bert_output, attention_mask, False)  # subject_tail to object_tail
        eh_et_matrix = eh_et_matrix[:, :, 1:-1, 1:-1]
        sh_oh_matrix = sh_oh_matrix[:, :, 1:-1, 1:-1]
        st_ot_matrix = st_ot_matrix[:, :, 1:-1, 1:-1]
        return eh_et_matrix, sh_oh_matrix, st_ot_matrix


def train(model, dataloader, epoch, scaler, optimizer):
    time.sleep(0.2)
    model.train()
    loss_count = deque([], maxlen=100)
    entity_tp_count = deque([], maxlen=100)
    entity_fp_count = deque([], maxlen=100)
    entity_fn_count = deque([], maxlen=100)
    relation_tp_count = deque([], maxlen=100)
    relation_fp_count = deque([], maxlen=100)
    relation_fn_count = deque([], maxlen=100)
    pbar = tqdm(dataloader)
    pbar.set_description("train epoch {}".format(epoch))
    for input_encoding, eh_et_matrix_target, sh_oh_matrix_target, st_ot_matrix_target in pbar:
        optimizer.zero_grad()
        input_encoding = {k: v.to(device) for k, v in input_encoding.items()}  # [4, 128]
        eh_et_matrix_target, sh_oh_matrix_target, st_ot_matrix_target = \
            eh_et_matrix_target.to(device), sh_oh_matrix_target.to(device), st_ot_matrix_target.to(device)  # [4, 26, 126, 126]
        with torch.cuda.amp.autocast():
            eh_et_matrix_predict, sh_oh_matrix_predict, st_ot_matrix_predict = model(**input_encoding)
            bce_loss_1 = F.binary_cross_entropy_with_logits(eh_et_matrix_predict, eh_et_matrix_target.float())
            bce_loss_2 = F.binary_cross_entropy_with_logits(sh_oh_matrix_predict, sh_oh_matrix_target.float())
            bce_loss_3 = F.binary_cross_entropy_with_logits(st_ot_matrix_predict, st_ot_matrix_target.float())
            loss = bce_loss_1 + bce_loss_2 + bce_loss_3
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        loss_count.append(loss.item())
        log_loss = np.mean(loss_count)

        eh_et_matrix_predict = torch.gt(eh_et_matrix_predict, 0)
        eh_et_matrix_target = torch.eq(eh_et_matrix_target, 1)
        entity_tp_count.append(torch.logical_and(eh_et_matrix_predict, eh_et_matrix_target).sum().item())
        entity_fp_count.append(torch.logical_and(eh_et_matrix_predict, torch.logical_not(eh_et_matrix_target)).sum().item())
        entity_fn_count.append(torch.logical_and(torch.logical_not(eh_et_matrix_predict), eh_et_matrix_target).sum().item())

        log_entity_tp, log_entity_fp, log_entity_fn = np.sum(entity_tp_count), np.sum(entity_fp_count), np.sum(entity_fn_count)
        log_entity_precision = log_entity_tp / (log_entity_tp + log_entity_fp + 1e-5)
        log_entity_recall = log_entity_tp / (log_entity_tp + log_entity_fn + 1e-5)

        sh_oh_matrix_predict = torch.gt(sh_oh_matrix_predict, 0)
        st_ot_matrix_predict = torch.gt(st_ot_matrix_predict, 0)
        sh_oh_matrix_target = torch.eq(sh_oh_matrix_target, 1)
        st_ot_matrix_target = torch.eq(st_ot_matrix_target, 1)
        relation_tp_count.append(torch.logical_and(sh_oh_matrix_predict, sh_oh_matrix_target).sum().item() +
                                 torch.logical_and(st_ot_matrix_predict, st_ot_matrix_target).sum().item())
        relation_fp_count.append(torch.logical_and(sh_oh_matrix_predict, torch.logical_not(sh_oh_matrix_target)).sum().item() +
                                 torch.logical_and(st_ot_matrix_predict, torch.logical_not(st_ot_matrix_target)).sum().item())
        relation_fn_count.append(torch.logical_and(torch.logical_not(sh_oh_matrix_predict), sh_oh_matrix_target).sum().item()+
                                 torch.logical_and(torch.logical_not(st_ot_matrix_predict), st_ot_matrix_target).sum().item())
        log_relation_tp, log_relation_fp, log_relation_fn = np.sum(relation_tp_count), np.sum(relation_fp_count), np.sum(relation_fn_count)
        log_relation_precision = log_relation_tp / (log_relation_tp + log_relation_fp + 1e-5)
        log_relation_recall = log_relation_tp / (log_relation_tp + log_relation_fn + 1e-5)

        log_str = f"loss={log_loss:0.9f}, entity_precision={log_entity_precision:0.9f}, entity_recall={log_entity_recall:0.9f}, " \
                  f"relation_precision={log_relation_precision:0.9f}, relation_recall={log_relation_recall:0.9f}"
        pbar.set_postfix_str(log_str)


def test(model, dataloader, epoch):
    time.sleep(0.2)
    model.eval()
    loss_count = []
    entity_tp_count = []
    entity_fp_count = []
    entity_fn_count = []
    relation_tp_count = []
    relation_fp_count = []
    relation_fn_count = []
    pbar = tqdm(dataloader)
    pbar.set_description("test epoch {}".format(epoch))
    for input_encoding, eh_et_matrix_target, sh_oh_matrix_target, st_ot_matrix_target in pbar:
        input_encoding = {k: v.to(device) for k, v in input_encoding.items()}  # [4, 128]
        eh_et_matrix_target, sh_oh_matrix_target, st_ot_matrix_target = \
            eh_et_matrix_target.to(device), sh_oh_matrix_target.to(device), st_ot_matrix_target.to(device)  # [4, 26, 126, 126]
        with torch.cuda.amp.autocast():
            eh_et_matrix_predict, sh_oh_matrix_predict, st_ot_matrix_predict = model(**input_encoding)
            bce_loss_1 = F.binary_cross_entropy_with_logits(eh_et_matrix_predict, eh_et_matrix_target.float())
            bce_loss_2 = F.binary_cross_entropy_with_logits(sh_oh_matrix_predict, sh_oh_matrix_target.float())
            bce_loss_3 = F.binary_cross_entropy_with_logits(st_ot_matrix_predict, st_ot_matrix_target.float())
            loss = bce_loss_1 + bce_loss_2 + bce_loss_3
        loss_count.append(loss.item())
        log_loss = np.mean(loss_count)

        eh_et_matrix_predict = torch.gt(eh_et_matrix_predict, 0)
        eh_et_matrix_target = torch.eq(eh_et_matrix_target, 1)
        entity_tp_count.append(torch.logical_and(eh_et_matrix_predict, eh_et_matrix_target).sum().item())
        entity_fp_count.append(torch.logical_and(eh_et_matrix_predict, torch.logical_not(eh_et_matrix_target)).sum().item())
        entity_fn_count.append(torch.logical_and(torch.logical_not(eh_et_matrix_predict), eh_et_matrix_target).sum().item())

        log_entity_tp, log_entity_fp, log_entity_fn = np.sum(entity_tp_count), np.sum(entity_fp_count), np.sum(entity_fn_count)
        log_entity_precision = log_entity_tp / (log_entity_tp + log_entity_fp + 1e-5)
        log_entity_recall = log_entity_tp / (log_entity_tp + log_entity_fn + 1e-5)

        sh_oh_matrix_predict = torch.gt(sh_oh_matrix_predict, 0)
        st_ot_matrix_predict = torch.gt(st_ot_matrix_predict, 0)
        sh_oh_matrix_target = torch.eq(sh_oh_matrix_target, 1)
        st_ot_matrix_target = torch.eq(st_ot_matrix_target, 1)
        relation_tp_count.append(torch.logical_and(sh_oh_matrix_predict, sh_oh_matrix_target).sum().item() +
                                 torch.logical_and(st_ot_matrix_predict, st_ot_matrix_target).sum().item())
        relation_fp_count.append(torch.logical_and(sh_oh_matrix_predict, torch.logical_not(sh_oh_matrix_target)).sum().item() +
                                 torch.logical_and(st_ot_matrix_predict, torch.logical_not(st_ot_matrix_target)).sum().item())
        relation_fn_count.append(torch.logical_and(torch.logical_not(sh_oh_matrix_predict), sh_oh_matrix_target).sum().item() +
                                 torch.logical_and(torch.logical_not(st_ot_matrix_predict), st_ot_matrix_target).sum().item())
        log_relation_tp, log_relation_fp, log_relation_fn = np.sum(relation_tp_count), np.sum(relation_fp_count), np.sum(relation_fn_count)
        log_relation_precision = log_relation_tp / (log_relation_tp + log_relation_fp + 1e-5)
        log_relation_recall = log_relation_tp / (log_relation_tp + log_relation_fn + 1e-5)

        log_str = f"loss={log_loss:0.9f}, entity_precision={log_entity_precision:0.9f}, entity_recall={log_entity_recall:0.9f}, " \
                  f"relation_precision={log_relation_precision:0.9f}, relation_recall={log_relation_recall:0.9f}"
        pbar.set_postfix_str(log_str)


if __name__ == '__main__':
    dataset_train = CustomDataset(train_mode=True, transformers_path=transformers_path)
    dataset_test = CustomDataset(train_mode=False, transformers_path=transformers_path)
    # dataloader_train = DataLoader(dataset=dataset_train, batch_size=2, shuffle=True)
    # dataloader_test = DataLoader(dataset=dataset_test, batch_size=2, shuffle=False)
    dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

    model = CustomModel(transformers_path, dataset_train.num_entity, dataset_train.num_relation)
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(100):
        train(model, dataloader_train, epoch, scaler, optimizer)
        test(model, dataloader_test, epoch)
        # torch.save(model.state_dict(), f"./model_1/model_{epoch}.pth")


load train dataset size: 18606
load test dataset size: 2068


train epoch 0: 100%|██████████| 1163/1163 [03:56<00:00,  4.91it/s, loss=0.000184903, entity_precision=0.218397290, entity_recall=0.078245046, relation_precision=0.002915452, relation_recall=0.000396197]
test epoch 0: 100%|██████████| 130/130 [00:18<00:00,  6.95it/s, loss=0.000179820, entity_precision=0.290644867, entity_recall=0.098416115, relation_precision=0.001675042, relation_recall=0.000100523]
train epoch 1: 100%|██████████| 1163/1163 [03:54<00:00,  4.95it/s, loss=0.002000586, entity_precision=0.392136023, entity_recall=0.145018667, relation_precision=0.066568046, relation_recall=0.005888511]
test epoch 1: 100%|██████████| 130/130 [00:18<00:00,  7.03it/s, loss=0.000114324, entity_precision=0.387666324, entity_recall=0.232969398, relation_precision=0.178649236, relation_recall=0.016485726]
train epoch 2: 100%|██████████| 1163/1163 [03:55<00:00,  4.93it/s, loss=0.001969838, entity_precision=0.492180710, entity_recall=0.216966679, relation_precision=0.309090907, relation_recall=0.04

KeyboardInterrupt: 

In [None]:
{
  "id": 754,
  "text": "2006年主演《士兵突击》引起空前的反响，这让王宝强一下子成为了家喻户晓的明星",
  "relation_list": [
    {
      "object": "王宝强",
      "subject": "士兵突击",
      "predicate": "主演"
    }
  ],
  "entity_list": [
    {
      "text": "王宝强",
      "type": "人物"
    },
    {
      "text": "士兵突击",
      "type": "影视作品"
    }
  ]
}

