# Notebook illustrating the lecture on LMs

We will use the NLK module LM to better understand what LMs are and how they work. 

In the first part of the notebook, we will guide you through this module to get a better understanding and illustrate the notions that wer've seen during the lecture. You'll get to see counts, estimate LMs from the counts with a number of dicounting techniques and measure log-prob/perplexity of a text.

In the second part, you will be guided through training a small-sized LM on real data, from different amount of data. With a trained LM, you're then asked to write a function that generates text by sampling from the LM distribution.

Here are useful reference links to the documentation of the LM module of NLTK and of the main classes and functions that we will make use of.

Entry points in the documentation:
- https://www.nltk.org/api/nltk.lm.html
- https://www.nltk.org/howto/lm.html

You will be given more targeted links in Part I of the notebook.

Also remember to have a look at Chen and Goodman's paper if you want to have details on the various discounting, interpolation and backoff techniques.

In [1]:
import json, gzip

from nltk.lm import NgramCounter, Vocabulary
from nltk.lm.models import MLE, Laplace, WittenBellInterpolated, KneserNeyInterpolated, AbsoluteDiscountingInterpolated
from nltk.lm.preprocessing import flatten
from nltk.util import ngrams, everygrams

from statistics import mean
from collections import Counter

from tqdm import tqdm

## Part I -- Playing with a toy example to understand how things work

Let's play with a very small artificial corpus to get acquainted with NLTK's LM module and verify it does exactly what we've seen in the classroom. The idea is to
1. see the interplay between vocabulary and LM with cutoff and unk
2. get ngram probabilities with MLE estimation and Laplace smoothing
3. compute the probability / perplexity of a sequence

Let's first define a toy corpus of texts to train on.

In [None]:
text = [['<s>', 'a', 'b', 'a', '</s>'], 
        ['<s>', 'a', 'c', 'b', 'a', 'b', '</s>'], 
        ['<s>', 'a', 'a', 'c', 'b', '</s>']]

### Define the vocabulary

First thing to do is to define the vocabulary on which the language model operates. 

The vocabulary can be conveniently obtained from the (flattened) text corpus, using a specified cutoff, say n, to retain only tokens that appear at least n times in the data.

See https://www.nltk.org/api/nltk.lm.vocabulary.html for details.

In [None]:

flat_text = [token for utterance in text for token in utterance]

vocab = Vocabulary(flat_text, unk_cutoff=1) # define the vocabulary
print('vocab size =', len(vocab))
print('vocab =', vocab.counts)

# Note for later use that the vocab object has a lookup method that maps tokens in 
# an input utterance, e.g.,
test = ('<s>', 'a', 'b', 'a', 'b', 'c', '</s>')
print()
print('Encoding', test, 'as', vocab.lookup(test))

# Let's try with cutoff=3: Note that you don't need to recompute the vocabulary 
# counts and can make use of the counts computed in vocab to get a new vocab 
# with different counts

vocab2 = Vocabulary(vocab.counts, unk_cutoff=3)
print('Re encoding', test, 'as', vocab2.lookup(test))


### Creating and counting ngrams

NLTK provides a counter of ngrams, `NgramCounter()`, that can be used to compute, store and manipulate the counts of each ngrams that appear in a corpus.

Unfortunately, NgramCounter() does not take text as input but rather the list of ngrams to be considered in a text. In other words, for the text `['<s>', 'a', 'b', '</s>'` and considering a maximum ngram order of 2, updating the counter requires to input `('<s>',), ('a',), ('b',), ('</s>',), ('<s>', 'a'), ('a', 'b'), ('b', '</s>')`.

Fotunately, NLTK provides the function `everygram()` that converts a text into the corresponding list of ngrams. Note that there are other similar functions (that we will not use in this notebook) such as `ngram()`, `bigram()`, `trigram()`.

In the cells below, with illustrate things with a LM order of 3, i.e., we go up to trigrams but no further. Hence we only generate unigrams, bigrams and trigrams in the counts and estimate the LM as a trigram. You can easily change the order.

See https://www.nltk.org/api/nltk.lm.counter.html for details.

In [None]:
# Illustrating everygram() before we use it:
print(text[0], '-->', list(everygrams(text[0], max_len=3)), '\n')

counts = NgramCounter([everygrams(s, max_len=3) for s in text])
print('Total number of ngrams =', counts.N())

In [None]:
# You can access counts by specifying the history (called context in NLTK's terminology) 
# as an index: this provides all the possible completion for this history and the number 
# of times you've seen this completion. Note that the history is a list (not a tuple).

print("After 'a':", sorted(counts[['a']].items())) # get bigrams starting with 'a'
print("After 'c':", sorted(counts[['c']].items())) 

print("After 'a', 'b':", sorted(counts[['a', 'b']].items())) # get trigrams starting with ('a', 'b')
print("After 'c', 'b':", sorted(counts[['c', 'b']].items()))
print()

# You can also access all counts for a given order
print('Accessing ngrams for a given order directly:')
unigram = counts[1] # a dictionary token=count
print(list(unigram.keys()))
print('unigram counts:', sorted(unigram.items()))

bigram = counts[2] # a dictionary history=dictionary of counts for the bigram following this history
print(list(bigram.keys()))
print("bigram counts starting with 'a':", sorted(bigram[('a',)].items()))


### Estimating LM probabilities

From the counts and vocabulary, one can easily estimate LM probabilities and instantiate a LM to get the probabilities of a sequence or simply that of a particular ngram.

Let's try with a maximum likelhood estimation of the ngram probabilities and then let's try smoothing.

See https://www.nltk.org/api/nltk.lm.api.html for the general definition of the LM classe and https://www.nltk.org/api/nltk.lm.models.html for specific models that may include discounting and (recursive) interpolation.

In [None]:
print('MLE LM')

# Create and estimate LM from the vocabulary and counts
lm = MLE(3, vocabulary=vocab, counter=counts)

print('  P[b|h=a] = {:.4f}'.format(lm.score('b', context=('a',))))
print('  P[b|h=c] = {:.4f}'.format(lm.score('b', context=('c',))))
print('  P[c|h=ab] = {:.4f}'.format(lm.score('c', context=('a','b'))))
print('  P[a|h=cb] = {:.4f}'.format(lm.score('a', context=('c','b'))))

# You also have a logscore() function if you want the log-probability (base 2) directly

# TODO: Verify that these probabilities are the ones you expected!

In [None]:
print('Laplace smoothing')

lm = Laplace(3, vocabulary=vocab, counter=counts)

print('  P[b|h=a] = {:.4f}'.format(lm.score('b', context=('a',))))
print('  P[b|h=c] = {:.4f}'.format(lm.score('b', context=('c',))))
print('  P[c|h=ab] = {:.4f}'.format(lm.score('c', context=('a','b'))))
print('  P[a|h=cb] = {:.4f}'.format(lm.score('a', context=('c','b'))))

# TODO: Verify that these probabilities are the ones you expected!

And there are many other discounting and smoothing schemes implemented in NLTK LM module, among which:
- Lindstone: same as Laplace except you specify the value to add (Lindstone with 1 == Laplace)
- AbsoluteDiscountingInterpolated: asbolute discounting with interpolation as in the lecture
- KneserNeyInterpolated: a smart variant of absolute discounting
- WittenBellInterpolated: a variant of Jelinek-Mercer smoothing implementing recursive interpolation, no discounting

Invoking such LMs is straightforward from the above examples. Note however that some discounting schemes have specific parameters that you can specify. For instance, absolute discounting requires the fix the discount value. This is illustrated below.

In [None]:
print('Absolute Discounting with a dicount of 0.2')

lm = AbsoluteDiscountingInterpolated(3, discount=0.2, vocabulary=vocab, counter=counts)
print('  P[b|h=a] = {:.4f}'.format(lm.score('b', context=('a',))))
print('  P[b|h=c] = {:.4f}'.format(lm.score('b', context=('c',))))
print('  P[c|h=ab] = {:.4f}'.format(lm.score('c', context=('a','b'))))
print('  P[a|h=cb] = {:.4f}'.format(lm.score('a', context=('c','b'))))

print('\nAbsolute Discounting with a dicount of 0.7')
lm = AbsoluteDiscountingInterpolated(3, discount=0.7, vocabulary=vocab, counter=counts)
print('  P[b|h=a] = {:.4f}'.format(lm.score('b', context=('a',))))
print('  P[b|h=c] = {:.4f}'.format(lm.score('b', context=('c',))))
print('  P[c|h=ab] = {:.4f}'.format(lm.score('c', context=('a','b'))))
print('  P[a|h=cb] = {:.4f}'.format(lm.score('a', context=('c','b'))))


**We're now all set to move to real data and have fun!**

A few things that might be useful to know before moving on.

- Remember the `'<UNK>'` token that appeared when we encoded with a vocabulary built with a cutoff at three occurrences? We didn't really pay attention to this so far because we played with a closed LM where no unknown tokens appear in the training and test data. In practice, we need to encode with `vocab.lookup()` the data (train and test) before generating the ngrams, counts and probabilities. Note that the `lm.score()` and `lm.logscore()` function does the `vocab.lookup()` operation. But `everygram()` does not so you'll need to do it by yourself.  

- We will have to compute perplexity of texts to evaluate the various LMs we will build. NLTK provides a perplexity function that you can make use of. But I'm not super happy with how it's implemented and sort of disagree with how perplexity is computed (better said, there are high risks that you do the wrong thing if you use it directly). We will thus reimplement one.

- Note that `Vocabulary()` and `NgramCounter()` both provide an `update()` method so you don't really have to provide all the required information at once: you can invoke `update()` in an iterative manner while you scan your data. This is not super efficient though and we fortunately don't need that here. Just in case you need that at some point. 

Anyway, there's much more than what's illustrated here with this library. But keep in mind that it's by far nt the most efficient and practical implementation of ngram LMs.


## Part II -- Play with real data and make real LMs

To faciliate things, we provide you with pre-processed data from the Signal Media 1M article dataset. It consists of a fairly large number of news articles (from a decade back). Details on the dataset can be found at https://ceur-ws.org/Vol-1568/paper8.pdf and the original data can be downloaded from https://research.signal-ai.com/newsir16/signal-dataset.html.

To make things easy, the original corpus went through heavy article and utterance selection and cleansing, preprocessing with spaCy's en_web_core_md pipeline. We kept the first 100k articles, with the following global statistics:
```
# of docs = 100,000 (out of 886,145 in the full dataset after basic filtering)
avg. # of utterances per doc = 3.28897
total number of utterances = 328897
stats on utterance length: mean=26.90  median=24  stdev=19.469142738145738 min=1  max=2158
```
These 100k documents were split into 90k documents for training, 5k for validation, 5k for test. For each fold, we made a list of sentences with a number of tokens between 10 and 100, resulting in 271,607 utterances for training, 15,316 for validation and 15,062 for testing.

You can get this dataset from the json.gzip file provided which has the following format:
```
data['train'] # list of utterances in train
  data['train'][0] # list of tokens
  data['train'][1]
  ...
data['valid']
  ...
data['test']
```

In [2]:
with gzip.open('./signalmedia-utterances.json.gz', 'rt', encoding="utf-8") as f:
    data = json.load(f)
    
print(len(data['train']), len(data['valid']), len(data['test']))
print(data['train'][:10])

271607 15316 15062
[['<s>', 'Ally', 'Financial', 'Inc.', ' ', 'is', 'a', 'leading', 'automotive', 'financial', 'services', 'company', 'powered', 'by', 'a', 'top', 'direct', 'banking', 'franchise', '</s>'], ['<s>', 'Ally', "'s", 'automotive', 'services', 'business', 'offers', 'a', 'full', 'spectrum', 'of', 'financial', 'products', 'and', 'services', ',', 'including', 'new', 'and', 'used', 'vehicle', 'inventory', 'and', 'consumer', 'financing', ',', 'leasing', ',', 'vehicle', 'service', 'contracts', ',', 'commercial', 'loans', 'and', 'vehicle', 'remarketing', 'services', ',', 'as', 'well', 'as', 'a', 'variety', 'of', 'insurance', 'offerings', ',', 'including', 'inventory', 'insurance', ',', 'insurance', 'consultative', 'services', 'for', 'dealers', 'and', 'other', 'ancillary', 'products', '</s>'], ['<s>', 'Ally', 'Bank', ',', 'the', 'company', "'s", 'direct', 'banking', 'subsidiary', 'and', 'member', 'FDIC', ',', 'offers', 'an', 'array', 'of', 'deposit', 'products', ',', 'including', 'ce

### Vocabulary

Let's first have a look at the vocabulary: 

- we'll use the flatten() utility function provided by NLTK which makes a (lazy) flat list of all the tokens in all utterances in the train data, efficient equivalent of  `[token for utterance in data['train'] for token in utterance]`
- we'll use no unk_cutoff to keep all counts initially, playing with cutoff in a second stage

In [4]:
vocab = Vocabulary(flatten(data['train']), unk_cutoff=1)

In [12]:
# Let's see what happens when we use a cut_off to select the vocabulary and how the out-of-vocabulary (OOV) rate
# evolves as we limit the vocabulary size.
#
# I'm afraid this takes a little bit of time: if too slow or painful, you can compute the OOV rate on a fraction
# only of the train and test set to get an idea.

def oov_rate(_data, _vocab) -> float:
    """
    Returns the out of vocabulary rate in the data.
    
    _dat: list of utterances, not encoded with lookup()
    _vocab: vocabulary
    """
    
    nunk, ntot = 0, 0
    
    for u in _data:
        # nunk += sum(map(lambda x: x == '<UNK>', _vocab.lookup(u))) # select '<unk>' tokens in the list and
        nunk += sum(map(lambda x: x == '<UNK>', _vocab.lookup(u)))
        ntot += len(u)
    
    return 100 * nunk / ntot


for cutoff in range(1,20):
    V = Vocabulary(vocab.counts, unk_cutoff=cutoff)
    oov1, oov2 = oov_rate(data['train'], V), oov_rate(data['test'], V)
    print('cutoff={:-2d}   vocsize={:6d}    %oov={:.2f}/{:.2f}'.format(cutoff, len(V), oov1, oov2))


cutoff= 1   vocsize=159561    %oov=0.00/1.07
cutoff= 2   vocsize= 88038    %oov=0.91/1.62
cutoff= 3   vocsize= 65258    %oov=1.49/2.06
cutoff= 4   vocsize= 53831    %oov=1.93/2.44
cutoff= 5   vocsize= 46327    %oov=2.31/2.77
cutoff= 6   vocsize= 41202    %oov=2.64/3.07
cutoff= 7   vocsize= 37156    %oov=2.94/3.35
cutoff= 8   vocsize= 34041    %oov=3.22/3.62
cutoff= 9   vocsize= 31447    %oov=3.49/3.89
cutoff=10   vocsize= 29375    %oov=3.72/4.10
cutoff=11   vocsize= 27552    %oov=3.96/4.31
cutoff=12   vocsize= 26021    %oov=4.17/4.51
cutoff=13   vocsize= 24664    %oov=4.38/4.69
cutoff=14   vocsize= 23473    %oov=4.57/4.88
cutoff=15   vocsize= 22408    %oov=4.76/5.05
cutoff=16   vocsize= 21462    %oov=4.94/5.20
cutoff=17   vocsize= 20623    %oov=5.12/5.34
cutoff=18   vocsize= 19917    %oov=5.27/5.49
cutoff=19   vocsize= 19208    %oov=5.43/5.65


### Counts

Fix the vocabulary and generate the counts. We will generate counts for different corpus size and see how counts evolve. Obviously, in the end, we'll take the counts from the whole corpus.


**TODO** Make a (wise) choice for the unk_cutoff value to define your vocabulary before moving on

In [5]:

# It should take 1 or 2 minutes to compute all the counts, be patient!

cutoff = 18
V = Vocabulary(vocab.counts, unk_cutoff=cutoff)

counts = NgramCounter()
counts.update([everygrams(V.lookup(s), max_len=3) for s in data['train']])


In [6]:
#
# Let's have a look at the counts
#

def count_sum(_counts, order=3):
    """
    Sum counts to return the number of unique ngrams (n \in [1,order]) and the number of occurrences
    """

    stats = []
    
    ngram = _counts[1]
    stats.append([len(ngram), sum(ngram.values())])
    
    for order in range(1,order): # foreach history/context, we have a dictionary of completion/n key/values        
        
        ndistincts, noccurrences = 0, 0
        
        for v in _counts[order+1].values():
            ndistincts += len(v)
            noccurrences += sum(v.values())

        stats.append([ndistincts, noccurrences])
    
    return stats

print('#utterances =', len(data['train']))
ntot = count_sum(counts, order=3)
for i in range(3):
    print('   order={}   #distinct={:8d}   #total={}'.format(i, ntot[i][0], ntot[i][1]))

#utterances = 271607
   order=0   #distinct=   19917   #total=7855042
   order=1   #distinct= 1219589   #total=7583435
   order=2   #distinct= 3636446   #total=7311828


In [None]:
#
# Just to illustrate the tables seen in the classroom, let's compute these stats 
# (number of distinct ngrams, number of ngram occurrences
# 

for n in (50000, 100000, 200000):
    print('#utterances =', n)
    
    c = NgramCounter()
    c.update([everygrams(V.lookup(s), max_len=3) for s in data['train'][:n]])

    ntot = count_sum(c, order=3)
    for i in range(3):
        print('   order={}   #distinct={:8d}   #total={}'.format(i, ntot[i][0], ntot[i][1]))
    # print('   #total={}'.format(sum([x[1] for x in ntot])))

# QUESTION: What do you observe in these numbers?

### Building and comparing models

Now ready to create a first model and evaluate it.

A quick recall on evaluation through perplexity. Perplexity is defined as
$$
P(C) = 2^{-\mbox{avg_log_prob}(C)}
$$
where $\mbox{avg_log_prob}(C)$ is the average ngram log-probability (base 2) measured on the corpus $C$, i.e.,
$$
    \mbox{avg_log_prob}(C) = \sum_{hw \in \mbox{ngrams}(C)} \log_2 P[w|h] \enspace .
$$
Here, $\mbox{ngrams}(C)$ designates all occurrences of ngrams hw (history followed by w) that occur in the corpus $C$. The ngram occurrences can easily be obtained through `ngrams(utterance)` after vocabulary lookup (to map unknown tokens to `'<UNK>'`) for all utterances in $C$.

We provide a quick and dirty implementation of perplexity hereunder.

In [8]:
def avg_logscore(_data, _vocab, _lm, order=None) -> float:
    """
    Return the average ngram log-prob of the utterances in _data
    
    Input:
      _data: input utterances on which to measure ngram log-probs
      _vocab: vocabulary (to map unknown tokens)
      _lm: language model
      
    Options:
      order=n use ngrams (defaults to LM order)
      
    Notes: This is a dirty implementation that distegard the first m-grams 
    for m < n of the utterance but that'll do the job.
    """
    
    order = _lm.order if order == None else order
    
    buf = [ngrams(_vocab.lookup(x), order) for x in _data]
    
    return mean([_lm.logscore(x[-1], x[:-1]) for x in flatten(buf)])


def perplexity(_data, _vocab, _lm, order=None) -> float:
    """
    Return the perplexity measured on the utterances in _data
    
    Input:
      _data: input utterances on which to measure ngram log-probs
      _vocab: vocabulary (to map unknown tokens)
      _lm: language model
      
    Options:
      order=n use ngrams (defaults to LM order)    
    """
    
    return pow(2.0, -avg_logscore(_data, _vocab, _lm, order))

In [9]:
# As some of the models are pretty slow, we make a small version of the test utterances to measure perplexity

C = data['test'][:1000]

**At this stage, you should be able to play on your own!**

So here's what you're asked to do:
1. Measure (on the downsized test data) the perplexity of a trigram LM (a) with no smoothing and (b) with Laplace smoothing. How do you explain the result?

In [10]:
lm = MLE(3, vocabulary=V, counter=counts)
print(perplexity(C, V, lm))

lm = Laplace(3, vocabulary=V, counter=counts)
print(perplexity(C, V, lm))

inf
4224.250426340526


2. Using absolute discounting, compare the perplexities of the bigram and trigrem LMs with different discount factors, filling in the table below. What is the impact of the discount factor on the two models?

| discount | 0.4 | 0.6 | 0.8 | 1 |
|----------|-----|-----|-----|---|
| bigram   |     |     |     |   |
| trigram  |     |     |     |   |


In [11]:
for d in (0.4, 0.6, 0.8, 1):
    lm = AbsoluteDiscountingInterpolated(3, discount=d, vocabulary=V, counter=counts)
    
    print(d, perplexity(C, V, lm, order=2), perplexity(C, V, lm))

0.4 216.57293899633783 142.5137685878291
0.6 212.15880849925716 127.18462408534691
0.8 212.3927900785234 121.98059526438729
1 232.6537718597303 155.7681837866273


3. Retrain a model with absolute discounting on a smaller train set, say 50k utterances rather than the full 271k ones, and compare with what you had previously. Is the optimal discount factor the same and why? How's perplexity affected by the training dataset size?

4. Check with Witten-Bell

*Corpus is too small for Kneser-Ney so you'll probably get an error if we try that.*

In [None]:
lm = WittenBellInterpolated(3, vocabulary=V, counter=counts)
print(perplexity(C, V, lm, order=2), perplexity(C, V, lm))

### Text generation

You should now have enough experience to write your own text generator given a LM. The idea is to write a function that takes on a prompt and complete it given a LM. To simplify things, we will assume your prompt to be tokenized with a tokenization that resembles the one used to prepare the data.

*Hint*: Given an history as a tuple, you can easily get the probability distribution over the vocabulary with
```
h = ('certificates', 'of')
distrib = [lm.score(w, h) for w in vocab]
```
and use `numpy.random.choice()` to pick one word from this distribution.

Unfortunately, with NLTK, this is a bit long to compute with a 20k vocabulary if you use an interpolated LM (which I strongly recommand) but it remains manageable. You may want to downsize your vocabulary of not patient enough.

Once you're done with your generation function, enjoy and play with it, try with various smoothing techniques or with various LM order.

In [None]:
def generate(_lm, _vocab, prompt=['<s>'], order=None, maxlen=100) -> list[str]:
    """
    Generate text from the prompt based for the given LM.
    
    Input:
      _lm: language model
      _vocab: vocabulary
      
    Options:
      prompt=   prompt as a list[str]
      order=    order of the LM (defaults to LM order)
      maxlen=   stop generating after maxlen tokens
    """
    text = [token for token in prompt]
    
    # TO BE COMPLETED
    
    return text
    