# corpus bootstrapping with LM perplexity

toy test of Ramaswamy, Printz, Gopalakrishnan: *A Bootstrap Technique for Building Domain-Dependent Langauge Models*
http://mirlab.org/conference_papers/International_Conference/ICSLP%201998/PDF/SCAN/SL980611.PDF

- uses simple bigram LM with add-k smoothing
- in-domain data from Jane Austen
- 'unlabeled' corpus of Austen,Carroll and Melville sentences

In [1]:
from collections import Counter
import math
import nltk
import random
import re

# in-domain corpus data

load the three jane austen texts as tokenized sentences, preprocess (lowercase, remove punctuation etc, add `<s>` and `</s>` tags)

In [2]:
# read corpora
austen1 = nltk.corpus.gutenberg.sents('austen-emma.txt')
austen2 = nltk.corpus.gutenberg.sents('austen-persuasion.txt')
austen3 = nltk.corpus.gutenberg.sents('austen-sense.txt')
data = austen1 + austen2 + austen3
print(len(data))

16498


In [3]:
%%time
# shuffle data and withhold random set
indices = [i for i in range(len(data))]
random.shuffle(indices)
data = [data[i] for i in indices]

test_idx = int(len(data)*0.25)
corpus = data[:test_idx]
withheld = data[test_idx:]
data = None # clear
print(len(corpus), len(withheld))

4124 12374
CPU times: user 5.88 s, sys: 92 ms, total: 5.97 s
Wall time: 5.97 s


## preprocessing

remove sents of len < 5 (words)

In [4]:
%%time
# preprocess data (with function)
def preprocess(tokens):
    processed = []
    for sent in tokens:
        if len(sent) > 4:
            this_sent = []
            for word in sent:
                if re.findall(r'[0-9A-Za-z]+', word):
                    this_sent.append(word.lower())
            this_sent = ['<s>'] + this_sent + ['</s>']
            processed.append(this_sent)
    return processed

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 3.58 µs


In [5]:
corpus = preprocess(corpus)
withheld = preprocess(withheld)
print("seed corpus:", len(corpus), "withheld:", len(withheld))

seed corpus: 3858 withheld: 11600


# build bigram language model

we will constuct this as a function so we can iterate

In [6]:
def languagemodel(corpus):
    
    # bigramize and get vocabulary
    bigrams = []
    vocab = []
    for sent in corpus:
        # split
        for i in range(len(sent)-1):
            vocab.append(sent[i])
            bigrams.append((sent[i], sent[i+1]))
        vocab.append(sent[-1])
        
    # get vocabulary size, counters
    # oovs will be handled by smoothing
    vocabsize = len(set(vocab))
    vocab_counts = Counter(vocab)
    bigram_counts = Counter(bigrams)
    
    return vocabsize, vocab_counts, bigram_counts

In [7]:
%%time
vocabsize, vocab_counts, bigram_counts = languagemodel(corpus)

CPU times: user 68 ms, sys: 12 ms, total: 80 ms
Wall time: 78.8 ms


# perplexity measurement

this is a function of the probability of a given sentence under this language model's assumptions (i.e the product of the probability of each bigram normalized by the total sentence length)

we use add-k smoothing to account for out-of-vocab terms

In [8]:
# perplexity calculation with add-k smoothing
# from http://www3.cs.stonybrook.edu/~has/CSE594/Notes/(1)%20Probability%20Theory%20Application%201-28.pdf
def perplexity(sent, k=0.1):
    # tokenize and bigramize
    bigrams = []
    for i in range(len(sent)-1):
        bigrams.append((sent[i], sent[i+1]))
        
    # for each bigram, get probability (add-one smoothed)
    # p(w_i | w_i-1) = count(w_i-1, w_i) + 1 / count(w_i-1 + N)
    probs = []
    for bigram in bigrams:
        b_c = bigram_counts[bigram] + k
        w_c = vocab_counts[bigram[0]] + k*vocabsize
        probs.append(b_c/w_c)
    
    # get inverse
    inverse = [1.0/prob for prob in probs]
    
    # get product
    product = 1.0
    for inv in inverse:
        product *= inv
    
    # n-th root
    perplexity = math.pow(product, (1/len(sent)))
    
    return perplexity

# external corpus

this is data from an external source that (hopefully) includes some sentences that we can use for data augmentation.

here we (artificially) create a mixed id/ood corpus by mixing our withheld data in with some text from another source. we will use moby dick because it is one of the NLTK prose texts closer in time to Jane Austen

for testing, we will label each sentence according to source

In [9]:
withheld = [('austen', s) for s in withheld]
len(withheld)

11600

In [10]:
melville = nltk.corpus.gutenberg.sents('melville-moby_dick.txt')
melville = preprocess(melville)
melville = [('melville', s) for s in melville]

carroll = nltk.corpus.gutenberg.sents('carroll-alice.txt')
carroll = preprocess(carroll)
carroll = [('carroll', s) for s in carroll]

len(melville), len(carroll)

(8857, 1514)

In [11]:
unlabeled = withheld + melville + carroll
len(unlabeled)

21971

# sort by perplexity score

as we can see, in-domain answers are at the top. of course it is not the case that necessarily all *(true)* in-domain sentences are at the top of the list.

In [12]:
%%time
# get perplexities
perplexities = [perplexity(s[1]) for s in unlabeled]

CPU times: user 448 ms, sys: 0 ns, total: 448 ms
Wall time: 445 ms


In [13]:
# sort by perplexity (lower = better)
[(x[0], ' '.join(x[1]), y) for x, y in sorted(zip(unlabeled, perplexities), key=lambda pair: pair[1])]

[('austen', '<s> i do not know </s>', 14.729173304674147),
 ('austen', '<s> i do not know </s>', 14.729173304674147),
 ('austen', '<s> i do not know </s>', 14.729173304674147),
 ('austen', '<s> i am sure i do not know </s>', 16.39012305889204),
 ('carroll', '<s> i had not </s>', 17.045270817412888),
 ('austen', '<s> she had seen him </s>', 18.171172050776057),
 ('austen', '<s> i am sure i had not </s>', 18.19799968118952),
 ('melville', '<s> i am sure that i did not </s>', 19.704061956263423),
 ('austen', '<s> i do not see it </s>', 19.839176305286784),
 ('austen', '<s> i think it is so </s>', 20.05698120009778),
 ('austen', '<s> i am very happy </s>', 20.77234144496522),
 ('austen', '<s> i am sure you will not like him </s>', 20.95163418579466),
 ('melville', '<s> i will tell you </s>', 21.454686379053655),
 ('austen', '<s> yes i do </s>', 22.640216351337635),
 ('melville', '<s> but it is not so </s>', 22.66385587291757),
 ('austen', '<s> i am said he </s>', 22.958280323796355),
 ('au

# iterate

this is meant to be an iterative algorithm, so we add the top sentences (using threshold) to the original training data, make a new language model, and calculate new perplexity scores over the outside data.

In [14]:
iters = 500

add_corpus = corpus[:]       # the expanding id-corpus
rem_unlabeled = unlabeled[:] # the shrinking unlabeled data
additions = []               # track additions to lm corpus
threshhold = 100.0           # perplexity threshhold
cutoff = 25                  # cutoff for added sents, make large to 'ignore'
k = 0.001                    # smoothing term for perplexity add-k smoothing

for i in range(iters):
    
    # EarlyStopping
    cnt = 0
    
    # indices to remove from unlabeled data
    remove_idx = []
    
    # build language model
    vocabsize, vocab_counts, bigram_counts = languagemodel(add_corpus)
    
    # rank perplexities
    perplexities = [perplexity(s[1], k) for s in rem_unlabeled]
    
    # take top sents
    add = 0
    for idx, perp in enumerate(perplexities):
        if perp < threshhold:
            additions.append(rem_unlabeled[idx])
            add_corpus.append(rem_unlabeled[idx][1])
            remove_idx.append(idx)
            cnt += 1
            add += 1
        if add == cutoff:
            break
    
    # filter out additions
    rem_unlabeled = [rem_unlabeled[i] for i in range(len(rem_unlabeled)) if i not in remove_idx]
    
    # if no added sents, terminate
    if cnt == 0:
        print("no added sentences, stopping...\n")
        debug = [(x[0], ' '.join(x[1]), y) for x, y in sorted(zip(rem_unlabeled, perplexities), key=lambda pair: pair[1])][:100]
        for d in debug[:10]:
            print(d)
        break
    
    if i > 0 and i % 10 == 0:
        print("iter", i, ": total added", len(additions), "sents")
            

iter 10 : total added 275 sents
iter 20 : total added 525 sents
iter 30 : total added 775 sents
iter 40 : total added 1025 sents
iter 50 : total added 1275 sents
iter 60 : total added 1525 sents
iter 70 : total added 1775 sents
iter 80 : total added 2025 sents
no added sentences, stopping...

('melville', '<s> i do not think that any sensation lurks in it </s>', 100.19508495878277)
('austen', '<s> the consequence said she has been a state of perpetual suffering to me and so it ought </s>', 100.55487042161359)
('austen', '<s> no no no how can you suspect me of such a thing </s>', 100.73587402614565)
('austen', '<s> there was much to regret </s>', 100.73615839319187)
('melville', '<s> he snorts to think of it </s>', 100.84287118187619)
('austen', '<s> it is time for us to be going indeed </s>', 100.98943711050511)
('austen', '<s> said i i cannot think who you mean </s>', 101.22991326463982)
('austen', '<s> i hope i know better now than to care for mr martin or to be suspected of it </s>'

## evaluation

In [15]:
diff = len(add_corpus) - len(corpus)
totl = len(unlabeled)
print("sents found:", diff, "(%", diff*100/totl, "of unlabeled)")

sents found: 2157 (% 9.817486686996496 of unlabeled)


In [16]:
labels = [t[0] for t in additions]
corrects = [t[0] for t in additions if t[0]=='austen']
print("precision of found sents: ", len(corrects)/len(labels))

precision of found sents:  0.9058878071395456


In [17]:
recall = len(corrects)/len(withheld)
print("recall of unlabeled austen sents: ", recall)

recall of unlabeled austen sents:  0.16844827586206895
