# corpus bootstrapping with LM perplexity

**purpose** here we use domain-specific language model perplexity to extract in-domain sentences from a general "unlabeled" corpus (of in-domain and out-of-domain sentences).

this is a toy example of Ramaswamy, Printz, Gopalakrishnan: *A Bootstrap Technique for Building Domain-Dependent Language Models*  
http://mirlab.org/conference_papers/International_Conference/ICSLP%201998/PDF/SCAN/SL980611.PDF

- uses simple bigram LM with add-k smoothing
- in-domain data from Jane Austen
- 'unlabeled' corpus of Jane Austen, Lewis Carroll and Herman Melville sentences

In [1]:
from collections import Counter
import math
import nltk
import random
import re

# load in-domain corpus data

load the three jane austen texts form NLTK as tokenized sentences and preprocess them

(lowercase, remove punctuation etc, add `<s>` and `</s>` tags)

In [2]:
# read corpora
austen1 = nltk.corpus.gutenberg.sents('austen-emma.txt')
austen2 = nltk.corpus.gutenberg.sents('austen-persuasion.txt')
austen3 = nltk.corpus.gutenberg.sents('austen-sense.txt')
data = austen1 + austen2 + austen3
print(len(data))

16498


In [3]:
%%time
# shuffle data and withhold random set
indices = [i for i in range(len(data))]
random.shuffle(indices)
data = [data[i] for i in indices]

test_idx = int(len(data)*0.25)
corpus = data[:test_idx]
withheld = data[test_idx:]
data = None # clear
print(len(corpus), len(withheld))

4124 12374
CPU times: user 5.77 s, sys: 120 ms, total: 5.89 s
Wall time: 5.89 s


### filtering in-domain data by length

remove sentences of len < 5 (words)

In [4]:
%%time
# preprocess data (with function)
def preprocess(tokens):
    processed = []
    for sent in tokens:
        if len(sent) > 6:
            this_sent = []
            for word in sent:
                if re.findall(r'[0-9A-Za-z]+', word):
                    this_sent.append(word.lower())
            this_sent = ['<s>'] + this_sent + ['</s>']
            processed.append(this_sent)
    return processed

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 3.58 µs


In [5]:
corpus = preprocess(corpus)
withheld = preprocess(withheld)
print("seed corpus:", len(corpus), "withheld:", len(withheld))

seed corpus: 3609 withheld: 10886


# build bigram language model

this is a simple bigram language model with rudimentary Laplace add-k smoothing. we will build it by splitting each sentence into a list of bigram tuples and maintaining a Counter for each tuple.

we will constuct this as a function so we can call it repeatedly during the main iterative process.

In [6]:
def languagemodel(corpus):
    
    # bigramize and get vocabulary
    bigrams = []
    vocab = []
    for sent in corpus:
        # split
        for i in range(len(sent)-1):
            vocab.append(sent[i])
            bigrams.append((sent[i], sent[i+1]))
        vocab.append(sent[-1])
        
    # get vocabulary size, counters
    # oovs will be handled by smoothing
    vocabsize = len(set(vocab))
    vocab_counts = Counter(vocab)
    bigram_counts = Counter(bigrams)
    
    return vocabsize, vocab_counts, bigram_counts

In [7]:
%%time
vocabsize, vocab_counts, bigram_counts = languagemodel(corpus)

CPU times: user 76 ms, sys: 0 ns, total: 76 ms
Wall time: 74.7 ms


# perplexity measurement

this is a function of the probability of a given sentence under this language model's assumptions (i.e the product of the probability of each bigram normalized by the total sentence length)

we use add-k smoothing to account for out-of-vocab terms

In [8]:
# perplexity calculation with add-k smoothing
# from http://www3.cs.stonybrook.edu/~has/CSE594/Notes/(1)%20Probability%20Theory%20Application%201-28.pdf
def perplexity(sent, k=0.1):
    # tokenize and bigramize
    bigrams = []
    for i in range(len(sent)-1):
        bigrams.append((sent[i], sent[i+1]))
        
    # for each bigram, get probability (add-one smoothed)
    # p(w_i | w_i-1) = count(w_i-1, w_i) + 1 / count(w_i-1 + N)
    probs = []
    for bigram in bigrams:
        b_c = bigram_counts[bigram] + k
        w_c = vocab_counts[bigram[0]] + k*vocabsize
        probs.append(b_c/w_c)
    
    # get inverse
    inverse = [1.0/prob for prob in probs]
    
    # get product
    product = 1.0
    for inv in inverse:
        product *= inv
    
    # n-th root
    perplexity = math.pow(product, (1/len(sent)))
    
    return perplexity

# external corpus

this is data from an external source that (hopefully) includes some sentences that we can use for data augmentation.

here we (artificially) create a mixed id/ood corpus by mixing our withheld data in with some text from another source. we will use herman melville and lewis carroll because they are authors relatively close in time to jane austen (amongst those included in the NLTK corpus).

for evaluation, we will label each sentence according to source

In [9]:
withheld = [('austen', s) for s in withheld]
len(withheld)

10886

In [10]:
melville = nltk.corpus.gutenberg.sents('melville-moby_dick.txt')
melville = preprocess(melville)
melville = [('melville', s) for s in melville]

carroll = nltk.corpus.gutenberg.sents('carroll-alice.txt')
carroll = preprocess(carroll)
carroll = [('carroll', s) for s in carroll]

len(melville), len(carroll)

(8213, 1360)

In [11]:
unlabeled = withheld + melville + carroll
len(unlabeled)

20459

# test: sort by perplexity score

as we can see, in-domain answers are at the top. of course it is not the case that necessarily all *(true)* in-domain sentences are at the top of the list.

In [12]:
%%time
# get perplexities
perplexities = [perplexity(s[1]) for s in unlabeled]

CPU times: user 460 ms, sys: 4 ms, total: 464 ms
Wall time: 465 ms


# look at closest-correlated sentences

if we sort by perplexity score (lowest first), hopefully the sentences we get are 'closest' to the known jane austen sentences. as we can see, we don't do too porly.

In [13]:
# sort by perplexity (lower = better)
[(x[0], ' '.join(x[1]), y) for x, y in sorted(zip(unlabeled, perplexities), key=lambda pair: pair[1])][:20]

[('austen', '<s> i am sure i had not </s>', 17.193319524167045),
 ('austen', '<s> i have no doubt of it </s>', 18.244950322894468),
 ('melville', '<s> i am sure that i did not </s>', 19.135558787152608),
 ('austen', '<s> i am sure you will not like him </s>', 23.203222329802095),
 ('austen', '<s> i am glad of it </s>', 23.425978832850518),
 ('austen', '<s> i should be sorry to be more </s>', 24.39050124746715),
 ('melville', '<s> you must have heard of it </s>', 27.21363298314665),
 ('austen', '<s> that i am sure you would </s>', 28.750255105743225),
 ('austen', '<s> she had a great wish to see him </s>', 29.522205060675578),
 ('austen', '<s> as to that i do not </s>', 31.522506322212063),
 ('austen', '<s> i am said he </s>', 32.0885518355778),
 ('austen', '<s> i do not pretend to it </s>', 32.159353910701086),
 ('austen', '<s> no that he is not </s>', 32.86057514751998),
 ('austen', '<s> no indeed i do not </s>', 33.00707727014687),
 ('austen', '<s> i dare say she had </s>', 33.133797

# iterate

this is meant to be an iterative algorithm, so we add the top sentences (using threshold) to the original training data, make a new language model, and calculate new perplexity scores over the outside data.

In [14]:
iters = 500

add_corpus = corpus[:]       # the expanding id-corpus
rem_unlabeled = unlabeled[:] # the shrinking unlabeled data
additions = []               # track additions to lm corpus
threshhold = 100.0           # perplexity threshhold
cutoff = 25                  # cutoff for added sents, make large to 'ignore'
k = 0.001                    # smoothing term for perplexity add-k smoothing

for i in range(iters):
    
    # EarlyStopping
    cnt = 0
    
    # indices to remove from unlabeled data
    remove_idx = []
    
    # build language model
    vocabsize, vocab_counts, bigram_counts = languagemodel(add_corpus)
    
    # get perplexities
    perplexities = [perplexity(s[1], k) for s in rem_unlabeled]
    
    # indices, sort perplexities
    indices = [i for i in range(len(perplexities))]
    sorted_perplexities = [(x, y) for x, y in sorted(zip(indices, perplexities), key=lambda pair: pair[1])]
    
    # take top sents
    add = 0
    for jdx, tup in enumerate(sorted_perplexities):
        idx = tup[0]
        perp = tup[1]
        if perp < threshhold:
            additions.append(rem_unlabeled[idx])
            add_corpus.append(rem_unlabeled[idx][1])
            remove_idx.append(idx)
            cnt += 1
            add += 1
        if add == cutoff:
            break
    
    # filter out additions
    rem_unlabeled = [rem_unlabeled[i] for i in range(len(rem_unlabeled)) if i not in remove_idx]
    
    # if no added sents, terminate
    if cnt == 0:
        print("no added sentences, stopping...\n")
        break
    
    if i > 0 and i % 10 == 0:
        print("iter", i, ": total added", len(additions), "sents")
            

iter 10 : total added 275 sents
iter 20 : total added 525 sents
iter 30 : total added 775 sents
iter 40 : total added 1025 sents
iter 50 : total added 1275 sents
iter 60 : total added 1525 sents
no added sentences, stopping...



## evaluation

well, looking at the top results from the extracted sentences, we can see that the results subjectively look good in regards to differentiating between jane austen and other texts. but we can also see that perhaps the strongest 'features' are rather trivial: it looks like ('{start}', 'i'), ('i', 'am'), ('i', 'do'), and ('am', 'sure') are probably more highly correlated with austen's works than the others. we could confirm this experimentally by examining counts. a better language model should also help diversify our results. but this does serve to demonstrate the technique.

In [16]:
add_perplexities = [perplexity(s[1]) for s in additions]
[(x[0], ' '.join(x[1]), y) for x, y in sorted(zip(additions, add_perplexities), key=lambda pair: pair[1])][:20]

[('austen', '<s> i am sure i had not </s>', 12.756553442828348),
 ('austen', '<s> i have no doubt of it </s>', 12.964209650994526),
 ('melville', '<s> i am sure that i did not </s>', 14.638938101079416),
 ('austen', '<s> i am glad of it </s>', 14.951632549419598),
 ('austen', '<s> i am sure you will not like him </s>', 15.053524838875829),
 ('austen', '<s> i am said he </s>', 19.31944247856956),
 ('austen', '<s> i should be sorry to be more </s>', 19.65881676454139),
 ('melville', '<s> you must have heard of it </s>', 19.68569980998823),
 ('austen', '<s> that i am sure you would </s>', 21.715270354841675),
 ('austen', '<s> i could not have believed it </s>', 21.958631490355476),
 ('austen', '<s> i do not think they were </s>', 21.984407431754054),
 ('austen', '<s> i do not pretend to it </s>', 22.35103112504964),
 ('austen', '<s> no indeed i do not </s>', 22.447764756276353),
 ('austen', '<s> she had a great wish to see him </s>', 22.555253968187362),
 ('austen', '<s> what do you think

In [17]:
diff = len(add_corpus) - len(corpus)
totl = len(unlabeled)
print("sents found:", diff, "(%", diff*100/totl, "of unlabeled)")

sents found: 1646 (% 8.045359010704335 of unlabeled)


In [19]:
punlabeled = len(withheld)/len(withheld + melville + carroll)
print("percentage of trues in unlabeled:", punlabeled)

percentage of trues in unlabeled: 0.5320885673786597


In [18]:
labels = [t[0] for t in additions]
corrects = [t[0] for t in additions if t[0]=='austen']
print("precision of found sents: ", len(corrects)/len(labels))

precision of found sents:  0.9270959902794653


In [20]:
recall = len(corrects)/len(withheld)
print("recall of unlabeled austen sents: ", recall)

recall of unlabeled austen sents:  0.14018004776777512
