In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import json
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.nn import Embedding
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


# 数据预处理

In [2]:
def read_sentence (df: pd.DataFrame, sentences: list):
    for i in range(df.__len__()):
        df.context[i].split()
        sentence = [[j, "O"] for j in df.context[i].split()]
        for label in df.labels[i]:
            if label["end_position"] is not None :
                for sp in label["start_position"]:
                    sentence[sp][1] = "B-" + label["entity_label"]
                for sl in label["span_list"]:
                    for idx in range(sl[0][0] + 1, sl[0][1] + 1):
                        sentence[idx][1] = "I-" + label["entity_label"]
        sentences.append([df.context[i], [i[1] for i in sentence]])

In [3]:
## 加载所有句子，用于构建总词典
all_sentences = []
df = pd.read_json("/Users/zhangyuyao/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/3b589b8fb0516c8e072f981ad5ce1b32/Message/MessageTemp/9a8944b8b79ab3df4a4f9e99ebae488b/File/第一次作业布置22.09.27/data/conll03/train.json")
read_sentence(df, all_sentences)
df = pd.read_json("/Users/zhangyuyao/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/3b589b8fb0516c8e072f981ad5ce1b32/Message/MessageTemp/9a8944b8b79ab3df4a4f9e99ebae488b/File/第一次作业布置22.09.27/data/conll03/test.json")
read_sentence(df, all_sentences)
df = pd.read_json("/Users/zhangyuyao/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/3b589b8fb0516c8e072f981ad5ce1b32/Message/MessageTemp/9a8944b8b79ab3df4a4f9e99ebae488b/File/第一次作业布置22.09.27/data/conll03/dev.json")
read_sentence(df, all_sentences)
all_sentences.__len__()

20741

In [4]:
words = set()
for sentence in all_sentences:
    words_list = sentence[0].split()
    for word in words_list :
        words.add(word)
len(words)

30288

## 构建词典

In [6]:
#tokenize时的字典
words_dict = {}
for i, word in enumerate(words):
    words_dict[word] = i + 1
words_dict["<PAD>"] = 0

labels_dict = {
    'O': 1,
    'B-ORG': 2,
    'I-ORG': 3,
    'B-PER': 4,
    'I-PER': 5,
    'B-LOC': 6,
    'I-LOC': 7,
    'B-MISC': 8,
    'I-MISC': 9,
    '<PAD>': 0
}

## 反向映射，便于输出结果
words_dict2 = {}
for i, word in enumerate(words):
    words_dict2[i + 1] = word
words_dict2[0] = "<PAD>"
labels_dict2 = {
    1: 'O',
    2: 'B-ORG',
    3: 'I-ORG',
    4: 'B-PER',
    5: 'I-PER',
    6: 'B-LOC',
    7: 'I-LOC',
    8: 'B-MISC',
    9: 'I-MISC',
    0: '<PAD>'
}

## 构建数据集

In [7]:
def data_load(sentences):
    data = []
    for sentence in sentences:
        input = [words_dict[i] for i in sentence[0].split()]
        label = [labels_dict[i] for i in sentence[1]]
        pad_num = 124 - input.__len__()
        for i in range(pad_num):
            input.append(words_dict["<PAD>"])
            label.append(labels_dict["<PAD>"])
        data.append((torch.LongTensor(input), torch.LongTensor(label)))
    return data

def data2Tensor(data_list: list, mode: str = "train"):
    df = pd.read_json("/Users/zhangyuyao/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/3b589b8fb0516c8e072f981ad5ce1b32/Message/MessageTemp/9a8944b8b79ab3df4a4f9e99ebae488b/File/第一次作业布置22.09.27/data/conll03/{}.json".format(mode))
    read_sentence(df, data_list)
    data = data_load(data_list)
    data_tensor, label_tensor = data[0][0].view(1,-1), data[0][1].view(1,-1)
    for i in data[1:]:
        data_tensor = torch.concat((data_tensor, i[0].view(1, -1)), dim = 0)
        label_tensor = torch.concat((label_tensor, i[1].view(1, -1)), dim = 0)
    return data_tensor, label_tensor

In [8]:
train_sentences = []
data_tensor, label_tensor = data2Tensor(train_sentences, mode="train")
data_tensor.size(), label_tensor.size()

(torch.Size([14040, 124]), torch.Size([14040, 124]))

In [248]:
# 保存训练数据
input = data_tensor.numpy()
np.save("input_train.npy", input)
label = label_tensor.numpy()
np.save("label_train.npy", label)

In [74]:
## test_data
test_sentences = []
data_tensor, label_tensor = data2Tensor(test_sentences, mode="test")
data_tensor.size(), label_tensor.size()

(torch.Size([3452, 124]), torch.Size([3452, 124]))

In [250]:
# 保存测试数据
input = data_tensor.numpy()
np.save("input_test.npy", input)
label = label_tensor.numpy()
np.save("label_test.npy", label)

In [241]:
dev_sentences = []
data_tensor, label_tensor = data2Tensor(dev_sentences, mode="dev")
input = data_tensor.numpy()
np.save("input_dev.npy", input)
label = label_tensor.numpy()
np.save("label_dev.npy", label)

In [9]:
data_tensor = torch.LongTensor(np.load("input_train.npy", allow_pickle=True))
label_tensor = torch.LongTensor(np.load("label_train.npy", allow_pickle=True))
data_tensor.size()
data_tensor.size(), label_tensor.size()

(torch.Size([14040, 124]), torch.Size([14040, 124]))

In [400]:
data_tensor = torch.LongTensor(np.load("input_test.npy", allow_pickle=True))
label_tensor = torch.LongTensor(np.load("label_test.npy", allow_pickle=True))
data_tensor.size()
data_tensor.size(), label_tensor.size()

(torch.Size([3452, 124]), torch.Size([3452, 124]))

# 模型搭建

# 双向LSTM模型

In [13]:
class dataset_generator(Dataset):
    def __init__(self, data_tensor, label_tensor):
        super().__init__()
        self.input = data_tensor
        self.output = label_tensor
        self.len = data_tensor.size()[0]
        
    def __getitem__(self, index):
        return self.input[index], self.output[index]
 
    def __len__(self):
        return self.len

class BiLSTM(nn.Module):
    def __init__(self, dict_size, embedding_size, hidden_size, output_size):
        """初始化参数：
            dict_size：字典的大小
            embedding_size：词向量的维数
            hidden_size：隐向量的维数
            output_size：标签的维数
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.embedding = Embedding(dict_size, embedding_size, padding_idx = dict_size - 1)
        self.bilstm = nn.LSTM(embedding_size, hidden_size,
                              batch_first = True,
                              bidirectional = True)
        self.linear = nn.Linear(hidden_size * 2, output_size)
        
    def forward(self, input_tensor: torch.LongTensor, seq_length, hidden_cell = None):
        batch_size = input_tensor.size()[0]
        
        if hidden_cell is None:
            hidden_cell = (torch.zeros(2, batch_size, self.hidden_size), 
                           torch.zeros(2, batch_size, self.hidden_size))
        
        embbeding = self.embedding(input_tensor)
        embbeding = pack_padded_sequence(embbeding, seq_length,
                                         batch_first = True,
                                         enforce_sorted = False)
        
        lstm_out, hidden_cell = self.bilstm(embbeding, hidden_cell)
        lstm_out, _ = pad_packed_sequence(lstm_out, 
                                          batch_first = True,
                                          total_length = 124)
        
        output = self.linear(lstm_out)
        return output, hidden_cell
    
    def train(self, epochs: int, trainDataLoader: DataLoader):
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(self.parameters(), lr = 3)
        for epoch in tqdm(range(epochs)):
            Loss = 0
            for data in trainDataLoader:
                optimizer.zero_grad()
                input, output = data
                seq_length = torch.count_nonzero(output, dim = -1)
                pred_output, _ = self(input, seq_length)
                loss = criterion(pred_output.view(-1, 10), output.view(-1))
                Loss += loss.item()
                loss.backward()
                optimizer.step()
            optimizer.step()
            if (epoch + 1) % 5 == 0:
                print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.4f}'.format(Loss))
                
    def test(self, data_tensor):
        seq_length = torch.count_nonzero(label_tensor, dim = -1)
        output, _ = self(data_tensor, seq_length)
        output = torch.argmax(output, dim = -1)
        return output

## 模型训练

In [14]:
dataset = dataset_generator(data_tensor, label_tensor)
trainDataLoader = DataLoader(dataset, batch_size = 32, shuffle = True)
dataset.__len__()

14040

In [15]:
bilstm = BiLSTM(dict_size = len(words_dict),
                embedding_size = 512,
                hidden_size = 256,
                output_size = len(labels_dict)
                )

In [16]:
bilstm.train(epochs = 100, trainDataLoader = trainDataLoader)

  5%|▌         | 5/100 [03:09<1:01:43, 38.98s/it]

Epoch: 0005 loss = 6.6087


 10%|█         | 10/100 [06:57<1:06:59, 44.66s/it]

Epoch: 0010 loss = 1.5170


 15%|█▌        | 15/100 [10:49<1:06:00, 46.59s/it]

Epoch: 0015 loss = 0.4950


 20%|██        | 20/100 [14:44<1:02:44, 47.06s/it]

Epoch: 0020 loss = 0.2488


 25%|██▌       | 25/100 [18:32<57:22, 45.90s/it]  

Epoch: 0025 loss = 0.1594


 30%|███       | 30/100 [22:26<54:38, 46.84s/it]

Epoch: 0030 loss = 0.1141


 35%|███▌      | 35/100 [26:23<51:13, 47.28s/it]

Epoch: 0035 loss = 0.0890


 40%|████      | 40/100 [30:19<48:05, 48.10s/it]

Epoch: 0040 loss = 0.0720


 45%|████▌     | 45/100 [34:39<47:07, 51.42s/it]

Epoch: 0045 loss = 0.0608


 50%|█████     | 50/100 [38:55<42:41, 51.23s/it]

Epoch: 0050 loss = 0.0523


 55%|█████▌    | 55/100 [43:03<37:03, 49.42s/it]

Epoch: 0055 loss = 0.0452


 60%|██████    | 60/100 [47:06<32:26, 48.67s/it]

Epoch: 0060 loss = 0.0417


 65%|██████▌   | 65/100 [51:19<29:30, 50.57s/it]

Epoch: 0065 loss = 0.0359


 70%|███████   | 70/100 [55:35<25:35, 51.18s/it]

Epoch: 0070 loss = 0.0326


 75%|███████▌  | 75/100 [59:46<21:08, 50.73s/it]

Epoch: 0075 loss = 0.0300


 80%|████████  | 80/100 [1:04:00<16:54, 50.72s/it]

Epoch: 0080 loss = 0.0275


 85%|████████▌ | 85/100 [1:08:13<12:40, 50.72s/it]

Epoch: 0085 loss = 0.0257


 90%|█████████ | 90/100 [1:12:31<08:32, 51.29s/it]

Epoch: 0090 loss = 0.0240


 95%|█████████▌| 95/100 [1:16:44<04:12, 50.59s/it]

Epoch: 0095 loss = 0.0221


100%|██████████| 100/100 [1:21:01<00:00, 48.62s/it]

Epoch: 0100 loss = 0.0208





In [17]:
# 保存模型参数
torch.save(bilstm.state_dict(), "bilstm.pkl")

In [19]:
new_model = BiLSTM(dict_size = len(words_dict),
                embedding_size = 512,
                hidden_size = 256,
                output_size = len(labels_dict)
                )
new_model.load_state_dict(torch.load("bilstm.pkl"))

<All keys matched successfully>

## 性能评估

In [20]:
data_tensor = torch.LongTensor(np.load("input_test.npy", allow_pickle=True))
label_tensor = torch.LongTensor(np.load("label_test.npy", allow_pickle=True))

In [21]:
pred_labels_lists = []
pred_labels_tensor = bilstm.test(data_tensor)
seq_length = torch.count_nonzero(label_tensor, dim = -1).tolist()
for i, j in enumerate(seq_length):
    tmp_sentence, tmp_label = data_tensor[i][:j].tolist(), pred_labels_tensor[i][:j].tolist()
    res = []
    for i, j in zip(tmp_sentence, tmp_label):
        res.append(labels_dict2[j])
    pred_labels_lists.append(res)

In [22]:
len(pred_labels_lists)

3452

# 隐马尔可夫模型


## 构建训练数据

In [320]:
train_sentences = []
df = pd.read_json("/Users/zhangyuyao/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/3b589b8fb0516c8e072f981ad5ce1b32/Message/MessageTemp/9a8944b8b79ab3df4a4f9e99ebae488b/File/第一次作业布置22.09.27/data/conll03/train.json")
read_sentence(df, train_sentences)

## 构建用于隐马尔可夫模型的词典

In [28]:
words = set()
for sentence in all_sentences:
    words_list = sentence[0].split()
    for word in words_list :
        words.add(word)

words_dict_hmm = {}

for i, word in enumerate(words):
    words_dict_hmm[word] = i
    
labels_dict_hmm = {
    'O': 0,
    'B-ORG': 1,
    'I-ORG': 2,
    'B-PER': 3,
    'I-PER': 4,
    'B-LOC': 5,
    'I-LOC': 6,
    'B-MISC': 7,
    'I-MISC': 8,
}

## 模型构建

In [29]:
class HMM(object):
    def __init__(self, N, M):
        
        self.N = N 
        self.M = M
        
        # 0. 初始化转移概率矩阵、估计观测概率矩阵、初始状态概率向量
        self.A = np.zeros((N, N))
        self.B = np.zeros((N, M))
        self.pi = np.zeros(N)
    
    def train(self, data_list: list, words_dict: dict, labels_dict: dict):
        """use corpus to train the hmm model.

        Args:
            data_list (list): each item is a list, item[0] is sentence and item[1] is label_list
            words_dict (dict): the words_dict, mapping word to its token
            labels_dict (dict): the labels_dict mapping label to its token
        """
        
        ## 1.估计转移概率矩阵
        for data in data_list:
            for i in range(len(data[1]) - 1):
                current_id = labels_dict[data[1][i]]
                next_id = labels_dict[data[1][i + 1]]
                self.A[current_id][next_id] += 1
        self.A[self.A == 0.] = 1e-10
        self.A = self.A / self.A.sum(axis = 1, keepdims = True)
        
        ## 2.估计观测概率矩阵
        ## 实质为根据频率派思想估计每个状态对应观测到各个观察值（词语）的概率
        for data in data_list:
            label_list = data[1]
            word_list = data[0].split()
            for label, word in zip(label_list, word_list):
                label_id = labels_dict[label]
                word_id = words_dict[word]
                self.B[label_id][word_id] += 1
        self.B[self.B == 0.] = 1e-10
        self.B = self.B / self.B.sum(axis = 1, keepdims = True)
        
        ## 3.估计初始状态概率
        ## 实质为根据频率派思想估计每个标签作为一个序列的首的概率
        for data in data_list:
            self.pi[labels_dict[data[1][0]]] += 1
        self.pi[self.pi == 0.] = 1e-10
        self.pi = self.pi / self.pi.sum()
    
    def decoding(self, sentence: str,
                 words_dict: dict = words_dict_hmm,
                 labels_dict: dict = labels_dict_hmm):
        """use Viterbi algorithm to decode the hidden_state into label_sequence

        Args:
            sentence (str): default.
            words_dict (dict): the words_dict, mapping word to its token
            labels_dict (dict): the labels_dict, mapping label to its token
        """
        A = np.log(self.A)
        B = np.log(self.B)
        Bt = B.T
        Pi = np.log(self.pi)
        word_list = sentence.split()
        seq_length = len(word_list)
        viterbi_matrix = np.zeros((self.N, seq_length))
        backpointer = torch.zeros(self.N, seq_length).long()
        start_idx = words_dict.get(word_list[0], None)
        
        if start_idx is None:
            bt = np.log(np.ones(self.N) / self.N)
        else:
            bt = Bt[start_idx]
            
        viterbi_matrix[:, 0] = Pi + bt
        backpointer[:, 0] = -1
        
        for i in range(1, seq_length):
            word_idx = words_dict.get(word_list[i], None)
            
            if word_idx is None:
                bt = np.log(np.ones(self.N) / self.N)
            else:
                bt = Bt[word_idx]
            for label_idx in range(self.N):
                max_prob = np.max(viterbi_matrix[:, i - 1] + A[:, label_idx], axis = 0)
                max_idx = np.argmax(viterbi_matrix[:, i - 1] + A[:, label_idx], axis = 0)
                viterbi_matrix[label_idx, i] = max_prob + bt[label_idx]
                backpointer[label_idx, i] = max_idx
        
        best_path_prob = np.max(viterbi_matrix[:, seq_length - 1], axis=0)
        best_path_pointer = np.argmax(viterbi_matrix[:, seq_length - 1], axis=0)
        best_path_pointer = best_path_pointer.item()
        best_path = [best_path_pointer]
        
        for back_step in range(seq_length - 1, 0, -1):
            best_path_pointer = backpointer[best_path_pointer, back_step]
            best_path_pointer = best_path_pointer.item()
            best_path.append(best_path_pointer)

        id2label = dict((id, tag) for tag, id in labels_dict.items())
        label_list = [id2label[id] for id in reversed(best_path)]
        
        return word_list, label_list


## 评价指标预测及输出
def ent_predict(label_list):
    """

    Args:
        label_list (list): the item of this list is label of entity

    Returns:
        Tuple(list, str): the list of entity format and string to output
    """
    span_list_ORG = []
    span_list_PER = []
    span_list_LOC = []
    span_list_MISC = []
    res = []
    for i, label in enumerate(label_list):
        if label == "B-ORG":
            for j, label in enumerate(label_list[i + 1:]):
                if label != "I-ORG":
                    span_list_ORG.append([i, i + j])
                    break
        
        if label == "B-PER":
            for j, label in enumerate(label_list[i + 1:]):
                if label != "I-PER":
                    span_list_PER.append([i, i + j])
                    break
                
        if label == "B-LOC":
            for j, label in enumerate(label_list[i + 1:]):
                if label != "I-LOC":
                    span_list_LOC.append([i, i + j])
                    break
        
        if label == "B-MISC":
            for j, label in enumerate(label_list[i + 1:]):
                if label != "I-MISC":
                    span_list_MISC.append([i, i + j])
                    break
        
    for item in span_list_ORG:
        res.append(('ORG', item[0], item[1]))
    
    for item in span_list_PER:
        res.append(('PER', item[0], item[1]))
    
    for item in span_list_LOC:
        res.append(('LOC', item[0], item[1]))
        
    for item in span_list_MISC:
        res.append(('MISC', item[0], item[1]))
    
    return res, "; ".join([str(it) for it in res])

## 模型训练

In [30]:
hmm = HMM(9, len(words_dict_hmm))
hmm.train(train_sentences, words_dict_hmm, labels_dict_hmm)

# 模型测试与评估

In [31]:
def Metrics(predict_list, golen_list):
    a = 0
    b = 0
    c = 0
    for res1, res2 in zip (predict_list, golen_list):
        for res_1_item in res1:
            if res_1_item in res2:
                a += 1
            else:
                b += 1
        for res_2_item in res2:
            if res_2_item  not in res1:
                c += 1
    p = a / (a + c)
    r = a / (a + b)
    f1 = 2 * p * r / (p + r)
    return p, r, f1

In [32]:
test_sentences = []
df = pd.read_json("/Users/zhangyuyao/Library/Containers/com.tencent.xinWeChat/Data/Library/Application Support/com.tencent.xinWeChat/2.0b4.0.9/3b589b8fb0516c8e072f981ad5ce1b32/Message/MessageTemp/9a8944b8b79ab3df4a4f9e99ebae488b/File/第一次作业布置22.09.27/data/conll03/test.json")
read_sentence(df, test_sentences)

In [33]:
golen_list = []
for i in range(len(df)):
    tmp = []
    label_items = df.loc[i, "labels"]
    for item in label_items: 
        name = item["entity_label"]
        if len(item["span_list"]) > 0:
            for span in item["span_list"]:
                start = span[0][0]
                end = span[0][1]
                tmp.append((name, start, end))
    golen_list.append(tmp)

In [34]:
predict_list_hmm = []
for i in range(len(df)):
    _ , label_list = hmm.decoding(df.loc[i, "context"])
    result , _ = ent_predict(label_list)
    predict_list_hmm.append(result)

In [35]:
predict_list_bilstm = []
for i in range(len(df)):
    res , _ = ent_predict(pred_labels_lists[i])
    predict_list_bilstm.append(res)

In [426]:
predict_list_bilstm

[[('PER', 2, 2), ('PER', 9, 9)],
 [],
 [('LOC', 0, 0), ('LOC', 2, 3)],
 [('LOC', 0, 0), ('LOC', 15, 15), ('MISC', 6, 7)],
 [('ORG', 1, 1)],
 [('PER', 18, 19), ('LOC', 0, 0), ('LOC', 16, 16), ('MISC', 34, 34)],
 [('PER', 1, 1)],
 [('MISC', 2, 2), ('MISC', 8, 9)],
 [('MISC', 3, 4)],
 [('LOC', 11, 11), ('LOC', 26, 26)],
 [('PER', 0, 1)],
 [('LOC', 7, 7)],
 [('PER', 0, 0), ('PER', 1, 1), ('LOC', 27, 27)],
 [('MISC', 0, 0)],
 [('LOC', 0, 0), ('MISC', 6, 6), ('MISC', 18, 18)],
 [('PER', 0, 0)],
 [('PER', 2, 3), ('LOC', 0, 0)],
 [('MISC', 1, 1)],
 [],
 [('ORG', 16, 16), ('LOC', 0, 0), ('LOC', 8, 8), ('MISC', 5, 6)],
 [('PER', 0, 0),
  ('LOC', 0, 0),
  ('LOC', 1, 1),
  ('LOC', 3, 3),
  ('LOC', 5, 6),
  ('LOC', 9, 9)],
 [],
 [('ORG', 0, 1)],
 [('PER', 0, 0)],
 [('LOC', 0, 0)],
 [('LOC', 6, 6)],
 [('PER', 4, 5), ('LOC', 30, 30)],
 [('PER', 15, 15), ('LOC', 23, 23)],
 [('LOC', 19, 19), ('LOC', 23, 23), ('MISC', 7, 8)],
 [('PER', 0, 0)],
 [('PER', 12, 12), ('MISC', 4, 5)],
 [],
 [],
 [('PER', 2, 3

In [36]:
p, r, f1 = Metrics(predict_list_hmm, golen_list)
p, r, f1

(0.5761871013465627, 0.6407881773399015, 0.6067730198712565)

In [38]:
p, r, f1 = Metrics(predict_list_bilstm, golen_list)
p, r, f1

(0.5914245216158752, 0.7402971834109559, 0.6575396434551364)

# 输出结果

In [39]:
def resOutput(sentences: list, model = hmm, mode: str = "hmm", form: str = "BIO"):
    if form == "BIO":
        with open("{}_BIO_results.txt".format(mode), "w") as f:
            for sentence in sentences:
                word_list, label_list = model.decoding(sentence[0])
                assert len(word_list) == len(label_list)
                for word, label in zip(word_list, label_list):
                    f.write(word + "\t" + label + "\n")
    else:
        label_output = []
        for i in range(len(df)):
            out = {}
            out["text"] = df.loc[i, "context"]
            
            _ , label_list = model.decoding(df.loc[i, "context"])
            _ , out["ent_predict"] = ent_predict(label_list)
            
            tmp = []
            label_items = df.loc[i, "labels"]
            for item in label_items: 
                name = item["entity_label"]
                if len(item["span_list"]) > 0:
                    for span in item["span_list"]:
                        start = span[0][0]
                        end = span[0][1]
                        tmp.append((name, start, end))
            out["golen_label"] = "; ".join([str(it) for it in tmp])
            
            label_output.append(out)
            
        with open('{}_Ent_results.json'.format(mode),'w') as f:
            json.dump(label_output, f, indent = 2)

In [40]:
resOutput(test_sentences, model=hmm)
resOutput(test_sentences, model=hmm, form="Ent")

In [41]:
def resOutput2(sentences: list, model = bilstm, form: str = "BIO"):
    sentence_list = [i[0] for i in sentences]
    if form == "BIO":
        with open("bilstm_BIO_results.txt", "w") as f:
            for i, sentence in enumerate(sentence_list):
                words_list = sentence.split()
                labels_list = pred_labels_lists[i]
                assert len(words_list) == len(labels_list)
                for word, label in zip(words_list, labels_list):
                    f.write(word + "\t" + label + "\n")
    else:
        label_output = []
        for i in range(len(df)):
            out = {}
            out["text"] = df.loc[i, "context"]
            
            label_list = pred_labels_lists[i]
            _ , out["ent_predict"] = ent_predict(label_list)
            
            tmp = []
            label_items = df.loc[i, "labels"]
            for item in label_items: 
                name = item["entity_label"]
                if len(item["span_list"]) > 0:
                    for span in item["span_list"]:
                        start = span[0][0]
                        end = span[0][1]
                        tmp.append((name, start, end))
            out["golen_label"] = "; ".join([str(it) for it in tmp])
            
            label_output.append(out)
            
        with open('bilstm_Ent_results.json','w') as f:
            json.dump(label_output, f, indent = 2)

In [42]:
resOutput2(test_sentences)
resOutput2(test_sentences, form="Ent")