# 统计分词

## HMM

In [1]:
class HMM:
    def __init__(self, model_file, trained=False):
        import os
        # 存储算法中间结果
        self.model_file = model_file
        # 状态值集合
        # B：词首，M：词中，E：词尾，S：单独成词
        self.state_list = ["B", "M", "E", "S"]
        self.invalid_trans = [("B", "S"), 
                              ("B", "B"), 
                              ("M", "B"),
                              ("M", "S"),
                              ("E", "M"), 
                              ("E", "E")]
        self.trained = trained
    
    def load_model(self):
        """从模型文件中读取之前保存的概率
        A_dic: 转移概率
        B_dic: 发射概率
        Pi_dic: 初始概率
        """
        import pickle
        with open(model_file, mode="rb") as f:
            self.A_dic, self.B_dic, self.Pi_dic = pickle.load(f)
        f.close()
    
    def init_model(self):
        """初始化各概率
        A_dic: 转移概率
        B_dic: 发射概率
        Pi_dic: 初始概率
        """
        self.A_dic = {}
        self.B_dic = {}
        self.Pi_dic = {}
        for state in self.state_list:
            self.A_dic[state] = {s: 0.0 for s in self.state_list}
            self.B_dic[state] = {}
            self.Pi_dic[state] = 0
    
    def _makeLabel(self, word):
        """
        给单一单词的各个字状态打标，返回一个状态列表
        """
        length = len(word)
        if length == 1:
            # 单字成词，因此返回 S
            return ["S"]
        else:
            # 多字词，除去词首 B， 词尾 E，其余均为 M
            return ["B"] + ["M"] * (length - 2) + ["E"]
        
    def _check_trans(self):
        """
        检查统计后的转移矩阵参数是否正确，不合理的转移是否出现
        """
        for former, latter in self.invalid_trans:
            check_value = self.A_dic[former][latter]
            error_text = "invalid trans between %s and %s with P: %.4f" % (former, latter, check_value)
            assert check_value == 0.0, error_text
            
    
    def train(self, path):
        """
        从已经分好词的语料中，学习转移概率，发射概率，初始概率
        """
        self.init_model()
        # 统计各状态出现次数，便于计算 pi。初始化为 0
        count_dic = {s: 0 for s in self.state_list}
        line_num = 0
        vocabs = set()
        with open(path, encoding="UTF-8") as f:
            for line in f:
                line_num += 1
                line = line.strip()
                if len(line) == 0:
                    continue
                character_list = [i for i in line if i != ' ']
                
                line_list = line.split()
                line_state_list = []
                for word in line_list:
                    state = self._makeLabel(word)
                    line_state_list.extend(state)
                
                assert len(character_list) == len(line_state_list), "字数量与状态数量不一致"
                
                former_state = None
                for index, (state, char) in enumerate(zip(line_state_list, character_list)):
                    count_dic[state] += 1
                    # 句首识别与统计，初始概率计算相关，一行只会有一个
                    if index == 0:
                        self.Pi_dic[state] += 1
                    # 其他概率的统计
                    else:
                        ## former_state -> state 计数 + 1
                        self.A_dic[former_state][state] += 1
                        ## state -> char 计数 + 1
                        try:
                            self.B_dic[state][char] += 1
                        except KeyError:
                            self.B_dic[state][char] = 1
                    former_state = state
        # 最后对统计计数进行归一化，得到概率：
        self.Pi_dic = {k: v * 1.0 / line_num for k, v in self.Pi_dic.items()}
        self.A_dic = {k: {k1: v1 / count_dic[k] for k1, v1 in v.items()} for k, v in self.A_dic.items()}
        ## 对发射概率 + 1做平滑
        self.B_dic = {k: {k1: (v1 + 1) / count_dic[k] for k1, v1 in v.items()} for k, v in self.B_dic.items()}
        
        self._check_trans()
        
        import dill as pickle
        with open(self.model_file, mode="wb") as f:
            pickle.dump([self.A_dic, self.B_dic, self.Pi_dic], f, pickle.HIGHEST_PROTOCOL)
        f.close()
        
        return self 
        
    def viterbi(self, text, states, start_p, trans_p, emit_p):
        """
        利用维特比算法对给定的text进行解码（state），得到最大似然的隐藏状态序列（分词状态）
        start_p: 起始状态概率 (Pi_dic) state: P
        trans_p: 状态转移概率 (A_dic) state: {state: P}
        emit_p: 发射概率 (B_dic) state: {word: P}
        """
        import itertools
        V = []  # 记录各步各状态对应的概率
        path = {}  # 记录到各点的最优路径
        vocabs = set(list(itertools.chain.from_iterable([list(v.keys()) for v in emit_p.values()])))  # 发射词汇
        
        # 前溯过程
        for t, w in enumerate(text):
            V.append({})
            # 初始化
            if t == 0:
                for y in states:
                    V[t][y] = start_p[y] * emit_p[y].get(w, 0)
                    path[y] = [y]
            else:
                ## 如果出现了未知字则设置发射概率为 1.0
                neverSeen = w not in vocabs
                emitP = emit_p[y].get(w, 0) if not neverSeen else 1.0
                new_path = {}
                
                ## 针对句子最后一字的状态选择，只能是词尾或单独成词
                valid_states = states if t < len(text) - 1 else ["E", "S"]
                for y in valid_states:
                    max_value = 0.0
                    max_y0 = ""
                    for y0 in states:
                        prob = V[t - 1][y0] * trans_p[y0].get(y, 0) * emitP
                        if prob > max_value:
                            max_value = prob
                            max_y0 = y0
                    V[t][y] = max_value  # 记录到达当前状态的最大概率
                    new_path[y] = path[max_y0] + [y]  # 记录到达当前状态的之前最优路径
                path = new_path  # 重新赋值给最优路径记录
                
        # 获取终点的最优状态，回溯最优路径
        (_, final_state) = max([(V[-1][y], y) for y in ["E", "S"]])
        return path[final_state]
    
    def cut(self, text):
        import os
        # 如果需要 load已有模型，则此处load
        if self.trained:
            self.load_model()
        
        char_states = self.viterbi(text, self.state_list, self.Pi_dic, self.A_dic, self.B_dic)
        begin, next_ = 0, 0
        for i, (state, char) in enumerate(zip(char_states, text)):
            if state == "B":
                begin = i
            elif state == "E":
                yield text[begin: (i + 1)]
                next_ = i + 1
            elif state == "S":
                yield char
                next_ = i + 1
    

In [2]:
model_path = "C:/Users/Cigar/Documents/jupyter/NLP_learn/NLP_In_Action/Chap03_save/model"
corpus_path = "F:/for learn/Python/NLP_in_Action/chapter-3/data/trainCorpus.txt_utf8"

In [3]:
hmm = HMM(model_file=model_path, trained=False)
hmm.train(corpus_path)

<__main__.HMM at 0x24e335deb70>

In [8]:
print(list(hmm.cut("这是一个非常棒的方案")))

['这是', '一个', '非常', '棒的', '方案']
