In [1]:
!nvidia-smi

Fri Jun 18 00:49:59 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    25W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/00/92/6153f4912b84ee1ab53ab45663d23e7cf3704161cb5ef18b0c07e207cef2/transformers-4.7.0-py3-none-any.whl (2.5MB)
[K     |████████████████████████████████| 2.5MB 4.0MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 31.8MB/s 
[?25hCollecting huggingface-hub==0.0.8
  Downloading https://files.pythonhosted.org/packages/a1/88/7b1e45720ecf59c6c6737ff332f41c955963090a18e72acbcbeac6b25e86/huggingface_hub-0.0.8-py3-none-any.whl
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/d4/e2/df3543e8ffdab68f5acc73f613de9c2b155ac47f162e725dcac87c521c11/tokenizers-0.10.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.3MB)
[K     |█

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
import json
from tqdm import tqdm
import numpy as np
from collections import deque, defaultdict
import time
import math

batch_size = 16
max_length = 256
yamma = 4
transformers_path = "hfl/chinese-roberta-wwm-ext"
# train_data_file_path = "D:/db/ner/renminribao_2014/train.json"
# test_data_file_path = "D:/db/ner/renminribao_2014/test.json"
# train_data_file_path = "./data/train.json"
# test_data_file_path = "./data/test.json"

train_data_file_path = "/content/drive/MyDrive/dataset/ner/renminribao_2014/train.json"
test_data_file_path = "/content/drive/MyDrive/dataset/ner/renminribao_2014/test.json"

ENTITY_TYPE = ['PER', 'ORG', 'T', 'LOC']
entity_type_to_ids = {type_str: i for i, type_str in enumerate(ENTITY_TYPE)}
num_entity = len(ENTITY_TYPE)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class CustomNerDataset(Dataset):
    def __init__(self, train_mode: bool):
        self.train_mode = train_mode
        self.tokenizer = BertTokenizer.from_pretrained(transformers_path)
        data_file_path = train_data_file_path if self.train_mode else test_data_file_path
        with open(data_file_path, "r", encoding="utf-8") as f:
            self.dataset = json.load(f)
        self.dataset_length = len(self.dataset)
        print(f"load {'train' if self.train_mode else 'test'} dataset size: {self.dataset_length}")

    def _encode(self, input_text: str, max_length: int):
        input_text_tokens = input_text.strip().split(" ")
        input_text_tokens = ['[CLS]'] + input_text_tokens
        input_text_tokens = input_text_tokens[:max_length - 1] + ['[SEP]']
        input_text_tokens = input_text_tokens + ['[PAD]' for _ in range(max_length - len(input_text_tokens))]

        input_ids = self.tokenizer.convert_tokens_to_ids(input_text_tokens)
        attention_mask = [1 if token != '[PAD]' else 0 for token in input_text_tokens]
        token_type_ids = [0 for token in input_text_tokens]

        _input_encoding = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids
        }
        _input_encoding = {k: torch.tensor(v, dtype=torch.long) for k, v in _input_encoding.items()}
        return _input_encoding

    def _calc_target_matrix(self, tags, max_length: int):
        target_matrix = np.zeros([num_entity, max_length, max_length], dtype=int)
        for tag in tags:
            tag_idx = entity_type_to_ids[tag["label"]]
            start_index = tag["start_index"] + 1
            end_index = tag["end_index"] + 1
            if start_index < max_length - 1 and end_index < max_length - 1:
                target_matrix[tag_idx, start_index, end_index] = 1
        target_matrix = torch.tensor(target_matrix, dtype=torch.long)
        return target_matrix

    def __getitem__(self, index):
        sample = self.dataset[index]
        text = sample["text"]
        tags = sample["tags"]
        return self._encode(text, max_length=max_length), self._calc_target_matrix(tags, max_length=max_length)

    def __len__(self):
        return self.dataset_length


class RotationalPositionEmbedding(nn.Module):
    def __init__(self, seq_len, d_model):
        super(RotationalPositionEmbedding, self).__init__()
        position = torch.arange(0, seq_len, step=1, dtype=torch.float).unsqueeze(dim=1)
        div_term = torch.exp(torch.arange(0, d_model, step=2, dtype=torch.float) * (-math.log(10000)) / d_model).unsqueeze(dim=0)
        sin_part = torch.sin(position * div_term)
        sin_part = torch.stack([sin_part, sin_part], dim=-1).reshape([1, seq_len, 1, d_model])
        cos_part = torch.cos(position * div_term)
        cos_part = torch.stack([cos_part, cos_part], dim=-1).reshape([1, seq_len, 1, d_model])
        self.register_buffer("sin_part", sin_part)
        self.register_buffer("cos_part", cos_part)

    def forward(self, embedding):
        # embedding shape [batch, seq_length, d_model]
        embedding_2 = torch.stack([-embedding[..., 1::2], embedding[..., 0::2]], dim=-1).reshape(embedding.shape)
        return embedding * self.cos_part + embedding_2 * self.sin_part


class GlobalPointer(nn.Module):
    def __init__(self, seq_len, d_model):
        super(GlobalPointer, self).__init__()
        self.seq_len = seq_len
        self.d_model = d_model
        self.proj_p = nn.Linear(d_model, num_entity * d_model)
        self.proj_q = nn.Linear(d_model, num_entity * d_model)
        self.pe = RotationalPositionEmbedding(seq_len, d_model)

    def forward(self, embedding, mask):
        # embedding shape [batch, seq_length, d_model]
        p = self.proj_p(embedding).reshape([-1, self.seq_len, num_entity, self.d_model])
        q = self.proj_q(embedding).reshape([-1, self.seq_len, num_entity, self.d_model])
        p = self.pe(p)
        q = self.pe(q)
        tag_matrix = torch.einsum("bmed,bned->bemn", p, q)  # [batch, num_entity, seq_length, seq_length]

        mask_seq = torch.einsum("bm,bn->bmn", mask, mask).unsqueeze(dim=-3)
        mask_tri = torch.triu(torch.ones_like(tag_matrix))
        mask = torch.logical_and(mask_seq, mask_tri)
        tag_matrix = torch.masked_fill(tag_matrix, torch.logical_not(mask), -1e4)
        return tag_matrix


class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.bert_module = BertModel.from_pretrained(transformers_path)
        self.global_pointer = GlobalPointer(seq_len=max_length, d_model=self.bert_module.config.hidden_size)

    def forward(self, input_ids, attention_mask, token_type_ids):
        bert_output = self.bert_module(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]
        tag_matrix = self.global_pointer(bert_output, attention_mask)
        return tag_matrix


def train():
    time.sleep(0.2)
    model.train()
    train_loss = deque([], maxlen=100)
    TP_count = defaultdict(int)
    FP_count = defaultdict(int)
    FN_count = defaultdict(int)
    TN_count = defaultdict(int)
    pbar = tqdm(dataloader_train, position=0, leave=True)
    pbar.set_description("train epoch {}".format(epoch))
    for input_encoding, y_target in pbar:
        optimizer.zero_grad()
        input_encoding = {k: v.to(device) for k, v in input_encoding.items()}
        y_target = y_target.to(device)
        with torch.cuda.amp.autocast():
            y_predict = model(**input_encoding)
            bce_loss = F.binary_cross_entropy_with_logits(y_predict, y_target.float(), reduction='none')
            focal_loss = torch.pow(1 - torch.exp(-bce_loss), yamma) * bce_loss
            loss = torch.mean(focal_loss)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        train_loss.append(loss.item())

        y_predict = torch.sigmoid(y_predict)
        for i, label_str in enumerate(ENTITY_TYPE):
            y_predict_label = torch.gt(y_predict[:, i, :, :], 0.5)
            y_target_label = torch.eq(y_target[:, i, :, :], 1)
            TP_count[label_str] += torch.logical_and(y_predict_label, y_target_label).sum().item()
            FP_count[label_str] += torch.logical_and(y_predict_label, torch.logical_not(y_target_label)).sum().item()
            FN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), y_target_label).sum().item()
            TN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), torch.logical_not(y_target_label)).sum().item()

        log_str = "loss={}".format(np.mean(train_loss))
        pbar.set_postfix_str(log_str)
    for i, label_str in enumerate(ENTITY_TYPE):
        nums = TP_count[label_str] + FN_count[label_str]
        precision = TP_count[label_str] / (TP_count[label_str] + FP_count[label_str] + 1e-5)
        recall = TP_count[label_str] / (TP_count[label_str] + FN_count[label_str] + 1e-5)
        f1 = (2 * precision * recall) / (precision + recall + 1e-5)
        print(f"label {label_str}: precision={precision}, recall={recall}, f1={f1}, nums={nums}")


def test():
    time.sleep(0.2)
    model.eval()
    test_loss = []
    TP_count = defaultdict(int)
    FP_count = defaultdict(int)
    FN_count = defaultdict(int)
    TN_count = defaultdict(int)
    pbar = tqdm(dataloader_test, position=0, leave=True)
    pbar.set_description("test epoch {}".format(epoch))
    for input_encoding, y_target in pbar:
        input_encoding = {k: v.to(device) for k, v in input_encoding.items()}
        y_target = y_target.to(device)

        y_predict = model(**input_encoding)
        bce_loss = F.binary_cross_entropy_with_logits(y_predict, y_target.float(), reduction='none')
        focal_loss = torch.pow(1 - torch.exp(-bce_loss), yamma) * bce_loss
        loss = torch.mean(focal_loss)

        test_loss.append(loss.item())
        y_predict = torch.sigmoid(y_predict)
        for i, label_str in enumerate(ENTITY_TYPE):
            y_predict_label = torch.gt(y_predict[:, i, :, :], 0.5)
            y_target_label = torch.eq(y_target[:, i, :, :], 1)
            TP_count[label_str] += torch.logical_and(y_predict_label, y_target_label).sum().item()
            FP_count[label_str] += torch.logical_and(y_predict_label, torch.logical_not(y_target_label)).sum().item()
            FN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), y_target_label).sum().item()
            TN_count[label_str] += torch.logical_and(torch.logical_not(y_predict_label), torch.logical_not(y_target_label)).sum().item()

        log_str = "loss={}".format(np.mean(test_loss))
        pbar.set_postfix_str(log_str)
    for i, label_str in enumerate(ENTITY_TYPE):
        nums = TP_count[label_str] + FN_count[label_str]
        precision = TP_count[label_str] / (TP_count[label_str] + FP_count[label_str] + 1e-5)
        recall = TP_count[label_str] / (TP_count[label_str] + FN_count[label_str] + 1e-5)
        f1 = (2 * precision * recall) / (precision + recall + 1e-5)
        print(f"label {label_str}: precision={precision}, recall={recall}, f1={f1}, nums={nums}")


if __name__ == '__main__':
    dataset_train = CustomNerDataset(train_mode=True)
    dataloader_train = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    dataset_test = CustomNerDataset(train_mode=False)
    dataloader_test = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)

    model = CustomModel()
    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(100):
        train()
        test()
        # torch.save(model.state_dict(), f"./model_1/model_{epoch}.pth")


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=109540.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=2.0, style=ProgressStyle(description_wi…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=112.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=19.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=268961.0, style=ProgressStyle(descripti…


load train dataset size: 257642
load test dataset size: 28627


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=689.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411578458.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
train epoch 0: 100%|██████████| 16103/16103 [41:58<00:00,  6.39it/s, loss=6.748016620861108e-07]


label PER: precision=0.38999247838512796, recall=0.7126851650063938, f1=0.5041171096364067, nums=202252
label ORG: precision=0.03562948937398544, recall=0.4284681643103017, f1=0.06578690002716407, nums=16601
label T: precision=0.29324211564957114, recall=0.6756248049503465, f1=0.40897175477416425, nums=198702
label LOC: precision=0.3572920190253864, recall=0.694611776022859, f1=0.4718623238183394, nums=228090


test epoch 0: 100%|██████████| 1790/1790 [03:37<00:00,  8.22it/s, loss=5.690854374873613e-07]


label PER: precision=0.8884304928751433, recall=0.8797903988277498, f1=0.8840843368646343, nums=22519
label ORG: precision=0.9208899819475279, recall=0.8159912332092485, f1=0.8652679516852517, nums=1826
label T: precision=0.8267319158790595, recall=0.9646250609070682, f1=0.8903662231615062, nums=22417
label LOC: precision=0.8623465408871769, recall=0.9193542063092777, f1=0.8899333642212865, nums=25519


train epoch 1: 100%|██████████| 16103/16103 [41:43<00:00,  6.43it/s, loss=3.71766838824783e-07]


label PER: precision=0.8993262295716983, recall=0.8625576013654966, f1=0.8805532567326319, nums=202252
label ORG: precision=0.8429451627585057, recall=0.824106981010718, f1=0.8334146340433806, nums=16601
label T: precision=0.9310688954154009, recall=0.9202776015882942, f1=0.9256367980700557, nums=198702
label LOC: precision=0.9101290313441125, recall=0.8847472488541915, f1=0.8972536761438381, nums=228090


test epoch 1: 100%|██████████| 1790/1790 [03:37<00:00,  8.22it/s, loss=3.1269937701540997e-07]


label PER: precision=0.898307280033549, recall=0.9237976815472279, f1=0.9108691823616316, nums=22519
label ORG: precision=0.9565217337294086, recall=0.9277108382929308, f1=0.9418910204232415, nums=1826
label T: precision=0.9404101634032374, recall=0.9757327024241724, f1=0.9577408634653752, nums=22417
label LOC: precision=0.9544207375591053, recall=0.9132803005943492, f1=0.9333924148708099, nums=25519


train epoch 2: 100%|██████████| 16103/16103 [41:46<00:00,  6.43it/s, loss=2.5272726944081115e-07]


label PER: precision=0.9243613719494667, recall=0.8895536261253509, f1=0.9066185328947508, nums=202252
label ORG: precision=0.932764272324634, recall=0.9083790127652678, f1=0.9204051565924793, nums=16601
label T: precision=0.9587837256884734, recall=0.954831858715321, f1=0.956798711687741, nums=198702
label LOC: precision=0.9384061386899786, recall=0.9163005830629883, f1=0.9272166276894563, nums=228090


test epoch 2: 100%|██████████| 1790/1790 [03:37<00:00,  8.22it/s, loss=2.5611260307270645e-07]


label PER: precision=0.941359863857063, recall=0.9117633993908418, f1=0.9263202885000463, nums=22519
label ORG: precision=0.9604335991304412, recall=0.9704271578837506, f1=0.9653995168288367, nums=1826
label T: precision=0.9589763706692055, recall=0.9812642186816861, f1=0.9699872833577958, nums=22417
label LOC: precision=0.9539889569680767, recall=0.9343626313984237, f1=0.9440688029762326, nums=25519


train epoch 3: 100%|██████████| 16103/16103 [41:50<00:00,  6.42it/s, loss=2.158496661230913e-07]


label PER: precision=0.9342099218045843, recall=0.9049205940655757, f1=0.9193270336077379, nums=202252
label ORG: precision=0.9577810499768158, recall=0.940184325679065, f1=0.9488961154824866, nums=16601
label T: precision=0.9698195426041831, recall=0.9677305713597383, f1=0.9687689309005482, nums=198702
label LOC: precision=0.9492154320287356, recall=0.9309044674939321, f1=0.9399657826983878, nums=228090


test epoch 3: 100%|██████████| 1790/1790 [03:35<00:00,  8.30it/s, loss=2.2619656586892502e-07]


label PER: precision=0.946517697263374, recall=0.9179359647773276, f1=0.9320027558652444, nums=22519
label ORG: precision=0.9513998893217122, recall=0.9863088664495681, f1=0.968534926528733, nums=1826
label T: precision=0.9693666764956855, recall=0.9838961498042128, f1=0.9765723742333726, nums=22417
label LOC: precision=0.955314960253813, recall=0.9508601430499393, f1=0.9530773461510341, nums=25519


train epoch 4: 100%|██████████| 16103/16103 [41:53<00:00,  6.41it/s, loss=2.0035272747520593e-07]


label PER: precision=0.9414145620087749, recall=0.9165447065583261, f1=0.9288081865331517, nums=202252
label ORG: precision=0.9696821808273762, recall=0.9575326781774998, f1=0.9635641334224685, nums=16601
label T: precision=0.9756509379264916, recall=0.9741975419988628, f1=0.9749186983183106, nums=198702
label LOC: precision=0.9567950179538864, recall=0.9416852996211283, f1=0.9491750311466176, nums=228090


test epoch 4: 100%|██████████| 1790/1790 [03:38<00:00,  8.18it/s, loss=1.9975958657766587e-07]


label PER: precision=0.9496125783000532, recall=0.9306363511121114, f1=0.9400237073897214, nums=22519
label ORG: precision=0.9408713644145676, recall=0.9934282530480381, f1=0.9664308003797198, nums=1826
label T: precision=0.9751070305092876, recall=0.9855466828810516, f1=0.9802940634940971, nums=22417
label LOC: precision=0.9667780691625774, recall=0.9521924836583752, f1=0.9594248461446021, nums=25519


train epoch 5: 100%|██████████| 16103/16103 [42:01<00:00,  6.39it/s, loss=1.9993683053387557e-07]


label PER: precision=0.9468548394784809, recall=0.9245050728336677, f1=0.9355414945764606, nums=202252
label ORG: precision=0.9755429817025352, recall=0.9659056677513971, f1=0.970695405154891, nums=16601
label T: precision=0.9793508540767152, recall=0.9783897494248478, f1=0.9788650658621677, nums=198702
label LOC: precision=0.9622147000647577, recall=0.9496602218006199, f1=0.9558912411859661, nums=228090


test epoch 5: 100%|██████████| 1790/1790 [03:37<00:00,  8.22it/s, loss=1.8117538526351063e-07]


label PER: precision=0.9451501150536729, recall=0.9450242013655028, f1=0.9450821540422013, nums=22519
label ORG: precision=0.9678628764442267, recall=0.9895947371873235, f1=0.9786031729753814, nums=1826
label T: precision=0.9804503941038876, recall=0.986617298930893, f1=0.9835191797372524, nums=22417
label LOC: precision=0.9710727355677731, recall=0.9563462514376165, f1=0.9636482348191588, nums=25519


train epoch 6: 100%|██████████| 16103/16103 [41:53<00:00,  6.41it/s, loss=1.5985889884717607e-07]


label PER: precision=0.9515787985813482, recall=0.9323121649757574, f1=0.9418419619843404, nums=202252
label ORG: precision=0.9815558796070393, recall=0.971327027907158, f1=0.9764096654711931, nums=16601
label T: precision=0.9821088674665337, recall=0.9815552938077345, f1=0.9818270026344283, nums=198702
label LOC: precision=0.9660644216911897, recall=0.955285194398909, f1=0.960639571216908, nums=228090


test epoch 6: 100%|██████████| 1790/1790 [03:38<00:00,  8.20it/s, loss=1.7692027740204239e-07]


label PER: precision=0.9550218139904548, recall=0.94289266799463, f1=0.9489134838965714, nums=22519
label ORG: precision=0.953231734349807, recall=0.9934282530480381, f1=0.9729099875535775, nums=1826
label T: precision=0.9729753362895858, recall=0.9925502962070971, f1=0.9826603418120002, nums=22417
label LOC: precision=0.9725115851497161, recall=0.9621458517331613, f1=0.9672959492426567, nums=25519


train epoch 7:  51%|█████     | 8133/16103 [21:14<21:18,  6.23it/s, loss=1.5159697370314972e-07]