# n元语言模型回退算法

本次作业要求补全本笔记中的n元语言模型的采用Good-Turing折扣的Katz回退算法。

### 预处理

首先创建一些预处理函数。

引入必要的模块，定义些类型别名。

In [1]:
import re
import itertools

from typing import List, Dict, Tuple

Sentence = List[str]
IntSentence = List[int]

Corpus = List[Sentence]
IntCorpus = List[IntSentence]

Gram = Tuple[int]

下面的函数用于将文本正则化并词元化。该函数会将所有英文文本转为小写，去除文本中所有的标点，简单起见将所有连续的数字用一个`N`代替，将形如`let's`的词组拆分为`let`和`'s`两个词。

In [2]:
_splitor_pattern = re.compile(r"[^a-zA-Z']+|(?=')")
_digit_pattern = re.compile(r"\d+")
def normaltokenize(corpus: List[str]) -> Corpus:
    """
    Normalizes and tokenizes the sentences in `corpus`. Turns the letters into
    lower case and removes all the non-alphadigit characters and splits the
    sentence into words and added BOS and EOS marks.

    Args:
        corpus - list of str

    Return:
        list of list of str where each inner list of str represents the word
          sequence in a sentence from the original sentence list
    """

    tokeneds = [ ["<s>"]
               + list(
                   filter(lambda tkn: len(tkn)>0,
                       _splitor_pattern.split(
                           _digit_pattern.sub("N", stc.lower()))))
               + ["</s>"]
                    for stc in corpus
               ]
    return tokeneds

接下来定义两个函数用来从训练语料中构建词表，并将句子中的单词从字符串表示转为整数索引表示。

In [3]:
def extract_vocabulary(corpus: Corpus) -> Dict[str, int]:
    """
    Extracts the vocabulary from `corpus` and returns it as a mapping from the
    word to index. The words will be sorted by the codepoint value.

    Args:
        corpus - list of list of str

    Return:
        dict like {str: int}
    """

    vocabulary = set(itertools.chain.from_iterable(corpus))
    vocabulary = dict(
            map(lambda itm: (itm[1], itm[0]),
                enumerate(
                    sorted(vocabulary))))
    return vocabulary

def words_to_indices(vocabulary: Dict[str, int], sentence: Sentence) -> IntSentence:
    """
    Convert sentence in words to sentence in word indices.

    Args:
        vocabulary - dict like {str: int}
        sentence - list of str

    Return:
        list of int
    """

    return list(map(lambda tkn: vocabulary.get(tkn, len(vocabulary)), sentence))

接下来读入训练数据，将数据预处理。

In [4]:
import functools

with open("data/news.2007.en.shuffled.deduped.train", encoding='UTF-8') as f:
    texts = list(map(lambda l: l.strip(), f.readlines()))

print("Loaded training set.")

corpus = normaltokenize(texts)
vocabulary = extract_vocabulary(corpus)
corpus = list(
        map(functools.partial(words_to_indices, vocabulary),
            corpus))

print("Preprocessed training set.")

Loaded training set.
Preprocessed training set.


In [5]:
print(len(corpus[0]), len(corpus[1]))

print(corpus[0])
print(corpus[1])

36 32
[3581, 165570, 86217, 62968, 165570, 34971, 167440, 82721, 178887, 132056, 12791, 165570, 142662, 142568, 6618, 134164, 176515, 9880, 131488, 7782, 99960, 165570, 169869, 80313, 3582, 9880, 165570, 109285, 8702, 144231, 119677, 165570, 68036, 80313, 3582, 3580]
[3581, 165570, 178819, 62968, 165570, 58547, 180493, 41355, 82173, 135885, 51060, 166148, 179627, 169773, 36804, 165541, 83663, 106975, 76095, 118831, 63602, 135877, 41396, 120257, 55961, 37434, 135208, 165710, 79221, 81152, 5006, 3580]


### 设计模型

参照公式

$$
P_{\text{bo}}(w_k | W_{k-n+1}^{k-1}) = \begin{cases}
    d(W_{k-n+1}^k) \dfrac{C(W_{k-n+1}^k)}{C(W_{k-n+1}^{k-1})} &  C(W_{k-n+1}^k) > 0 \\
    \alpha(W_{k-n+1}^{k-1}) P_{\text{bo}}(w_k | W_{k-n+2}^{k-1}) &  \text{否则} \\
\end{cases}
$$

实现n元语言模型及采用Good-Turing折扣的Katz回退算法。

需要实现的功能包括：

1. 统计各词组（gram）在训练语料中的频数
2. 计算同频词组个数$N_r$
3. 计算$d(W_{k-n+1}^k)$
4. 计算$\alpha(W_{k-n+1}^{k-1})$
5. 根据公式计算回退概率
6. 计算概率对数与困惑度（PPL）

$d$与$\alpha$如何计算可以参考作业文件中的算法说明以及[SRILM](http://www.speech.sri.com/projects/srilm/)的[`ngram-discount(7)`手册页](http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html)。

In [6]:
import math

class NGramModel:
    def __init__(self, vocab_size: int, n: int = 4):
        """
        Constructs `n`-gram model with a `vocab_size`-size vocabulary.

        Args:
            vocab_size - int
            n - int
        """

        self.vocab_size: int = vocab_size
        self.n: int = n

        self.frequencies: List[Dict[Gram, int]] \
            = [{} for _ in range(n)]

        # 我认为应该更改为float 10/20 11:17
        self.disfrequencies: List[Dict[Gram, float]] \
            = [{} for _ in range(n)]

        self.ncounts: Dict[Gram
        , Dict[int, int]
        ] = {}

        self.discount_threshold: int = 7
        self._d: Dict[Gram, Tuple[float, float]] = {}
        self._alpha: List[Dict[Gram, float]] \
            = [{} for _ in range(n)]

        self.eps = 1e-10

    def learn(self, corpus: IntCorpus):
        """
        Learns the parameters of the n-gram model.

        Args:
            corpus - list of list of int
        """

        # self.n: lenth of n-gram
        for stc in corpus:
            # print(stc)
            for i in range(1, len(stc) + 1):  # 查看到i为止的序列
                for j in range(min(i, self.n)):  # 以j作为gram的迭代
                    # TODO: count the frequencies of the grams
                    tmp_gram = stc[i - j - 1:i]
                    tmp_gram = tuple(tmp_gram)
                    if self.frequencies[j].get(tmp_gram, 0):
                        self.frequencies[j][tmp_gram] = self.frequencies[j][tmp_gram] + 1

                        # print(self.frequencies[j][tmp_gram])
                    else:
                        self.frequencies[j][tmp_gram] = 1

                    # print(self.frequencies[j][tmp_gram])
        # print(self.frequencies[1])

        for i in range(1, self.n):  # 按照gram长度划分
            grams = itertools.groupby(
                sorted(
                    sorted(
                        map(lambda itm: (itm[0][:-1], itm[1]),  # 去除最后一项后的键值对
                            self.frequencies[i].items()),
                        key=(lambda itm: itm[1])),
                    key=(lambda itm: itm[0])))  # 再按照第一位元素排序， 后面的值为频率
            # TODO: calculates the value of $N_r$

            for key, group in grams:  # key[0]: 表示上一项的值, key[1]: 下一项出现频次

                cnt = 0
                for _ in group:
                    cnt = cnt + 1

                if key[0] in self.ncounts:
                    self.ncounts[key[0]][key[1]] = cnt  # 上一长度gram: (多一个元素所有可能的频次: 该频次的出现次数)
                else:
                    self.ncounts[key[0]] = {}
                    self.ncounts[key[0]][key[1]] = cnt

                # 记录了: 对本项(key)，可能的下一项的(r , N_r)

        return self.frequencies

    # 输入gram,计算其对应插值参数d
    # func 'd' not debugged yet
    def d(self, gram: Gram) -> float:
        """
        Calculates the interpolation coefficient.

        Args:
            gram - tuple of int

        Return:
            float
        """
        # r即本gram的出现频次
        length = len(gram)
        r = self.frequencies[length].get(gram, self.eps)

        # 调用self.ncounts， 查看下一项的不同频次的出现次数
        if gram not in self._d:
            # print(gram)
            # print(gram[:-1])
            # TODO: calculates the value of $d'$
            ncounts = self.ncounts[gram[:-1]]  # 查看前缀, counts是存储频次的字典
            # print(ncounts)

            def counts(num):
                if num in ncounts:
                    return float(ncounts[num])
                else:
                    return float(self.eps)  # 如有报错,改为0

            lamda = float(counts(1) / (counts(1) - 8 * counts(8)))
            N_r = counts(r)
            N_r_plus1 = counts(r + 1)
            d_dot = lamda * ((r + 1) * N_r_plus1) / (r * N_r) + 1 - lamda  # 求出各项参数 及d'

            self._d[gram] = (d_dot, 1.0)

        # self._d[gram] = (numerator1 / denominator, - numerator2 / denominator)  # 10/19 ques ?

        if r > 7:
            return self._d[gram][1]
        else:
            return self._d[gram][0]

    def alpha(self, gram: Gram) -> float:  # '回退' 意为降低gram大小要求
        """
        Calculates the back-off weight alpha(`gram`)

        Args:
            gram - tuple of int

        Return:
            float
        """

        n = len(gram)
        # 概率存放在 disfrequencies 内
        if gram not in self._alpha[n]:
            if gram in self.frequencies[n - 1]:
                # TODO: calculates the value of $\alpha$

                # V+, V- Accumulated, sum add
                V_plus = []
                V_minus = []

                sum_pls = 0
                sum_minus = 0
                for i in range(1, self.vocab_size):
                    # 添加元素:查询V+和V-
                    index = len(gram)  # 减一是原长度 checked in test_funcs 10/19 21:35
                    gram_add = list(gram)
                    gram_add.append(i)
                    gram_add = tuple(gram_add)

                    if gram_add in self.frequencies[index]:
                        # print(gram_add)
                        # print(self.frequencies[index].get(gram_add, 0))
                        V_plus.append(i)
                        sum_pls = sum_pls + self.disfrequencies[index].get(gram_add, 0)  # checked in test
                        # sum_pls = sum_pls + self.disfrequencies[index].get(gram_add, 0)
                    else:
                        V_minus.append(i)
                        gram_minus = gram_add[1:]
                        index = index - 1

                        sum_minus = sum_minus + self.disfrequencies[index - 1].get(gram_minus, self.eps)

                numerator = 1 - sum_pls
                denominator = 1 - sum_minus

                self._alpha[n][gram] = numerator / denominator
            else:
                self._alpha[n][gram] = 1.
        return self._alpha[n][gram]

    def __getitem__(self, gram: Gram) -> float:  # 计算回退概率(调用之前实现的函数)
        """
        Calculates smoothed conditional probability P(`gram[-1]`|`gram[:-1]`).

        Args:
            gram - tuple of int

        Return:
            float
        """

        n = len(gram) - 1

        if gram not in self.disfrequencies[n]:
            if n > 0:
                # TODO: calculates the smoothed probability value according to the formulate

                # 1:C>0, use param_d
                if self.disfrequencies[n].get(gram, 0) > 0:
                    param_d = self.d(gram)
                    # print('param_d = ', param_d)
                    Ck = self.frequencies[n][gram]
                    Ck_1 = self.frequencies[n - 1][gram[:-1]]
                    P = param_d * Ck / Ck_1
                    self.disfrequencies[n][gram] = P

                # 2:C=0, use param_a
                else:
                    param_a = self.alpha(gram[:-1])
                    # print('param_a = ', param_a)
                    # P = self.disfrequencies[n-1].get(gram[1:], self.eps) * param_a
                    P = self.__getitem__(gram[1:]) * param_a
                    self.disfrequencies[n][gram] = P

            # 第一项
            else:
                self.disfrequencies[n][gram] = self.frequencies[n].get(gram, self.eps) / float(len(self.frequencies[0]))
        return self.disfrequencies[n][gram]

    # 交叉熵
    def log_prob(self, sentence: IntSentence) -> float:
        """
        Calculates the log probability of the given sentence. Assumes that the
        first token is always "<s>".

        Args:
            sentence: list of int

        Return:
            float
        """

        log_prob = 0.
        cnt = 0
        n = self.n
        for i in range(2, len(sentence) + 1):
            # TODO: calculates the log probability
            # 思路是递进地对每一个ngram求概率
            cnt = float(cnt+1)
            length = min(n, i)
            gram = tuple(sentence[i-length:i])  # 可能超出范围,有待测试
            # print(len(gram))
            # print(gram)
            log_prob = log_prob + math.log2(self.__getitem__(gram))

        log_prob = - log_prob / cnt
        return log_prob


    def ppl(self, sentence: IntSentence) -> float:
        """
        Calculates the PPL of the given sentence. Assumes that the first token
        is always "<s>".

        Args:
            sentence: list of int

        Return:
            float
        """

        # calculates the PPL

        PPL = 1.
        cnt = 0
        n = self.n
        for i in range(2, len(sentence) + 1):
            # calculates the log probability
            cnt = cnt + 1
            length = min(n, i)
            gram = tuple(sentence[i - length + 1:i])  # 可能超出范围,有待测试
            PPL = PPL / self.__getitem__(gram)

        # print(PPL)
        # print(cnt)

        PPL = math.pow(PPL, 1/cnt)
        return PPL


### 训练与测试

现在数据与模型均已齐备，可以训练并测试了。

训练模型：

In [8]:
import pickle as pkl

model = NGramModel(len(vocabulary))
model.learn(corpus)
with open("model.pkl", "wb") as f:
    pkl.dump(vocabulary, f)
    pkl.dump(model, f)

print("Dumped model.")

Dumped model.


在测试集上测试计算困惑度：

In [9]:
with open("model.pkl", "rb") as f:
    vocabulary = pkl.load(f)
    model = pkl.load(f)
print("Loaded model.")

with open("data/news.2007.en.shuffled.deduped.test") as f:
    test_set = list(map(lambda l: l.strip(), f.readlines()))
test_corpus = normaltokenize(test_set)
test_corpus = list(
        map(functools.partial(words_to_indices, vocabulary),
            test_corpus))
ppls = []
for t in test_corpus:
    ppls.append(model.ppl(t))
    print(ppls[-1])
print("Avg: ", sum(ppls)/len(ppls))

Loaded model.
331.7920174085309
5.563344879350672
7.681535885346684
9.231407600197334
2058.5496313403114
5.170079774739926
7.346955771855707
6.708097135516273
35.8399602566888
7.800398424769536
4.98059473069868
5.740536545302932
106.72802995303684
24.55205021534513
10.23939484159557
14.050691211160322
22.110248446631623
560.811966536818
7.498264092662078
12.986887242470827
13.721430265811495
31.509170294957258
17.723829181289716
16.66076665181356
10.942950730870209
5.626950182695882
11.97790166750024
4.039504825190963
10.395047853265448
6.724573289134925
14.531174264417785
24.018847141286066
9.833317186292778
4.5076702318899775
11.171390631558685
10.245512255962051
14.991131855280008
16.742101922318962
9.203150354256293
30.42616456485172
5.835219219013738
31.767545278188052
9.49133763841313
6.593217047976487
11.942381153275067
54.20236759260752
14.461080582736281
31.70004137384701
47.85187170003562
13.874095968074558
Avg:  74.96187670395679
