# Next Word Prediction

Network used if n=3:
<img src="PrototypeBN.png" alt="BN" width="50%" />

Below is all the code to create and parse the network:

In [1]:
#
# Predict the next word of a sentence using a Bayesian Network. Learn the
# probabilities from Gutenburg Project text documents.
#
import re
import sys
import argparse
import pathlib
import operator
import pomegranate
import readline

#
# Clean up a paragraph
#
def cleanup(text):
    text = text.strip()
    # Convert dashes and line returns to spaces
    text = re.sub('(--|\n)', ' ', text)
    # Only allow case-insensitive a-z in words
    text = re.sub('[^a-zA-Z \'-]+', '', text)
    return text

#
# Find all the words and the next word from a Gutenburg Project text document
# and use these to construct a Bayesian network
#
def constructNetworkFromFiles(filenames, number):
    assert number > 0, "n > 0"

    paragraphs = []

    # Parse each of the files
    for filename in filenames:
        p = pathlib.Path(filename)
        assert p.is_file(), "Must specify valid file"

        # Get all the book data from this one file
        data = ""

        try:
            with p.open() as f:
                start = False
                end   = False

                for line in f:
                    if re.match("^\*\*\* END OF THIS PROJECT", line):
                        end = True

                    # Only use this line if it's between the start and end
                    if start and not end:
                        data += line

                    # Find the start and end of the document
                    if re.match("^\*\*\* START OF THIS PROJECT", line):
                        start = True
        except OSError:
            print("Could not open file")

        # Split by paragraphs
        word_usage = [{} for i in range(0, number)]
        paragraphs += data.split('\n\n')

    # Using the paragraphs obtained from all the files, look for word sequences
    for paragraph in paragraphs:
        if paragraph:
            # Split by words
            words = cleanup(paragraph).split(' ')

            # For each length sequence of words (i.e., if n=3, for a sequence
            # of length 1, 2, and 3)
            for num in range(1, number+1):
                # Look at a sliding window of the specified number of words
                # finding the frequency of each sequenc
                for i in range(num, len(words)):
                    s = ','.join([words[i-(num-j)] for j in range(0, num)])

                    if s in word_usage[num-1]:
                        word_usage[num-1][s] += 1
                    else:
                        word_usage[num-1][s] = 1

    # Total number to divide by to compute the probabilities
    total_freq = [sum(d.values()) for d in word_usage]

    # Construct the network
    word_distributions = []

    # First word, only one that isn't a CPD. Just take the frequency and divide
    # by the total to get probabilities.
    word_distributions.append(pomegranate.DiscreteDistribution(
        { key: value/total_freq[0] for (key, value) in word_usage[0].items()}))

    # The other words (2, 3, etc.)
    for i in range(1, number):
        word_distributions.append(pomegranate.ConditionalProbabilityTable(
            # CPT: word1, ..., this_word, prob(this_word)
            [key.split(',')+[value] for (key, value) in word_usage[i].items()],
            # depends on all words up to this word
            [word_distributions[j] for j in range(0, i)]))

    # Create all the nodes
    nodes = []

    for i in range(0, number):
        nodes.append(pomegranate.State(word_distributions[i],
            name="word"+str(i+1)))

    network = pomegranate.BayesianNetwork("Word Prediction")
    network.add_states(nodes)

    # Add edges from current word to all later words
    for i in range(0, number):
        for j in range(i+1, number):
            network.add_transition(nodes[i], nodes[j])

    network.bake()

    return network

#
# Make a prediction of the next word based on a string of words
#
def predict(network, number, s):
    words = s.split(" ")
    num = number

    # Start off with the desired number-1 of observations, but if we can't
    # predict based on that, keep decreasing number until we can or until we
    # find out we just can't predict.
    while num > 1:
        observations = {}

        # Start at the last word and look backwards for our observations. But,
        # to predict, we need the last node to be the next word, so don't
        # observe all number of words in our network, just observe n-1 words.
        for i, word in enumerate(words[-(num-1):]):
            observations["word"+str(i+1)] = word

        # Update the beliefs in the network
        try:
            if debug:
                print("Input string:", s)
                print("Observations:", observations)

            marginals = network.predict_proba(observations)

            # Which node do we want to look at? If there are no previous words,
            # look at the first one. If there was one, look at the second, etc.
            # If there was n-1 or more, look at the last one.
            prediction_node = min(len(words), num-1)

            # Look at the node we care about to read off the new probability
            # distribution
            prediction = marginals[prediction_node].parameters[0]

            # Sort the predictions in descending order
            sorted_prediction = sorted(prediction.items(),
                    key=operator.itemgetter(1), reverse=True)

            return sorted_prediction

        # If our observations "aren't possible," i.e. they didn't occur in the
        # training data then try again using one fewer word as an observation,
        # but only if we haven't already used that number of words
        except ZeroDivisionError:
            if len(words) >= number-1:
                num -= 1
            else:
                break
        except ValueError:
            if len(words) >= number-1:
                num -= 1
            else:
                break

    # Couldn't make a prediction
    return []

#
# Print the predictions
#
# Assuming input is sorted in descending order of most probable to least
# probable
#
def printPrediction(predictions):
    num = min(25, len(predictions))

    if num == 0:
        print("No prediction found")
    else:
        # Print out the top ones that are greater than zero
        for i in range(0, num):
            if predictions[i][1] > 0:
                print(" ", predictions[i], sep="")

Now let's create the network from one of these documents:

In [2]:
global debug
debug = True
files = [
    'alices_adventures_in_wonderland.txt'
]
number = 3

# Construct the network based on the input training files
network = constructNetworkFromFiles(files[-1:], number)

# Simplify calls below
def p(s):
    printPrediction(predict(network, number, s))

Now let's do some predicting:

In [3]:
p("this is")

Input string: this is
Observations: {'word2': 'is', 'word1': 'this'}
 ('May', 1.0)


In [4]:
p("why did")

Input string: why did
Observations: {'word2': 'did', 'word1': 'why'}
 ('they', 1.0)


In [5]:
p("they did not")

Input string: they did not
Observations: {'word2': 'not', 'word1': 'did'}
 ('like', 0.22222222222222196)
 ('quite', 0.07407407407407404)
 ('get', 0.07407407407407404)
 ('at', 0.07407407407407404)
 ('dare', 0.07407407407407404)
 ('venture', 0.07407407407407404)
 ('much', 0.03703703703703708)
 ('look', 0.03703703703703708)
 ('seem', 0.03703703703703708)
 ('appear', 0.03703703703703708)
 ('come', 0.03703703703703708)
 ('feel', 0.03703703703703708)
 ('notice', 0.03703703703703708)
 ('sneeze', 0.03703703703703708)
 ('answer', 0.03703703703703708)
 ('wish', 0.03703703703703708)
 ('see', 0.03703703703703708)


In [6]:
p("four score and")

Input string: four score and
Observations: {'word2': 'and', 'word1': 'score'}
Input string: four score and
Observations: {'word1': 'and'}
 ('the', 0.4845312646041715)
 ('she', 0.1728198897093195)
 ('then', 0.0732778764370508)
 ('was', 0.031965604262080775)
 ('a', 0.023927469856996182)
 ('Alice', 0.021030002804000518)
 ('as', 0.015795868772782582)
 ('began', 0.01570240209365373)
 ('said', 0.014580801944107002)
 ('looked', 0.011309468174595827)
 ('went', 0.010281334704178016)
 ('all', 0.007570801009440203)
 ('I', 0.005981867464249034)
 ('they', 0.005981867464249034)
 ('he', 0.005981867464249034)
 ('when', 0.005981867464249013)
 ('it', 0.004579867277315666)
 ('this', 0.004579867277315657)
 ('had', 0.004579867277315649)
 ('there', 0.0033648004486400767)
 ('that', 0.0033648004486400767)
 ('in', 0.0033648004486400767)
 ('very', 0.0028040003738667297)
 ('after', 0.002336666978222274)
 ('found', 0.002336666978222274)


In [7]:
p("with much")

Input string: with much
Observations: {'word2': 'much', 'word1': 'with'}
Input string: with much
Observations: {'word1': 'much'}
 ('as', 0.1428571428571425)
 ('of', 0.1428571428571425)
 ('surprised', 0.06349206349206363)
 ('the', 0.06349206349206363)
 ('frightened', 0.06349206349206363)
 ('at', 0.06349206349206363)
 ('accustomed', 0.015873015873015876)
 ('overcome', 0.015873015873015876)
 ('out', 0.015873015873015876)
 ('to', 0.015873015873015876)
 ('pleased', 0.015873015873015876)
 ('already', 0.015873015873015876)
 ('so', 0.015873015873015876)
 ('confused', 0.015873015873015876)
 ('matter', 0.015873015873015876)
 ('larger', 0.015873015873015876)
 ('under', 0.015873015873015876)
 ("right'", 0.015873015873015876)
 ('evidence', 0.015873015873015876)
 ('pepper', 0.015873015873015876)
 ('care', 0.015873015873015876)
 ('farther', 0.015873015873015876)
 ('about', 0.015873015873015876)
 ('if', 0.015873015873015876)
 ('use', 0.015873015873015876)


In [8]:
p("and then")

Input string: and then
Observations: {'word2': 'then', 'word1': 'and'}
 ('said', 0.07142857142857137)
 ('the', 0.07142857142857137)
 ('sat', 0.03571428571428575)
 ("'we", 0.03571428571428575)
 ("I'll", 0.03571428571428575)
 ('dipped', 0.03571428571428575)
 ('and', 0.03571428571428575)
 ('they', 0.03571428571428575)
 ('added', 0.03571428571428575)
 ('keep', 0.03571428571428575)
 ('another', 0.03571428571428575)
 ('she', 0.03571428571428575)
 ('turned', 0.03571428571428575)
 ('hurried', 0.03571428571428575)
 ('after', 0.03571428571428575)
 ('all', 0.03571428571428575)
 ('quietly', 0.03571428571428575)
 ('if', 0.03571428571428575)
 ('Alice', 0.03571428571428575)
 ('such', 0.03571428571428575)
 ('at', 0.03571428571428575)
 ('a', 0.03571428571428575)
 ('raised', 0.03571428571428575)
 ('treading', 0.03571428571428575)
 ('nodded', 0.03571428571428575)
