In [1]:
import json
import torch 
import operator
from dataclasses import dataclass, field
from transformers import AutoTokenizer

@dataclass
class Config:
    test_path: str = field(
        default = "../data/test_dataset_A.json",
        metadata = { "help": "测试数据位置" }
    )
    id2label_path: str = field(
        default = "../data/id2label.txt",
        metadata = { "help": "标签集合的列表存放位置" }
    )
    pretrained_model_path: str = field(
        default = "/tf/FangGexiang/2.SememeV2/pretrained_model/chinese-bert-wwm-ext/",
        metadata = { "help": "预训练模型的存放位置" }
    )
    max_len: int = field(
        default = 512,
        metadata = { "help": "每句最长大小" }
    )
    device: str = field(
        default = "cuda:3",
        metadata = { "help": "GPU" }
    )
    saved_model_path: str = field(
        default = "/tf/FangGexiang/3.EventDetectation/7.V7/model_saved/",
        metadata = { "help": "训好的模型的保存位置" }
    )
    save_path: str = field(
        default = "./result.json",
        metadata = { "help": "结果保存位置" }
    )

config = Config()

id2label = eval(open(config.id2label_path, "r", encoding="utf-8").readline())
label2id = { label: i for i, label in enumerate(id2label) }

tokenizer = AutoTokenizer.from_pretrained(config.pretrained_model_path)

In [2]:
# ======== Set Seed ======== #
import random
import numpy as np
import torch

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# ========================= #

In [3]:
from torch.utils.data import Dataset, DataLoader

class myDataset(Dataset):

    def __init__(self, config):
        self.data = list()
        
        for data_item in json.load(open(config.test_path, "r", encoding="utf-8")):
            idx = data_item["id"]
            sentence = list(data_item["sentence"].strip())
            participle_sentence = data_item["tokens"]
            
            begin, part_token_ids = 1, list()
            for i, token in enumerate(participle_sentence):
                part_token_ids.append([begin, begin + len(token) - 1])
                begin += len(token)
            
            self.data.append([idx, sentence, participle_sentence, part_token_ids])

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)
    
test_dataset = myDataset(config)

In [4]:
def pad_sequence_ls(ls):
    max_len = max(list(map(len, ls)))
    ls = list(map(lambda x: x + [[0, 0]] * (max_len - len(x)), ls))
    entities_mask = list(map(lambda x: [0 if operator.eq(ids, [0, 0]) else 1 for ids in x], ls))
    
    return torch.tensor(ls, dtype=torch.long), torch.tensor(entities_mask, dtype=torch.long)

def myFn(batch):
    idx, batch_sentence, participle_sentence, batch_entities_ids = list(map(lambda x: list(x), list(zip(*batch))))
    
    batch_entities_ids = list(map(lambda x: [[0, 0]] + x + [[0, 0]], batch_entities_ids))
    
    batch_input_encoded = tokenizer(
        batch_sentence,
        padding=True,
        truncation=True,
        max_length=config.max_len,
        is_split_into_words=True,
        return_tensors="pt"
    )
    batch_entities_ids_padded, entities_mask = pad_sequence_ls(batch_entities_ids)

    batch_input_encoded = { k: v.to(config.device) for k, v in batch_input_encoded.items() }
    
    return [
        idx,
        batch_input_encoded,
        batch_entities_ids_padded.to(config.device),
        entities_mask.to(config.device),
        participle_sentence
    ]

test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=myFn)

In [5]:
import sys
sys.path.append("../training")

from model.bert import BertForSpanClassification

model = BertForSpanClassification.from_pretrained(config.saved_model_path, model_args=config).to(config.device)

In [6]:
import torch
from tqdm import tqdm

model.eval()
results = list()
test_bar = tqdm(test_dataloader)
with torch.no_grad():
    for idx, batch_input_encoded, batch_entities_ids_padded, entities_mask, participle_sentence in test_bar:
        logits = model(batch_input_encoded, batch_entities_ids_padded, entities_mask=entities_mask).logits
        softmax_logits = torch.softmax(logits, dim=-1)
        
        pre_score, prediction = torch.max(softmax_logits[1:-1], dim=-1)
        pre_score, prediction = pre_score.tolist(), prediction.tolist()
        
        record = list()
        for j, item in enumerate(prediction):
            score = list()
            if item != int(label2id["O"]):
                start, end = j, j + 1
                score.append(pre_score[j])

                if start == len(prediction) - 1:
                    record.append([start, start + 1, sum(score) / len(score)])
                else:
                    for j_next in range(j + 1, len(prediction)):
                        if prediction[j_next] != int(label2id["O"]):
                            score.append(pre_score[j_next])
                            end += 1
                        else:
                            record.append([start, end, sum(score) / len(score)])
                            break
        
        return_dict = { "id": int(idx[0]), "event_mention": dict() }
        if any(prediction):
            selected = sorted(record, key=lambda x: x[-1])[-1]

            text = ""
            if selected[0] == selected[1]:
                text += participle_sentence[0][selected[0]]
            else:
                for j_idx in range(selected[0], selected[1]):
                    text += participle_sentence[0][j_idx]
            return_dict["event_mention"].update({
                "trigger": {
                    "text": text,
                    "offset": [selected[0], selected[1]]
                },
                "event_type": id2label[prediction[selected[0]]]
            })

        results.append(return_dict)

100%|██████████| 2000/2000 [00:24<00:00, 81.45it/s]


In [7]:
with open(config.save_path, "w", encoding="utf-8") as f:
    f.write(json.dumps(results, ensure_ascii=False))