In [7]:
from nltk.tokenize import word_tokenize
from collections import Counter, defaultdict
from math import log2
import re

### 预处理文本

In [8]:
def preprocess_text(text):
    sentences = text.split("__eou__")
    sentences.pop()
    words = []
    for sentence in sentences:
        sentence = re.sub(r"[^\w\s']", "", sentence).lower()
        words += word_tokenize(sentence)
        words.append("<beg>")
    return words


def preprocess_text2(text):
    sentences = text.split("__eou__")
    sentences.pop()
    words = []
    for sentence in sentences:
        sentence = re.sub(r"[^\w\s']", "", sentence).lower()
        words.append(["<beg>"] + word_tokenize(sentence) + ["</end>"])
    return words

### 构建词汇表

In [9]:
def build_vocab(words):
    vocab = Counter(words)
    return vocab

### 计算bigram词频

In [10]:
def calculate_bigram(text):
    bigram_counts = defaultdict(dict)
    for sentence in text:
        for i in range(len(sentence) - 1):
            if sentence[i + 1] not in bigram_counts[sentence[i]]:
                bigram_counts[sentence[i]][sentence[i + 1]] = 1
            else:
                bigram_counts[sentence[i]][sentence[i + 1]] += 1
    return bigram_counts

### 计算bigram概率

In [11]:
def calculate_bigram_probs(bigram_counts,vocab):
    bigram_probs = defaultdict(dict)
    for prev_word, list in bigram_counts.items():
        for back_word, count in list.items():
            bigram_probs[prev_word][back_word] = (count + 1) /(
                vocab[prev_word] + len(vocab))
    return bigram_probs

### 计算句子困惑度

In [12]:
def sentence_perplexity(text, bigram_probs, vocab, bigram_counts):
    perplexity = []
    for sentence in text:
        prob = 0
        for i in range(len(sentence) - 1):
            if sentence[i] not in vocab:  # w1是未登录词
                prob += len(vocab)
            elif sentence[i + 1] not in bigram_probs[sentence[i]]: 
                # w1不是未登录词而w2是
                prob += log2(1 / (vocab[sentence[i]] + len(vocab)))
            else:
                # 都不是未登录词
                prob += log2(bigram_probs[sentence[i]][sentence[i + 1]])
        perplexity.append(pow(2, -(prob / (len(sentence) - 1))))
    return perplexity

### 测试

In [13]:
def text_perplexity(perplexity):
    return sum(perplexity) / len(perplexity)
with open("train_LM.txt", "r", encoding="utf-8") as file:
    text = file.read()
with open("test_LM.txt", "r", encoding="utf-8") as file:
    test_text = file.read()

words = preprocess_text(text)
vocab = build_vocab(words)

train_text = preprocess_text2(text)
bigram_counts = calculate_bigram(train_text)
bigram_probs = calculate_bigram_probs(bigram_counts,vocab)
test_text = preprocess_text2(test_text)
perplexity = sentence_perplexity(test_text, bigram_probs, vocab, bigram_counts)
test_perplexity = text_perplexity(perplexity)

print(test_perplexity)

675.7854405101131
