In [30]:
from tqdm.notebook import tqdm,trange
import os
class GetData:
    def read(self, data_path):
        data_parts = ['train', 'valid', 'test']
        extension = '.txt'
        dataset = {}
        bar = tqdm(data_parts)
        for data_part in bar:
            bar.set_description("正在读取数据集")
            file_path = os.path.join(data_path, data_part+extension)
            dataset[data_part] = self.read_file(str(file_path))
            if data_part == 'test':
                bar.set_description('数据读取完毕')
        return dataset

    def read_file(self, file_path):
        samples = []
        tokens = []   # 单词
        tags = []     # 实体标注
        with open(file_path,'r', encoding='utf-8') as fb:
            for line in fb:
                line = line.strip('\n')
                if line == '-DOCSTART- -X- -X- O':   # 去除数据头
                    pass
                elif line =='':                      # 一句话结束
                    if len(tokens) != 0:
                        samples.append((tokens, tags))
                        tokens = []
                        tags = []
                else:
                    items = line.split(' ')
                    tokens.append(items[0])
                    tags.append(items[-1])
        return samples

In [31]:
# Author: Robert Guthrie

import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim

torch.manual_seed(1)    # 人工设定随机种子以保证相同的初始化参数，实现模型的可复现性。

<torch._C.Generator at 0x15cb5545cb0>

In [32]:
def argmax(vec):  # 给定输入二维序列，取每行（第一维度）的最大值，返回对应索引。
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx.item()


def prepare_sequence(seq, to_ix):    # 利用to_ix这个word2id字典，将序列seq中的词转化为数字表示，包装为torch.long后返回
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):                # 函数目的相当于log∑exi 首先取序列中最大值，输入序列是一个二维序列(shape[1,tags_size])。下面的计算先将每个值减去最大值，再取log_sum_exp，最后加上最大值。
    max_score = vec[0, argmax(vec)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

In [33]:
class BiLSTM_CRF(nn.Module):

    def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
        super(BiLSTM_CRF, self).__init__()
        self.embedding_dim = embedding_dim # 词嵌入维度，即输入维度
        self.hidden_dim = hidden_dim   # 隐层维度
        self.vocab_size = vocab_size   # 训练集词典大小
        self.tag_to_ix = tag_to_ix     # 标签索引表
        self.tagset_size = len(tag_to_ix) # 标注 类型数
        print(f'tagset_size={self.tagset_size}')
        self.word_embeds = nn.Embedding(vocab_size, embedding_dim)  # （词嵌入的个数，嵌入维度）
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,         # （输入节点数，隐层节点数，隐层层数，是否双向）
                            num_layers=1, bidirectional=True)       #  hidden_size除以2是为了使BiLSTM的输出维度依然是hidden_size,而不用乘以2

        # Maps the output of the LSTM into tag space.
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)   # （输入x的维度，输出y的维度），将LSTM的输出线性映射到标签空间

        # Matrix of transition parameters.  Entry i,j is the score of
        # transitioning *to* i *from* j.
        self.transitions = nn.Parameter(                            # 转移矩阵，标注j转移到标注i的概率，后期要学习更新
            torch.randn(self.tagset_size, self.tagset_size))

        # These two statements enforce the constraint that we never transfer
        # to the start tag and we never transfer from the stop tag
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000     # 不会有标注转移到开始标注
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000      # 结束标注不会转移到其他标注

        self.hidden = self.init_hidden()

    def init_hidden(self):                                          # 初始化隐层（两层，3维）
        return (torch.randn(2, 1, self.hidden_dim // 2),            # (num_layer * num_direction, batch_size)
                torch.randn(2, 1, self.hidden_dim // 2))            # (隐层层数2 * 方向数1， 批大小1， 每层节点数)

    def _forward_alg(self, feats):                                  # 得到所有路径的分数/概率
        # Do the forward algorithm to compute the partition function
        init_alphas = torch.full((1, self.tagset_size), -10000.)    # P，(1, m)维，初始化为-10000
        # START_TAG has all of the score.
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.

        # Wrap in a variable so that we will get automatic backprop
        forward_var = init_alphas                                   # 前向状态，记录当前t之前的所有路径的分数
  
        # Iterate through the sentence
        for feat in feats:                                          # 动态规划思想，具体见onenote上的笔记
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.tagset_size):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                trans_score = self.transitions[next_tag].view(1, -1)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                next_tag_var = forward_var + trans_score + emit_score
                # The forward variable for this tag is log-sum-exp of all the
                # scores.
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        alpha = log_sum_exp(terminal_var)
        return alpha                                                # 返回的是所有路径的分数

    def _get_lstm_features(self, sentence):             # 通过BiLSTM层，输出得到发射分数
        self.hidden = self.init_hidden()
        embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)      # 对输入语句 词嵌入化
        lstm_out, self.hidden = self.lstm(embeds, self.hidden)              # 词嵌入通过lstm网络输出,lstm传入参数之后会自动调用其forward方法
        lstm_out = lstm_out.view(len(sentence), self.hidden_dim)            # 将输出转为2维（原本是3维，但是batch_size=1，可以去掉这一维）
        lstm_feats = self.hidden2tag(lstm_out)                              # 将输出映射到标签空间，得到单词-分数表
        return lstm_feats

    def _score_sentence(self, feats, tags):             # 计算给定路径的分数
        # feats : LSTM的所有输出，发射分数矩阵
        # tags : golden路径的标注序列
        # Gives the score of a provided tag sequence
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])   # 在标注序列最前加上开始标注
        for i, feat in enumerate(feats):                                                        # 计算给定序列的分数，Σ发散分数+Σ转移分数
            score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.tagset_size), -10000.)                      # 初始化forward_var,并且 开始标注 的分数为0,确保一定是从START_TAG开始的,
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0

        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars                                                     # forward_var记录每个标签的前向状态得分，即w{i-1}被打作每个标签的对应得分值
        for feat in feats:                                                           # feats是LSTM的输出，每一个feat都是一个词w{i}，feat[tag]就是这个词tag标注的分数
            bptrs_t = []  # holds the backpointers for this step                     # 记录当前词w{i}对应每个标签的最优转移结点
            viterbivars_t = []  # holds the viterbi variables for this step          # 记录当前词各个标签w{i, j}对应的最高得分
                                                                                     # 动态规划：w{i，j}=max{forwar_var + transitions[j]}，词存于bptrs_t中，分数存于viterbivars_t中

            for next_tag in range(self.tagset_size):                                 # 对当前词w{i}的每个标签 运算
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)                                              # 记忆，方便回溯

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = argmax(terminal_var)                                            # 结束标记前的一个词的最高前向状态得分就是最优序列尾
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):                                        # 回溯
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):       # CRF的损失函数：-gold分数-logsumexp(所有序列)
        feats = self._get_lstm_features(sentence)                   # 通过BiLSTM层，获得每个 {词-标签}对 的发射分数
        forward_score = self._forward_alg(feats)                    # 根据发射分数计算所有路径的分数
        gold_score = self._score_sentence(feats, tags)              # 传入标注序列真实值，计算语句的真实分数gold_score
        return forward_score - gold_score                           # 返回误差值

    def forward(self, sentence):                                    # 重载前向传播函数，对象传入参数后就会自动调用该函数
        # Get the emission scores from the BiLSTM 
        lstm_feats = self._get_lstm_features(sentence)              # 通过LSTM层得到输出

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)           # 通过CFR层得到最优路径及其分数
        return score, tag_seq

In [34]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"
EMBEDDING_DIM = 5       # 词嵌入维度
HIDDEN_DIM = 4          # 隐层层数

# Make up some training data
training_data = [(
    "the wall street journal reported today that apple corporation made money".split(),
    "B I I I O O O B I O O".split()
), (
    "georgia tech is a university in georgia".split(),
    "B I O O O O B".split()
)]
ds_rd = GetData()
data = ds_rd.read("./data")
training_data = data['train'][0:10]

word_to_ix = {}                             # 训练集词典 {词——索引}
for sentence, tags in training_data:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)

tag_to_ix = {"B-PER": 0, "B-LOC": 1, "B-ORG": 2, "B-MISC": 3,
             "I-PER": 4, "I-LOC": 5, "I-ORG": 6, "I-MISC": 7,
             "O": 8, START_TAG: 9, STOP_TAG: 10}  # 标签词典 {标注——索引}
model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)   # 模型BiLSTM-CRF
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-4)       # 优化器：使用SGD更新参数

# Check predictions before training
with torch.no_grad():                                                       # 在训练前测试一次预测结果，和训练后对比
    precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)       # 将第一个训练样本（词序列）转成索引序列
    precheck_tags = torch.tensor([tag_to_ix[t] for t in training_data[0][1]], dtype=torch.long) # 将第一个训练样本的标签转成索引序列
    print(model(precheck_sent))                                             # 输出第一次预测的结果（model(·)自动调用forward函数）

# Make sure prepare_sequence from earlier in the LSTM section is loaded
for epoch in trange(300,desc='模型训练进度'):  # again, normally you would NOT do 300 epochs, it is toy data     # 训练，迭代300次
    bar = tqdm(training_data, leave=False)
    for sentence, tags in bar:
        bar.set_description(f'epoch【{epoch}】')
        # Step 1. Remember that Pytorch accumulates gradients.
        # We need to clear them out before each instance
        model.zero_grad()                                                                   # 每次迭代前梯度清零（因为默认会叠加梯度）

        # Step 2. Get our inputs ready for the network, that is,
        # turn them into Tensors of word indices.
        sentence_in = prepare_sequence(sentence, word_to_ix)                                # 输入：语句转为词索引
        targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)              # 真实值：标注序列转为索引

        # Step 3. Run our forward pass.
        loss = model.neg_log_likelihood(sentence_in, targets)                               # 计算误差

        # Step 4. Compute the loss, gradients, and update the parameters by
        # calling optimizer.step()
        loss.backward()                                                                     # 计算当前梯度，反向传播
        optimizer.step()                                                                    # 根据当前梯度更新网络参数

# Check predictions after training
# with torch.no_grad():                                                      # 在训练后预测一次预测，和训练前对比
#     precheck_sent = prepare_sequence(training_data[0][0], word_to_ix)
#     print(model(precheck_sent))
print('traning over!')
torch.save(model,'pre_model.pth')                                  # 保存模型
torch.save(model.state_dict(),'model_params.pth')                  # 保存模型参数

  0%|          | 0/3 [00:00<?, ?it/s]

tagset_size=11
(tensor(23.0285), [4, 3, 0, 3, 0, 3, 0, 3, 0])


模型训练进度:   0%|          | 0/300 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

traning over!


In [50]:
pre_model = torch.load('pre_model.pth')                            # 直接加载模型
with torch.no_grad():
    for i in range(10):
        test_data = data['train'][i]
        In = prepare_sequence(test_data[0],word_to_ix)
        Out = pre_model(In)
        print(f'模型预测输出: {Out}')
        targets = torch.tensor([tag_to_ix[tag] for tag in test_data[1]], dtype=torch.long)
        print(f'真值: {targets}')
        print(f'预测中词性标注错误的个数: {(torch.tensor(Out[1],dtype=torch.long)-targets).sum().item()}')
        print()
        # We got it!

模型预测输出: (tensor(72.6214), [2, 8, 3, 8, 8, 8, 3, 8, 8])
真值: tensor([2, 8, 3, 8, 8, 8, 3, 8, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(10.8929), [0, 4])
真值: tensor([0, 4])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(12.6839), [1, 8])
真值: tensor([1, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(266.2206), [8, 2, 6, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
真值: tensor([8, 2, 6, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(262.9335), [1, 8, 8, 8, 8, 2, 6, 8, 8, 8, 0, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8])
真值: tensor([1, 8, 8, 8, 8, 2, 6, 8, 8, 8, 0, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,
        8, 8, 8, 8, 8, 8, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(292.3787), [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 2, 8, 8, 8, 0, 4, 4, 4, 8, 8, 8, 8, 8])
真值: tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 2, 8, 8, 8,
        0, 4, 4, 4, 8, 8, 8, 8, 8])


In [52]:
model2 = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)           # 读取参数加载模型
model2.load_state_dict(torch.load('model_params.pth'))
with torch.no_grad():
    for i in range(10):
        test_data = data['train'][i]
        In = prepare_sequence(test_data[0],word_to_ix)
        Out = model2(In)
        print(f'模型预测输出: {Out}')
        targets = torch.tensor([tag_to_ix[tag] for tag in test_data[1]], dtype=torch.long)
        print(f'真值: {targets}')
        print(f'预测中词性标注错误的个数: {(torch.tensor(Out[1],dtype=torch.long)-targets).sum().item()}')
        print()
        # We got it!

tagset_size=11
模型预测输出: (tensor(71.8458), [2, 8, 3, 8, 8, 8, 3, 8, 8])
真值: tensor([2, 8, 3, 8, 8, 8, 3, 8, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(10.7027), [0, 4])
真值: tensor([0, 4])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(16.0988), [1, 8])
真值: tensor([1, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(267.9419), [8, 2, 6, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8])
真值: tensor([8, 2, 6, 8, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 3, 8, 8, 8, 8, 8, 8, 8, 8,
        8, 8, 8, 8, 8, 8])
预测中词性标注错误的个数: 0

模型预测输出: (tensor(266.1242), [8, 8, 8, 8, 8, 2, 6, 8, 8, 8, 0, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1, 8, 8, 8, 8, 8, 8, 8])
真值: tensor([1, 8, 8, 8, 8, 2, 6, 8, 8, 8, 0, 4, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 1,
        8, 8, 8, 8, 8, 8, 8])
预测中词性标注错误的个数: 7

模型预测输出: (tensor(290.9359), [8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 2, 8, 8, 8, 0, 4, 4, 4, 8, 8, 8, 8, 8])
真值: tensor([8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 2, 8, 8, 8,
        0, 4, 4, 4, 8

In [24]:
# from tqdm.notebook import trange, tqdm
# from time import sleep
# out_bar = trange(5, desc='训练中',colour='red')
# for i in out_bar:
#     bar = tqdm(range(2), desc='2nd loop', leave=False)
#     for j in bar:
#         bar.set_description(f'epoch【{i}】')
#         sleep(0.5)
#     if i == 4:
#         out_bar.set_description('训练完毕')
        
# #     out_bar.set_description('总进度: ')

训练中:   0%|          | 0/5 [00:00<?, ?it/s]

2nd loop:   0%|          | 0/2 [00:00<?, ?it/s]

2nd loop:   0%|          | 0/2 [00:00<?, ?it/s]

2nd loop:   0%|          | 0/2 [00:00<?, ?it/s]

2nd loop:   0%|          | 0/2 [00:00<?, ?it/s]

2nd loop:   0%|          | 0/2 [00:00<?, ?it/s]

# Json处理

In [4]:
pip install pandas

Collecting pandasNote: you may need to restart the kernel to use updated packages.

  Downloading pandas-1.4.3-cp38-cp38-win_amd64.whl (10.6 MB)
     ---------------------------------------- 10.6/10.6 MB 9.4 MB/s eta 0:00:00
Installing collected packages: pandas
Successfully installed pandas-1.4.3


In [48]:
import pandas as pd
from tqdm import tqdm

In [49]:
file = './data/data.json'
df = pd.read_json(file)
a, b = df.loc[4]
print(type(a))
print(b)
print(len(df)//5*4)
df

<class 'str'>
[['NAME', [3, 14]], ['TICKER', [35, 39]], ['NOTIONAL', [21, 34]]]
2400


Unnamed: 0,text,label
0,Dear 568.763million AXNVF,"[[TICKER, [20, 25]], [NOTIONAL, [5, 19]]]"
1,Buy 703.363thousand HEOFF,"[[TICKER, [20, 25]], [NOTIONAL, [4, 19]]]"
2,May I 927.795hundred RLXXF Put,"[[TICKER, [21, 26]], [NOTIONAL, [6, 20]]]"
3,77.574hundred BRGGF,"[[TICKER, [14, 19]], [NOTIONAL, [0, 13]]]"
4,Hi Mark Romero Can I 66.585million ABST Thank ...,"[[NAME, [3, 14]], [TICKER, [35, 39]], [NOTIONA..."
...,...,...
2995,860.036 PCLOF Buy,"[[TICKER, [8, 13]], [NOTIONAL, [0, 7]]]"
2996,Jennifer Long Can I 956.436thousand CAE call ...,"[[NAME, [1, 14]], [TICKER, [37, 40]], [NOTIONA..."
2997,May I 282.922billion FTRP,"[[TICKER, [21, 25]], [NOTIONAL, [6, 20]]]"
2998,49.527trillion NSCIF sell,"[[TICKER, [15, 20]], [NOTIONAL, [0, 14]]]"


In [50]:
# word_to_ix = {}                             # 训练集词典 {词——索引}
# for sentence, tags in training_data:
#     for word in sentence:
#         if word not in word_to_ix:
#             word_to_ix[word] = len(word_to_ix)
# START_TAG = "<START>"
# STOP_TAG = "<STOP>"
# tag_to_ix = {"B-NAM": 0, "B-TIC": 1, "B-NOT": 2, 
#              "I-NAM": 3, "I-TIC": 4, "I-NOT": 5,
#              "O": 6, START_TAG: 7, STOP_TAG: 8}  # 标签词典 {标注——索引}

def get_tags(label, length):
    tags = ['O'] * length
    for entity_type, [start, end] in label:
        entity_type = entity_type[0:3]
        tags[start] = 'B-' + entity_type
        for i in range(start+1, end):
            tags[i] = 'I-' + entity_type
    return tags

In [58]:
data_size = len(df)
data_set = {'training': [], 'test': []}
tokens = []
tags = []
for i in tqdm(range(0, data_size//5*4), desc = '读取训练数据', position=0):
    text, label = df.loc[i]
    tokens = text.split()
    tags = get_tags(label, len(text))
    data_set['training'].append((tokens, tags))
for i in tqdm(range(data_size//5*4, data_size), desc = '读取测试数据', position=0):
    text, label = df.loc[i]
    tokens = text.split()
    tags = get_tags(label, len(text))
    data_set['test'].append((tokens, tags))
print(len(data_set['training']))
print(len(data_set['test']))

读取训练数据: 100%|█████████████████████████████████████████████████████████████| 2400/2400 [00:00<00:00, 14349.50it/s]
读取测试数据: 100%|███████████████████████████████████████████████████████████████| 600/600 [00:00<00:00, 12278.05it/s]

2400
600





In [47]:
from tqdm import tqdm
def fun():
    j = 0
    for i in range(1000):
        j += 1
qdm(fun()):


0it [00:00, ?it/s][A


TypeError: 'NoneType' object is not iterable

In [61]:
# defaultdict测试
from collections import defaultdict
word_num = -1
def zero():
    word_num += 1
    return word_num
tokens = ['猪','狗','ni','you','what','fuck','猪','ni']
word_to_ix = defaultdict(lambda)
for word in tokens:
    word_to_ix[word] = len(word_to_ix)
word_to_ix

NameError: name 'word_to_ix' is not defined