In [1]:
import os
import jieba
from collections import Counter
import re

In [2]:
# settings
n = 3  # 3-gram

data_path       = './Records_data.txt'          # 存放预处理后监督记录数据
wordtable_path  = './wordtable.txt'  # 词表
stopwords_path  = './CSDN_stopwords.txt' # 停止词表
testset_path    = './questions.txt'       # 测试集
prediction_path = './predictions.txt'   # 预测结果

In [3]:
ngrams_list = []  # n元组（分子）
prefix_list = []  # n-1元组（分母）

# 遍历所有预处理过的新闻文件
with open(data_path') as f:
    for line in f:
        sentence = ['<BOS>'] + line.split() + ['<EOS>']  # 列表，形如：['<BOS>', '显得', '十分', '明亮', '<EOS>']
        ngrams = list(zip(*[sentence[i:] for i in range(n)]))   # 一个句子中n-gram元组的列表
        prefix = list(zip(*[sentence[i:] for i in range(n-1)])) # 历史前缀元组的列表
        ngrams_list += ngrams
        prefix_list += prefix

ngrams_counter = Counter(ngrams_list)
prefix_counter = Counter(prefix_list)

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xbc in position 0: invalid start byte

In [None]:
all_words = []  # 词表中的全部词
with open(wordtable_path, encoding='utf-8') as f:
    for line in f.readlines()[1:]:
        all_words.append(line.split()[-1])

In [None]:
# 停止词
with open(stopwords_path, encoding='utf-8') as f:
    stopwords = f.readlines()
stopwords = set(map(lambda x:x.strip(), stopwords))  # 去除末尾换行符

In [None]:
def probability(sentence):
    """
    计算一个句子的概率。
    Params:
        sentence: 由词构成的列表表示的句子。
    Returns:
        句子的概率。
    """
    prob = 1  # 初始化句子概率
    ngrams = list(zip(*[sentence[i:] for i in range(n)]))   # 将句子处理成n-gram的列表
    for ngram in ngrams:
        # 累乘每个n-gram的概率，并使用加一法进行数据平滑
        prob *= (1 + ngrams_counter[ngram]) / (len(prefix_counter) + prefix_counter[(ngram[0], ngram[1])])
    return prob

In [None]:
def predict(pre_sentence, post_sentence, all_words, cand_num=1):
    """
    根据历史进行一个词的预测。
    Params:
        pre_sentence: 待预测词之前部分句子的分词结果构成的列表。
        post_sentence: 待预测词之后部分句子的分词结果构成的列表。
        all_words: 所有候选词构成的列表。
        cand_num: 候选词数，默认为1。
    Returns:
        一个含有cand_num个元素的列表，表示预测的词，概率由大到小排序；
        如果预测失败，返回None。
    """
    word_prob = []  # 候选词及其概率构成的元组的列表
    for word in all_words:
        # 实际上不需要算整个句子的概率，只需要算待预测词附近的概率即可，因为句子其他部分的概率不受待预测词影响
        test_sentence = pre_sentence[-(n-1):] + [word] + post_sentence[:(n-1)]  # 待预测词及其前后各n-1个词的列表
        word_prob.append( (word, probability(test_sentence)) )                  # (词, 概率)元组构成的列表

    return sorted(word_prob, key=lambda tup: tup[1], reverse=True)[:cand_num]  # 按概率降序排序并取前cand_num个

In [None]:
# 加载测试集标签（答案）
with open('./questions.txt', encoding='utf-8') as f:
    answers = [answer.strip() for answer in f]  # 答案构成的列表
    
prediction_file = open(prediction_path, 'w', encoding='utf-8')  # 存放预测结果

# 开始测试
correct_count = 0  # 预测正确的数量

with open('questions.txt', encoding='utf-8') as f:
    questions = f.readlines()  # 测试集规模
    total_count = len(questions)
    for i, question in enumerate(questions):
        question = question.strip()
        pre_mask = question[:question.index('[MASK]')]     # 待预测词的历史
        post_mask = question[question.index('[MASK]')+6:]  # 待预测词后的剩余部分
        
        pre_sentence = jieba.cut(pre_mask.replace('，', ' '))  # 分词
        post_sentence = jieba.cut(post_mask.replace('，', ' '))  # 分词
        pre_sentence = [word.strip() for word in pre_sentence if word.strip() and word not in stopwords]  # 去除停止词、空串
        post_sentence = [word.strip() for word in post_sentence if word.strip() and word not in stopwords]  # 去除停止词、空串

        predict_cand = predict(pre_sentence, post_sentence, all_words)  # 预测一个概率最大的词
        prediction_file.write(' '.join([w[0] for w in predict_cand]) + '\n')  # 将预测结果写入文件

        # 遍历多个预测结果
        for j, p in enumerate(predict_cand):
            if p[0] == answers[i]:
                print(i, '{} [{}] {}'.format(pre_mask, p[0], post_mask))
                correct_count += 1
                break
                    
prediction_file.close()

In [None]:
print('准确率：{}/{}'.format(correct_count, total_count))