In [2]:
from nltk.tokenize import word_tokenize
from collections import Counter
from math import log2
import re

### 数据预处理，将文本处理成单词列表

In [3]:
def preprocess_text(text):
    sentences = text.split("__eou__")  # 分句
    sentences.pop()
    words = []
    for sentence in sentences:
        sentence = re.sub(r"[^\w\s']", "", sentence).lower()  # 去除标点，改为小写
        words += word_tokenize(sentence)  # 分词
    return words

### 数据预处理，将文本处理成二维列表，text列表存储sentence列表，sentence存这个句子的单词

In [4]:
def preprocess_text2(text):
    sentences = text.split("__eou__")
    sentences.pop()
    text = []
    for sentence in sentences:
        sentence = re.sub(r"[^\w\s']", "", sentence).lower()
        text.append(word_tokenize(sentence))
    return text

### 构建词汇表，vocab为字典，每个项的格式为{word:count}

In [5]:
def build_vocab(words):
    vocab = Counter(words)
    return vocab

### 计算unigram概率（加一平滑）

In [6]:
def calculate_unigram_probs(vocab, total_words):
    unigram_probs = {}
    for word, count in vocab.items():
        unigram_probs[word] = (count + 1) / (total_words + len(vocab))
    return unigram_probs

### 计算句子困惑度

In [7]:
def sentence_perplexity(text, unigram_probs, vocab, total_words):
    perplexity = []
    for sentence in text:
        prob = 0
        for word in sentence:
            if word in unigram_probs:
                prob += log2(unigram_probs[word])
            else:
                prob += log2(1 / (len(vocab) + total_words))  # 未知单词的概率
        perplexity.append(pow(2, -(prob / len(sentence))))
    return perplexity

### 评估文本困惑度

In [8]:
def text_perplexity(perplexity):
    return sum(perplexity) / len(perplexity)

### 加载数据

In [9]:
with open("train_LM.txt", "r", encoding="utf-8") as file:
    train_text = file.read()
with open("test_LM.txt", "r", encoding="utf-8") as file:
    test_text = file.read()

words = preprocess_text(train_text)  # 单词列表
vocab = build_vocab(words)  # 词汇表
unigram_probs = calculate_unigram_probs(vocab, len(words))  # unigram概率
test_text = preprocess_text2(test_text)  # text二维列表
perplexity = sentence_perplexity(test_text, unigram_probs, vocab, len(words))  # 句子困惑度列表
test_perplexity = text_perplexity(perplexity)  # 文本困惑度

print(test_perplexity)

902.3264658565882
