In [1]:
import torch
import torch.nn as nn

In [3]:
class HMM:
    def __init__(self, N, M, word2id, tag2id):
        '''
        N: 状态数，对应存在的标注的种类
        M: 观测数，对应有多少个不同的字
        '''
        self.N = N
        self.M = M
        
        # 状态转移概率矩阵
        self.A = torch.zeros(N, N)
        # 观测概率矩阵，B[i][j]表示i状态下生成j观测的概率
        self.B = torch.zeros(N, M)
        # 初始状态概率
        self.Pi = torch.zeros(N)
        self.word2id = word2id   # 将单词映射为ID
        self.tag2id = tag2id     # 将标签映射为ID
        
    def train(self, word_lists, tag_lists):
        '''
        可以使用极大似然估计法来估计隐马模型的参数
        @params:
        word_lists: 二维list型，每个元素由字组成，如['担', '任', '科', '员']
        tag_lists: 二维list型，每个元素是对应的标注，如['O', 'O', 'B-TITLE', 'E-TITLE']
        '''
        assert len(tag_lists) == len(word_lists)
        
        # 估计转移概率矩阵
        for tag_list in tag_lists:
            seq_len = len(tag_list)
            for i in range(seq_len - 1):
                current_tagid = self.tag2id[tag_list[i]]
                next_tagid = self.tag2id[tag_list[i+1]]
                self.A[current_tagid][next_tagid] += 1 
        # 如果某元素没有出现过，进行平滑
        self.A[self.A == 0.] = 1e-10
        self.A = self.A / self.A.sum(dim=1, keepdim=True)
        
        
        # 估计观测概率矩阵
        for tag_list, word_list in zip(tag_lists, word_lists):
            assert len(tag_list) == len(word_list)
            for tag, word in zip(tag_list, word_list):
                tag_id = self.tag2id[tag]
                word_id = self.word2id[word]
                self.B[tag_id][word_id] += 1
        self.B[self.B == 0.] = 1e-10
        self.B = self.B / self.B.sum(dim=1, keepdim=True)
        
        
        # 估计初始状态概率
        for tag_list in tag_lists:
            init_tagid = self.tag2id[tag_list[0]]
            self.Pi[init_tagid] += 1
        self.Pi[self.Pi == 0] = 1e-10
        self.Pi = self.Pi / self.Pi.sum()
        
    def test(self, word_lists):
        '''
        用于最终的预测
        '''
        pred_tag_lists = []
        for word_list in word_lists:
            pred_tag_list = self.decoding(word_list)
            pred_tag_lists.append(pred_tag_list)
        return pred_tag_lists
    
    
    def decoding(self, word_list):
        '''
        使用维特比算法对给定观测序列求状态序列，对字组成的序列，求解对应的标注
        '''
        # 问题：整条链比较长的情况下，小概率相乘容易造成下溢
        # 所以可以采用对数概率，将小概率转化为负数，相乘转化为相加
        A = torch.log(self.A)
        B = torch.log(self.B)
        Pi = torch.log(self.Pi)
        
        # 初始化viterbi矩阵，它的维度为[状态数, 序列长度]
        # viterbi[i, j]表示标注序列的第j个标注为i的所有单个序列出现的概率最大值
        seq_len = len(word_list)
        viterbi = torch.zeros(self.N, seq_len)
        # backpointer用于回溯
        # backpointer[i][j]存储的是：标注序列的第j个标注为i时，第j-1个标注的id
        backpointer = torch.zeros(self.N, seq_len).long()
        
        # self.Pi[i]表示第一个字的标记为i的概率
        # Bt[word_id]表示字为word_id的时候，对应各个标记的概率
        # self.A.t()[tag_id]表示各个状态转移到tag_id对应的概率
        start_wordid = self.word2id.get(word_list[0], None)
        Bt = B.t()
        if start_wordid is None:
            # 如果字不在字典里，则假设状态的概率分布是均匀的
            bt = torch.log(torch.ones(self.N) / self.N)
        else:
            bt = Bt[start_wordid]
        viterbi[:, 0] = Pi + bt
        backpointer[:, 0] = -1 
        
        # 递推公式：
        # viterbi[tag_id, step] = max(viterbi[:, step-1]+self.A.t()[tag_id]+Bt[word])
        for step in range(1, seq_len):
            wordid = self.word2id.get(word_list[step], None)
            # 如果字不在字典中，则假设为均匀分布
            if wordid is None:
                # 如果字不在字典中，则假设状态概率分布是均匀的
                bt = torch.log(torch.ones(self.N) / self.N)
            else:
                bt = Bt[wordid]
            ## 获取前一个step概率最大的状态，计算当前的概率
            for tag_id in range(len(self.tag2id)):
                max_prob, max_id = torch.max(viterbi[:, step-1] + A[:, tag_id],
                                            dim=0)
                viterbi[tag_id, step] = max_prob + bt[tag_id]
                backpointer[tag_id, step] = max_id

        # 终止，t=seq_len，即viterbi[:, seq_len]中的最大概率，就是最优路径的概率
        best_path_prob, best_path_pointer = torch.max(
            viterbi[:, seq_len-1], dim=0
        )

        # 回溯，求最优路径
        best_path_pointer = best_path_pointer.item()
        best_path = [best_path_pointer]
        for back_step in range(seq_len-1, 0, -1):
            best_path_pointer = backpointer[best_path_pointer, back_step]
            best_path_pointer = best_path_pointer.item()
            best_path.append(best_path_pointer)

        # 将 tag_id组成的序列转化为tag
        assert len(best_path) == len(word_list)
        id2tag = dict((id_, tag) for tag, id_ in self.tag2id.items())
        tag_list = [id2tag[id_] for id_ in reversed(best_path)]

        return tag_list

# 读取数据

In [2]:
from DataProcess.data import build_corpus

In [4]:
train_word_lists, train_tag_lists, word2id, tag2id = build_corpus("train")

In [5]:
dev_word_lists, dev_tag_lists = build_corpus("dev", make_vocab=False)
test_word_lists, test_tag_lists = build_corpus("test", make_vocab=False)

# 训练模型

In [6]:
from DataProcess.utils import save_model
from DataProcess.evaluating import Metrics

In [7]:
def hmm_train_eval(train_data, test_data, word2id, tag2id, remove_O=False):
    '''
    训练并评估模型
    '''
    train_word_lists, train_tag_lists = train_data
    test_word_lists, test_tag_lists = test_data
    
    hmm_model = HMM(len(tag2id), len(word2id), word2id, tag2id)
    hmm_model.train(train_word_lists, train_tag_lists)
    
    save_model(hmm_model, './ckpts/hmm.pkl')
    # 评估hmm模型
    pred_tag_lists = hmm_model.test(test_word_lists)
    metrics = Metrics(test_tag_lists, pred_tag_lists, remove_O=remove_O)
    metrics.report_scores()
    metrics.report_confusion_matrix()
    
    return pred_tag_lists

In [8]:
hmm_pred = hmm_train_eval((train_word_lists, train_tag_lists), (test_word_lists, test_tag_lists),
                         word2id, tag2id)

           precision    recall  f1-score   support
   B-RACE     1.0000    0.9286    0.9630        14
   B-NAME     0.9800    0.8750    0.9245       112
  M-TITLE     0.9038    0.8751    0.8892      1922
    B-LOC     0.3333    0.3333    0.3333         6
    E-ORG     0.8262    0.8680    0.8466       553
    E-PRO     0.6512    0.8485    0.7368        33
    E-LOC     0.5000    0.5000    0.5000         6
   M-CONT     0.9815    1.0000    0.9907        53
  B-TITLE     0.8811    0.8925    0.8867       772
    M-ORG     0.9002    0.9327    0.9162      4325
   E-RACE     1.0000    0.9286    0.9630        14
  E-TITLE     0.9514    0.9637    0.9575       772
    B-EDU     0.9000    0.9643    0.9310       112
   E-CONT     0.9655    1.0000    0.9825        28
        O     0.9568    0.9177    0.9369      5190
    E-EDU     0.9167    0.9821    0.9483       112
    B-ORG     0.8422    0.8879    0.8644       553
    M-EDU     0.9348    0.9609    0.9477       179
   E-NAME     0.9000    0.8036 