In [1]:
import re
import math
from scipy.stats import poisson
import itertools

In [2]:
import findspark
import os
findspark.init()
import pyspark
sc = pyspark.SparkContext()
sc.setLogLevel('ERROR')

In [3]:
n_partitions = 6
MAX_EDIT_DISTANCE = 3

In [4]:
######################
#
# DOCUMENTATION HERE
#
######################

# number of partitions to be used
n_partitions = 6
MAX_EDIT_DISTANCE = 3

def get_n_deletes_list(w, n):
    '''given a word, derive list of strings with up to n characters deleted'''
    # since this list is generally of the same magnitude as the number of 
    # characters in a word, it may not make sense to parallelize this
    # so we use python to create the list
    deletes = []
    queue = [w]
    for d in range(n):
        temp_queue = []
        for word in queue:
            if len(word)>1:
                for c in range(len(word)):  # character index
                    word_minus_c = word[:c] + word[c+1:]
                    if word_minus_c not in deletes:
                        deletes.append(word_minus_c)
                    if word_minus_c not in temp_queue:
                        temp_queue.append(word_minus_c)
        queue = temp_queue
        
    return deletes

def get_transitions(sentence):
    if len(sentence)<2:
        return None
    else:
        return [((sentence[i], sentence[i+1]), 1) for i in range(len(sentence)-1)]
    
def map_transition_prob(x):
    vals = x[1]
    total = float(sum(vals.values()))
    probs = {k: math.log(v/total) for k, v in vals.items()}
    return (x[0], probs)

def parallel_create_dictionary(fname):
    '''
    Create dictionary, start probabilities and transition
    probabilities using Spark RDDs.
    '''
    # we generate and count all words for the corpus,
    # then add deletes to the dictionary
    # this is a slightly different approach from the SymSpell algorithm
    # that may be more appropriate for Spark processing
    
    print 'Creating dictionary...'
    
    ############
    #
    # load file & initial processing
    #
    ############
    
    # http://stackoverflow.com/questions/22520932/python-remove-all-non-alphabet-chars-from-string
    regex = re.compile('[^a-z ]')

    # convert file into one long sequence of words
    make_all_lower = sc.textFile(fname) \
            .map(lambda line: line.lower()) \
            .filter(lambda x: x!='').cache()
    
    # split into individual sentences and remove other punctuation
    split_sentence = make_all_lower.flatMap(lambda line: line.split('.')) \
            .map(lambda sentence: regex.sub(' ', sentence)) \
            .map(lambda sentence: sentence.split()).cache()
    
    ############
    #
    # generate start probabilities
    #
    ############
    
    # only focus on words at the start of sentences
    start_words = split_sentence.map(lambda sentence: sentence[0] if len(sentence)>0 else None) \
        .filter(lambda word: word!=None)
    
    # add a count to each word
    count_start_words_once = start_words.map(lambda word: (word, 1)).cache()

    # use accumulator to count the number of words at the start of sentences
    accum_total_start_words = sc.accumulator(0)
    count_total_start_words = count_start_words_once.foreach(lambda x: accum_total_start_words.add(1))
    total_start_words = float(accum_total_start_words.value)
    
    # reduce into count of unique words at the start of sentences
    unique_start_words = count_start_words_once.reduceByKey(lambda a, b: a + b, numPartitions = n_partitions)
    
    # convert counts to probabilities
    start_prob_calc = unique_start_words.mapValues(lambda v: math.log(v/total_start_words))
    
    # get default start probabilities (for words not in corpus)
    default_start_prob = math.log(1/total_start_words)
    
    # store start probabilities as a dictionary (will be used as a lookup table)
    start_prob = start_prob_calc.collectAsMap()
    
    ############
    #
    # generate transition probabilities
    #
    ############
    
    # focus on continuous word pairs within the sentence
    # e.g. "this is a test" -> "this is", "is a", "a test"
    # note: as the relevant probability is P(word|previous word)
    # the tuples are ordered as (previous word, word)
    other_words = split_sentence.map(lambda sentence: get_transitions(sentence)).filter(lambda x: x!=None). \
                flatMap(lambda x: x).cache()

    # use accumulator to count the number of transitions
    accum_total_other_words = sc.accumulator(0)
    count_total_other_words = other_words.foreach(lambda x: accum_total_other_words.add(1))
    total_other_words = float(accum_total_other_words.value)
    
    # reduce into count of unique word pairs
    unique_other_words = other_words.reduceByKey(lambda a, b: a + b, numPartitions = n_partitions)
    
    # aggregate by previous word
    # i.e. (previous word, [(word1, word1-previous word count), (word2, word2-previous word count), ...])
    other_words_collapsed = unique_other_words.map(lambda x: (x[0][0], (x[0][1], x[1]))).groupByKey().mapValues(dict)

    # POTENTIAL OPTIMIZATION: FIND AN ALTERNATIVE TO GROUPBYKEY (CREATES ~9.3MB SHUFFLE)
    
    # convert counts to probabilities
    transition_prob_calc = other_words_collapsed.map(lambda x: map_transition_prob(x))
    
    # get default transition probabilities (for word pairs not in corpus)
    default_transition_prob = math.log(1/total_other_words)
    
    # store transition probabilities as dictionary (will be used as lookup table)
    transition_prob = transition_prob_calc.collectAsMap()
    
    ############
    #
    # process corpus for dictionary
    #
    ############
    
    replace_nonalphs = make_all_lower.map(lambda line: regex.sub(' ', line))
    all_words = replace_nonalphs.flatMap(lambda line: line.split())

    # create core corpus dictionary (i.e. only words appearing in file, no "deletes") and cache it
    # output RDD of unique_words_with_count: [(word1, count1), (word2, count2), (word3, count3)...]
    count_once = all_words.map(lambda word: (word, 1))
    unique_words_with_count = count_once.reduceByKey(lambda a, b: a + b, numPartitions = n_partitions).cache()
    
    ############
    #
    # generate deletes list
    #
    ############
    
    # generate list of n-deletes from words in a corpus of the form: [(word1, count1), (word2, count2), ...]
     
    assert MAX_EDIT_DISTANCE > 0  
    
    generate_deletes = unique_words_with_count.map(lambda (parent, count): 
                                                   (parent, get_n_deletes_list(parent, MAX_EDIT_DISTANCE)))
    expand_deletes = generate_deletes.flatMapValues(lambda x: x)
    swap = expand_deletes.map(lambda (orig, delete): (delete, ([orig], 0)))
   
    ############
    #
    # combine delete elements with main dictionary
    #
    ############
    
    corpus = unique_words_with_count.mapValues(lambda count: ([], count))
    combine = swap.union(corpus)  # combine deletes with main dictionary, eliminate duplicates
    
    # since the dictionary will only be a lookup table once created, we can
    # pass on as a Python dictionary rather than RDD by reducing locally and
    # avoiding an extra shuffle from reduceByKey
    dictionary = combine.reduceByKeyLocally(lambda a, b: (a[0]+b[0], a[1]+b[1]))

    words_processed = unique_words_with_count.map(lambda (k, v): v).reduce(lambda a, b: a + b)
    word_count = unique_words_with_count.count()   
    
    # output stats
    print 'Total words processed: %i' % words_processed
    print 'Total unique words in corpus: %i' % word_count 
    print 'Total items in dictionary (corpus words and deletions): %i' % len(dictionary)
    print '  Edit distance for deletions: %i' % MAX_EDIT_DISTANCE
    print 'Total unique words at the start of a sentence: %i' \
        % len(start_prob)
    print 'Total unique word transitions: %i' % len(transition_prob)
    
    return dictionary, start_prob, default_start_prob, transition_prob, default_transition_prob

In [5]:
%%time
dictionary, start_prob, default_start_prob, transition_prob, default_transition_prob = \
    parallel_create_dictionary('testdata/big.txt')

Creating dictionary...
Total words processed: 1105285
Total unique words in corpus: 29157
Total items in dictionary (corpus words and deletions): 2151998
  Edit distance for deletions: 3
Total unique words at the start of a sentence: 15297
Total unique word transitions: 27224
CPU times: user 11.7 s, sys: 1.12 s, total: 12.8 s
Wall time: 54 s


In [64]:
def dameraulevenshtein(seq1, seq2):
    '''
    Calculate the Damerau-Levenshtein distance between sequences.
    Same code as word-level checking.
    '''
    
    # codesnippet:D0DE4716-B6E6-4161-9219-2903BF8F547F
    # Conceptually, this is based on a len(seq1) + 1 * len(seq2) + 1
    # matrix. However, only the current and two previous rows are
    # needed at once, so we only store those.
    
    oneago = None
    thisrow = range(1, len(seq2) + 1) + [0]
    
    for x in xrange(len(seq1)):
        
        # Python lists wrap around for negative indices, so put the
        # leftmost column at the *end* of the list. This matches with
        # the zero-indexed strings and saves extra calculation.
        twoago, oneago, thisrow = \
            oneago, thisrow, [0] * len(seq2) + [x + 1]
        
        for y in xrange(len(seq2)):
            delcost = oneago[y] + 1
            addcost = thisrow[y - 1] + 1
            subcost = oneago[y - 1] + (seq1[x] != seq2[y])
            thisrow[y] = min(delcost, addcost, subcost)
            # This block deals with transpositions
            if (x > 0 and y > 0 and seq1[x] == seq2[y - 1]
                and seq1[x-1] == seq2[y] and seq1[x] != seq2[y]):
                thisrow[y] = min(thisrow[y], twoago[y - 2] + 1)
                
    return thisrow[len(seq2) - 1]

def get_suggestions(string, dictionary, longest_word_length=20, 
                    min_count=100, max_sug=10):
    '''
    Return list of suggested corrections for potentially incorrectly
    spelled word.
    Code based on get_suggestions function from word-level checking,
    with the addition of the min_count parameter, which only
    considers words that have occur more than min_count times in the
    (dictionary) corpus.
    '''
    
    if (len(string) - longest_word_length) > MAX_EDIT_DISTANCE:
        # to ensure Viterbi can keep running -- use the word itself
        return [(string, 0)]
    
    suggest_dict = {}
    
    queue = [string]
    q_dictionary = {}  # items other than string that we've checked
    
    while len(queue)>0:
        q_item = queue[0]  # pop
        queue = queue[1:]
        
        # process queue item
        if (q_item in dictionary) and (q_item not in suggest_dict):
            if (dictionary[q_item][1]>0):
            # word is in dictionary, and is a word from the corpus,
            # and not already in suggestion list so add to suggestion
            # dictionary, indexed by the word with value (frequency
            # in corpus, edit distance)
            # note: q_items that are not the input string are shorter
            # than input string since only deletes are added (unless
            # manual dictionary corrections are added)
                assert len(string)>=len(q_item)
                suggest_dict[q_item] = \
                    (dictionary[q_item][1], len(string) - len(q_item))
            
            # the suggested corrections for q_item as stored in
            # dictionary (whether or not q_item itself is a valid
            # word or merely a delete) can be valid corrections
            for sc_item in dictionary[q_item][0]:
                if (sc_item not in suggest_dict):
                    
                    # compute edit distance
                    # suggested items should always be longer (unless
                    # manual corrections are added)
                    assert len(sc_item)>len(q_item)
                    # q_items that are not input should be shorter
                    # than original string 
                    # (unless manual corrections added)
                    assert len(q_item)<=len(string)
                    if len(q_item)==len(string):
                        assert q_item==string
                        item_dist = len(sc_item) - len(q_item)

                    # item in suggestions list should not be the same
                    # as the string itself
                    assert sc_item!=string           
                    # calculate edit distance using Damerau-
                    # Levenshtein distance
                    item_dist = dameraulevenshtein(sc_item, string)
                    
                    if item_dist<=MAX_EDIT_DISTANCE:
                        # should already be in dictionary if in
                        # suggestion list
                        assert sc_item in dictionary  
                        # trim list to contain state space
                        if (dictionary[q_item][1]>0): 
                            suggest_dict[sc_item] = \
                                (dictionary[sc_item][1], item_dist)
        
        # now generate deletes (e.g. a substring of string or of a
        # delete) from the queue item as additional items to check
        # -- add to end of queue
        assert len(string)>=len(q_item)
        if (len(string)-len(q_item))<MAX_EDIT_DISTANCE \
            and len(q_item)>1:
            for c in range(len(q_item)): # character index        
                word_minus_c = q_item[:c] + q_item[c+1:]
                if word_minus_c not in q_dictionary:
                    queue.append(word_minus_c)
                    # arbitrary value to identify we checked this
                    q_dictionary[word_minus_c] = None

    # return list of suggestions: (correction, edit distance)
    
    # only include words that have appeared a minimum number of times
    # make sure that we do not lose the original word
    as_list = [i for i in suggest_dict.items() 
               if (i[1][0]>min_count or i[0]==string)]
    
    # only include the most likely suggestions (based on frequency
    # and edit distance from original word)
    trunc_as_list = sorted(as_list, 
            key = lambda (term, (freq, dist)): (dist, -freq))[:max_sug]
    
    if len(trunc_as_list)==0:
        # to ensure Viterbi can keep running
        # -- use the word itself if no corrections are found
        return [(string, 0)]
        
    else:
        # drop the word frequency - not needed beyond this point
        return [(i[0], i[1][1]) for i in trunc_as_list]

    '''
    Output format:
    get_suggestions('file', dictionary)
    [('file', 0), ('five', 1), ('fire', 1), ('fine', 1), ('will', 2),
    ('time', 2), ('face', 2), ('like', 2), ('life', 2), ('while', 2)]
    '''
    
def get_emission_prob(edit_dist, poisson_lambda=0.01):
    '''
    The emission probability, i.e. P(observed word|intended word)
    is approximated by a Poisson(k, l) distribution, where 
    k=edit distance and l=0.01.
    
    The lambda parameter matches the one used in the AM207
    lecture notes. Various parameters between 0 and 1 were tested
    to confirm that 0.01 yields the most accurate results.
    '''
    
    return math.log(poisson.pmf(edit_dist, poisson_lambda))

######################
# Multiple helper functions are used to avoid KeyErrors when
# attempting to access values that are not present in dictionaries,
# in which case the previously specified default value is returned.
######################

def get_start_prob(word, start_prob, default_start_prob):
    try:
        return start_prob[word]
    except KeyError:
        return default_start_prob
    
def get_transition_prob(cur_word, prev_word, transition_prob, default_transition_prob):
    try:
        return transition_prob[prev_word][cur_word]
    except KeyError:
        return default_transition_prob

def get_belief(prev_word, prev_belief):
    try:
        return prev_belief[prev_word]
    except KeyError:
        return math.log(math.exp(min(prev_belief.values()))/2.)  


In [65]:
fname = "testdata/test1.txt"

# broadcast Python dictionaries to workers
bc_dictionary = sc.broadcast(dictionary)
bc_start_prob = sc.broadcast(start_prob)
bc_transition_prob = sc.broadcast(transition_prob)

# convert all text to lowercase and drop empty lines
make_all_lower = sc.textFile(fname) \
    .map(lambda line: line.lower()) \
    .filter(lambda x: x!='')

regex = re.compile('[^a-z ]')

# split into sentences -> remove special characters -> convert into list of words
split_sentence = make_all_lower.flatMap(lambda line: line.split('.')) \
        .map(lambda sentence: regex.sub(' ', sentence)) \
        .map(lambda sentence: sentence.split()).cache()

# use accumulator to count the number of words checked
accum_total_words = sc.accumulator(0)
split_words = split_sentence.flatMap(lambda x: x).foreach(lambda x: accum_total_words.add(1))

# assign each sentence a unique id
sentence_id = split_sentence.zipWithIndex().map(lambda (k, v): (v, k)).cache()

sentence_id.collect()

[(0, [u'this', u'is', u'ax', u'test']), (1, [u'her', u'tee', u'set'])]

In [66]:
def get_sentence_word_id(words):
    return [(i, w) for i, w in enumerate(words)]

In [67]:
def split_sentence_words(sentence):
    sent_id, words = sentence
    return [[sent_id, w] for w in words]

In [68]:
# number each word in a sentence, and split into individual words
sentence_word_id = sentence_id.mapValues(lambda v: get_sentence_word_id(v)).flatMap(lambda x: split_sentence_words(x))
sentence_word_id.collect()

[[0, (0, u'this')],
 [0, (1, u'is')],
 [0, (2, u'ax')],
 [0, (3, u'test')],
 [1, (0, u'her')],
 [1, (1, u'tee')],
 [1, (2, u'set')]]

In [69]:
# get suggestions for each word
sentence_word_suggestions = sentence_word_id.mapValues(lambda v: 
                                                       (v[0], v[1], get_suggestions(v[1], bc_dictionary.value))).cache()
# sentence_word_suggestions.filter(lambda x: x[0]==0).collect()

In [70]:
# filter for the first words in sentences
sentence_word_1 = sentence_word_suggestions.filter(lambda (k, v): v[0]==0).mapValues(lambda v: (v[1], v[2]))
sentence_word_1.collect()

[(0,
  (u'this',
   [(u'this', 0),
    (u'his', 1),
    (u'thus', 1),
    (u'thin', 1),
    (u'the', 2),
    (u'that', 2),
    (u'is', 2),
    (u'him', 2),
    (u'they', 2),
    (u'their', 2)])),
 (1,
  (u'her',
   [(u'her', 0),
    (u'he', 1),
    (u'here', 1),
    (u'hear', 1),
    (u'the', 2),
    (u'his', 2),
    (u'had', 2),
    (u'for', 2),
    (u'be', 2),
    (u'or', 2)]))]

In [71]:
def start_word_prob(words, tmp_sp, d_sp):
    orig_word, sug_words = words
    probs = [(w[0], 
              math.exp(get_start_prob(w[0], tmp_sp, d_sp) + get_emission_prob(w[1]))
             ) 
             for w in sug_words]
    sum_probs = sum([p[1] for p in probs])
    probs = [([p[0]], math.log(p[1]/sum_probs)) for p in probs]
    return probs

In [72]:
# calculate probability for each suggestion
# format: (sentence id, [path-probability pairs])
sentence_path = sentence_word_1.mapValues(lambda v: start_word_prob(v, bc_start_prob.value, default_start_prob))
sentence_path.collect()

[(0,
  [([u'this'], -0.010416136127922377),
   ([u'his'], -4.751501506095812),
   ([u'thus'], -7.051702807734582),
   ([u'thin'], -8.997612956789895),
   ([u'the'], -7.479924902965362),
   ([u'that'], -9.697449125500638),
   ([u'is'], -10.680518021105867),
   ([u'him'], -11.02509475953902),
   ([u'they'], -9.835785909400098),
   ([u'their'], -11.093861616786112)]),
 (1,
  [([u'her'], -0.07316779348212661),
   ([u'he'], -2.72717847452876),
   ([u'here'], -5.705854346454598),
   ([u'hear'], -8.539067690510812),
   ([u'the'], -6.615914528578116),
   ([u'his'], -9.185808498256602),
   ([u'had'], -9.857703403156888),
   ([u'for'], -9.548639412388193),
   ([u'be'], -10.411495062806324),
   ([u'or'], -10.27847192940494)])]

In [73]:
###LOOP STARTS HERE###

In [74]:
word_num = 1
word_num

1

In [75]:
# filter for the next words in sentences
sentence_word_next = sentence_word_suggestions.filter(lambda (k, v): v[0]==word_num).mapValues(lambda v: (v[1], v[2]))
sentence_word_next.collect()

[(0,
  (u'is',
   [(u'is', 0),
    (u'in', 1),
    (u'it', 1),
    (u'his', 1),
    (u'as', 1),
    (u'i', 1),
    (u's', 1),
    (u'if', 1),
    (u'its', 1),
    (u'us', 1)])),
 (1,
  (u'tee',
   [(u'the', 1),
    (u'see', 1),
    (u'ten', 1),
    (u'tea', 1),
    (u'to', 2),
    (u'he', 2),
    (u'be', 2),
    (u'her', 2),
    (u'were', 2),
    (u'she', 2)]))]

In [76]:
# check that there are more words left
sentence_word_next.isEmpty()

False

In [77]:
def split_suggestions(sentence):
    sent_id, (word, word_sug)  = sentence
    return [[sent_id, (word, w)] for w in word_sug]

In [78]:
# split into suggestions
sentence_word_next_split = sentence_word_next.flatMap(lambda x: split_suggestions(x))
sentence_word_next_split.collect()

[[0, (u'is', (u'is', 0))],
 [0, (u'is', (u'in', 1))],
 [0, (u'is', (u'it', 1))],
 [0, (u'is', (u'his', 1))],
 [0, (u'is', (u'as', 1))],
 [0, (u'is', (u'i', 1))],
 [0, (u'is', (u's', 1))],
 [0, (u'is', (u'if', 1))],
 [0, (u'is', (u'its', 1))],
 [0, (u'is', (u'us', 1))],
 [1, (u'tee', (u'the', 1))],
 [1, (u'tee', (u'see', 1))],
 [1, (u'tee', (u'ten', 1))],
 [1, (u'tee', (u'tea', 1))],
 [1, (u'tee', (u'to', 2))],
 [1, (u'tee', (u'he', 2))],
 [1, (u'tee', (u'be', 2))],
 [1, (u'tee', (u'her', 2))],
 [1, (u'tee', (u'were', 2))],
 [1, (u'tee', (u'she', 2))]]

In [79]:
# join on previous path
# format: (sentence id, ((current word, (current word suggestion, edit distance)), 
#         [(previous path-probability pairs)]))
sentence_word_next_path = sentence_word_next_split.join(sentence_path)
# sentence_word_next_path.filter(lambda x: x[0]==0).collect()

In [80]:
def subs_word_prob(words, tmp_tp, d_tp):
    
    # unpack values
    sent_id = words[0]
    cur_word = words[1][0][0]
    cur_sug = words[1][0][1][0]
    cur_sug_ed = words[1][0][1][1]
    prev_sug = words[1][1]
    
    # belief + transition probability + emission probability
    (prob, word) = max((p[1]
                 + get_transition_prob(cur_sug, p[0][-1], tmp_tp, d_tp)
                 + get_emission_prob(cur_sug_ed), p[0])
                     for p in prev_sug)
    
    return sent_id, (word + [cur_sug], math.exp(prob))

In [81]:
# calculate path with max probability
sentence_word_next_path_prob = sentence_word_next_path.map(lambda x:
                                        subs_word_prob(x, bc_transition_prob.value, default_transition_prob))
sentence_word_next_path_prob.collect()

[(0, ([u'this', u'is'], 0.07965129629236094)),
 (0, ([u'this', u'in'], 4.52866901996701e-05)),
 (0, ([u'this', u'it'], 1.0655691811687093e-05)),
 (0, ([u'this', u'his'], 7.991768858765314e-06)),
 (0, ([u'this', u'as'], 4.7950613152591945e-05)),
 (0, ([u'this', u'i'], 3.196707543506126e-05)),
 (0, ([u'that', u's'], 1.5724350299737122e-08)),
 (0, ([u'this', u'if'], 2.663922952921773e-06)),
 (0, ([u'this', u'its'], 1.0119590152617674e-08)),
 (0, ([u'this', u'us'], 1.0119590152617674e-08)),
 (1, ([u'her', u'the'], 6.210551616996943e-05)),
 (1, ([u'her', u'see'], 4.140367744664627e-06)),
 (1, ([u'her', u'ten'], 9.504083106300208e-09)),
 (1, ([u'her', u'tea'], 2.070183872332317e-06)),
 (1, ([u'her', u'to'], 1.12825021042111e-06)),
 (1, ([u'her', u'he'], 2.4842206467987784e-07)),
 (1, ([u'he', u'be'], 2.556031225104305e-09)),
 (1, ([u'her', u'her'], 2.0701838723323124e-08)),
 (1, ([u'her', u'were'], 1.0350919361661578e-08)),
 (1, ([u'her', u'she'], 3.8298401638147756e-07))]

In [82]:
def normalize(probs):
    sum_probs = sum([p[1] for p in probs])
    return [(p[0], math.log(p[1]/sum_probs)) for p in probs]

In [83]:
# normalize for numerical stability
sentence_path = sentence_word_next_path_prob.groupByKey().mapValues(lambda v: normalize(v))
sentence_path.collect()

[(0,
  [([u'this', u'is'], -0.001838225822312142),
   ([u'this', u'in'], -7.474238641144874),
   ([u'this', u'it'], -8.921157624081198),
   ([u'this', u'his'], -9.20883969653298),
   ([u'this', u'as'], -7.417080227304924),
   ([u'this', u'i'], -7.822545335413089),
   ([u'that', u's'], -15.439796609526178),
   ([u'this', u'if'], -10.307451985201089),
   ([u'this', u'its'], -15.880533930581192),
   ([u'this', u'us'], -15.880533930581192)]),
 (1,
  [([u'her', u'the'], -0.12135666115371971),
   ([u'her', u'see'], -2.8294068622559303),
   ([u'her', u'ten'], -8.906225245412237),
   ([u'her', u'tea'], -3.522554042815874),
   ([u'her', u'to'], -4.1295235271347694),
   ([u'her', u'he'], -5.642817579015966),
   ([u'he', u'be'], -10.21949099950891),
   ([u'her', u'her'], -8.127724228803967),
   ([u'her', u'were'], -8.82087140936391),
   ([u'her', u'she'], -5.209953496719689)])]

In [84]:
word_num += 1
word_num

2

In [85]:
# filter for the next words in sentences
sentence_word_next = sentence_word_suggestions.filter(lambda (k, v): v[0]==word_num).mapValues(lambda v: (v[1], v[2]))
sentence_word_next.collect()

[(0,
  (u'ax',
   [(u'ax', 0),
    (u'a', 1),
    (u'as', 1),
    (u'at', 1),
    (u'an', 1),
    (u'am', 1),
    (u'ah', 1),
    (u'x', 1),
    (u'and', 2),
    (u'was', 2)])),
 (1,
  (u'set',
   [(u'set', 0),
    (u'see', 1),
    (u'met', 1),
    (u'let', 1),
    (u'yet', 1),
    (u'get', 1),
    (u'sat', 1),
    (u'sent', 1),
    (u'seat', 1),
    (u'st', 1)]))]

In [86]:
# check that there are more words left
sentence_word_next.isEmpty()

False

In [87]:
# split into suggestions
sentence_word_next_split = sentence_word_next.flatMap(lambda x: split_suggestions(x))
sentence_word_next_split.collect()

[[0, (u'ax', (u'ax', 0))],
 [0, (u'ax', (u'a', 1))],
 [0, (u'ax', (u'as', 1))],
 [0, (u'ax', (u'at', 1))],
 [0, (u'ax', (u'an', 1))],
 [0, (u'ax', (u'am', 1))],
 [0, (u'ax', (u'ah', 1))],
 [0, (u'ax', (u'x', 1))],
 [0, (u'ax', (u'and', 2))],
 [0, (u'ax', (u'was', 2))],
 [1, (u'set', (u'set', 0))],
 [1, (u'set', (u'see', 1))],
 [1, (u'set', (u'met', 1))],
 [1, (u'set', (u'let', 1))],
 [1, (u'set', (u'yet', 1))],
 [1, (u'set', (u'get', 1))],
 [1, (u'set', (u'sat', 1))],
 [1, (u'set', (u'sent', 1))],
 [1, (u'set', (u'seat', 1))],
 [1, (u'set', (u'st', 1))]]

In [88]:
# join on previous path
# format: (sentence id, ((current word, (current word suggestion, edit distance)), 
#         [(previous path-probability pairs)]))
sentence_word_next_path = sentence_word_next_split.join(sentence_path)
sentence_word_next_path.collect()

[(0,
  ((u'ax', (u'ax', 0)),
   [([u'this', u'is'], -0.001838225822312142),
    ([u'this', u'in'], -7.474238641144874),
    ([u'this', u'it'], -8.921157624081198),
    ([u'this', u'his'], -9.20883969653298),
    ([u'this', u'as'], -7.417080227304924),
    ([u'this', u'i'], -7.822545335413089),
    ([u'that', u's'], -15.439796609526178),
    ([u'this', u'if'], -10.307451985201089),
    ([u'this', u'its'], -15.880533930581192),
    ([u'this', u'us'], -15.880533930581192)])),
 (0,
  ((u'ax', (u'a', 1)),
   [([u'this', u'is'], -0.001838225822312142),
    ([u'this', u'in'], -7.474238641144874),
    ([u'this', u'it'], -8.921157624081198),
    ([u'this', u'his'], -9.20883969653298),
    ([u'this', u'as'], -7.417080227304924),
    ([u'this', u'i'], -7.822545335413089),
    ([u'that', u's'], -15.439796609526178),
    ([u'this', u'if'], -10.307451985201089),
    ([u'this', u'its'], -15.880533930581192),
    ([u'this', u'us'], -15.880533930581192)])),
 (0,
  ((u'ax', (u'as', 1)),
   [([u'this', u

In [89]:
# calculate path with max probability
sentence_word_next_path_prob = sentence_word_next_path.map(lambda x:
                                        subs_word_prob(x, bc_transition_prob.value, default_transition_prob))
sentence_word_next_path_prob.collect() #filter(lambda x: x[0]==0).collect()

[(0, ([u'this', u'is', u'ax'], 1.0206768458569286e-06)),
 (0, ([u'this', u'is', u'a'], 0.0006417656065436952)),
 (0, ([u'this', u'is', u'as'], 4.482519568703837e-05)),
 (0, ([u'this', u'is', u'at'], 4.263860077547551e-05)),
 (0, ([u'this', u'is', u'an'], 0.00012026272013595663)),
 (0, ([u'this', u'i', u'am'], 3.3681971618496976e-07)),
 (0, ([u'this', u'is', u'ah'], 1.0932974557814245e-06)),
 (0, ([u'this', u'is', u'x'], 1.0206768458569282e-08)),
 (0, ([u'this', u'is', u'and'], 8.199730918360683e-08)),
 (0, ([u'this', u'is', u'was'], 1.093297455781426e-08)),
 (1, ([u'her', u'to', u'set'], 1.674747778516495e-05)),
 (1, ([u'her', u'to', u'see'], 2.0216598183520555e-06)),
 (1, ([u'her', u'she', u'met'], 1.4880518962107781e-07)),
 (1, ([u'her', u'to', u'let'], 3.708370081000812e-07)),
 (1, ([u'her', u'the', u'yet'], 9.056951996826441e-09)),
 (1, ([u'her', u'to', u'get'], 1.1304547504986358e-06)),
 (1, ([u'her', u'she', u'sat'], 4.315350499011254e-07)),
 (1, ([u'her', u'he', u'sent'], 4.9229

In [90]:
# normalize for numerical stability
sentence_path = sentence_word_next_path_prob.groupByKey().mapValues(lambda v: normalize(v))
sentence_path.collect()

[(0,
  [([u'this', u'is', u'ax'], -6.727174598010501),
   ([u'this', u'is', u'a'], -0.2834174414245401),
   ([u'this', u'is', u'as'], -2.944870194548328),
   ([u'this', u'is', u'at'], -2.9948806151229896),
   ([u'this', u'is', u'an'], -1.9579618954602194),
   ([u'this', u'i', u'am'], -7.8358480381938485),
   ([u'this', u'is', u'ah'], -6.658442261252635),
   ([u'this', u'is', u'x'], -11.332344783998593),
   ([u'this', u'is', u'and'], -9.248709426698461),
   ([u'this', u'is', u'was'], -11.263612447240725)]),
 (1,
  [([u'her', u'to', u'set'], -0.7073236573108231),
   ([u'her', u'to', u'see'], -2.8216524579910995),
   ([u'her', u'she', u'met'], -5.43068860495902),
   ([u'her', u'to', u'let'], -4.517563968429027),
   ([u'her', u'the', u'yet'], -8.229793963975347),
   ([u'her', u'to', u'get'], -3.4029513384144745),
   ([u'her', u'she', u'sat'], -4.365977867966592),
   ([u'her', u'he', u'sent'], -6.536828443436029),
   ([u'her', u'the', u'seat'], -1.0631219193848052),
   ([u'her', u'the', u's

In [91]:
word_num += 1
word_num

3

In [92]:
sentence_word_next = sentence_word_suggestions.filter(lambda (k, v): v[0]==word_num) \
                .mapValues(lambda v: (v[1], v[2]))
sentence_word_next.collect()

[(0,
  (u'test',
   [(u'test', 0),
    (u'west', 1),
    (u'best', 1),
    (u'rest', 1),
    (u'that', 2),
    (u'these', 2),
    (u'went', 2),
    (u'must', 2),
    (u'most', 2),
    (u'left', 2)]))]

In [93]:
sentence_word_next_split = sentence_word_next.flatMap(lambda x: split_suggestions(x))
sentence_word_next_split.collect()

[[0, (u'test', (u'test', 0))],
 [0, (u'test', (u'west', 1))],
 [0, (u'test', (u'best', 1))],
 [0, (u'test', (u'rest', 1))],
 [0, (u'test', (u'that', 2))],
 [0, (u'test', (u'these', 2))],
 [0, (u'test', (u'went', 2))],
 [0, (u'test', (u'must', 2))],
 [0, (u'test', (u'most', 2))],
 [0, (u'test', (u'left', 2))]]

In [94]:
sentence_word_next_path = sentence_word_next_split.join(sentence_path)
sentence_word_next_path.collect()

[(0,
  ((u'test', (u'test', 0)),
   [([u'this', u'is', u'ax'], -6.727174598010501),
    ([u'this', u'is', u'a'], -0.2834174414245401),
    ([u'this', u'is', u'as'], -2.944870194548328),
    ([u'this', u'is', u'at'], -2.9948806151229896),
    ([u'this', u'is', u'an'], -1.9579618954602194),
    ([u'this', u'i', u'am'], -7.8358480381938485),
    ([u'this', u'is', u'ah'], -6.658442261252635),
    ([u'this', u'is', u'x'], -11.332344783998593),
    ([u'this', u'is', u'and'], -9.248709426698461),
    ([u'this', u'is', u'was'], -11.263612447240725)])),
 (0,
  ((u'test', (u'west', 1)),
   [([u'this', u'is', u'ax'], -6.727174598010501),
    ([u'this', u'is', u'a'], -0.2834174414245401),
    ([u'this', u'is', u'as'], -2.944870194548328),
    ([u'this', u'is', u'at'], -2.9948806151229896),
    ([u'this', u'is', u'an'], -1.9579618954602194),
    ([u'this', u'i', u'am'], -7.8358480381938485),
    ([u'this', u'is', u'ah'], -6.658442261252635),
    ([u'this', u'is', u'x'], -11.332344783998593),
    ([

In [95]:
sentence_word_next_path_prob = sentence_word_next_path.map(lambda x:
                                                subs_word_prob(x, bc_transition_prob.value, default_transition_prob))
sentence_word_next_path_prob.collect()

[(0, ([u'this', u'is', u'a', u'test'], 0.00015533214292411503)),
 (0, ([u'this', u'is', u'a', u'west'], 7.701937025599192e-09)),
 (0, ([u'this', u'is', u'at', u'best'], 4.6791753999290885e-07)),
 (0, ([u'this', u'is', u'a', u'rest'], 1.1649910719308636e-06)),
 (0, ([u'this', u'is', u'at', u'that'], 7.252721869890078e-08)),
 (0, ([u'this', u'is', u'at', u'these'], 5.848969249911342e-09)),
 (0, ([u'this', u'is', u'a', u'went'], 3.850968512799608e-11)),
 (0, ([u'this', u'is', u'a', u'must'], 3.850968512799608e-11)),
 (0, ([u'this', u'is', u'a', u'most'], 8.154937503516043e-08)),
 (0, ([u'this', u'is', u'a', u'left'], 3.883303573102879e-09))]

In [96]:
sentence_path = sentence_word_next_path_prob.groupByKey().mapValues(lambda v: normalize(v))
sentence_path.collect()

[(0,
  [([u'this', u'is', u'a', u'test'], -0.011550059383306386),
   ([u'this', u'is', u'a', u'west'], -9.923399161225545),
   ([u'this', u'is', u'at', u'best'], -5.816578936418101),
   ([u'this', u'is', u'a', u'rest'], -4.904402317823178),
   ([u'this', u'is', u'at', u'that'], -7.680909098480993),
   ([u'this', u'is', u'at', u'these'], -10.198605571091987),
   ([u'this', u'is', u'a', u'went'], -15.221716527773578),
   ([u'this', u'is', u'a', u'must'], -15.221716527773578),
   ([u'this', u'is', u'a', u'most'], -7.563662354755956),
   ([u'this', u'is', u'a', u'left'], -10.60818479247938)])]

In [97]:
word_num += 1
word_num

4

In [98]:
sentence_word_next = sentence_word_suggestions.filter(lambda (k, v): v[0]==word_num) \
                .mapValues(lambda v: (v[1], v[2]))
sentence_word_next.collect()

[]

In [99]:
def get_max_path(final_paths):
    max_path = max((p[1], p[0]) for p in final_paths)
    return max_path[1]

In [100]:
sentence_suggestion = sentence_path.mapValues(lambda v: get_max_path(v))
sentence_suggestion.collect()

[(0, [u'this', u'is', u'a', u'test'])]

In [101]:
sentence_max_prob = sentence_id.join(sentence_suggestion)
sentence_max_prob.collect()

[(0, ([u'this', u'is', u'ax', u'test'], [u'this', u'is', u'a', u'test']))]