<div class="alert alert-danger">
  <strong>To run this program, restart notebook, and start executing the cells of this section starting here.</strong> <br><p>
  This version parallelizes the word check for all the words in a document, using word-level correction. Since SPARK does not permit RDD manipulation from within an RDD transformation (i.e. no parallelism within a parallel task), we converted the `get_suggestions` function that acts on an individual word to a serial method. This allows us to then parallelize across multiple words in a document. <i>This is a reasonable trade off when the number of words in a document is much larger compared to the number of suggestions that will likely be found for any given word)</i>. <br><p>
  Also note the (modified) `no_RDD_get_suggestions` function still returns an entire list of all possible suggestions to the calling function (e.g. for context checking), even if only the top match is used or required. Future improvements may be made to `no_RDD_get_suggestions` to terminate early once a "top" match (e.g. minimum edit distance) is found; a speedup in that function will in turn lead to a performance improvement of the document checking function as well.
</div>

In [2]:
'''
v 4.0 last revised 27 Nov 2015

This program is a Spark (PySpark) version of a spellchecker based on SymSpell, 
a Symmetric Delete spelling correction algorithm developed by Wolf Garbe 
and originally written in C#.

'''
import re

n_partitions = 6  # number of partitions to be used
max_edit_distance = 3

# helper functions

    
def copartitioned(RDD1, RDD2):
    '''check if two RDDs are copartitioned'''
    return RDD1.partitioner == RDD2.partitioner

def combine_joined_lists(tup):
    '''takes as input a tuple in the form (a, b) where each of a, b may be None (but not both) or a list
       and returns a concatenated list of unique elements'''
    concat_list = []
    if tup[1] is None:
        concat_list = tup[0]
    elif tup[0] is None:
        concat_list = tup[1]
    else:
        concat_list = tup[0] + tup[1]
        
    return list(set(concat_list)) 

def dameraulevenshtein(seq1, seq2):
    """Calculate the Damerau-Levenshtein distance (an integer) between sequences.

    This code has not been modified from the original.
    Source: http://mwh.geek.nz/2009/04/26/python-damerau-levenshtein-distance/
    
    This distance is the number of additions, deletions, substitutions,
    and transpositions needed to transform the first sequence into the
    second. Although generally used with strings, any sequences of
    comparable objects will work.

    Transpositions are exchanges of *consecutive* characters; all other
    operations are self-explanatory.

    This implementation is O(N*M) time and O(M) space, for N and M the
    lengths of the two sequences.

    >>> dameraulevenshtein('ba', 'abc')
    2
    >>> dameraulevenshtein('fee', 'deed')
    2

    It works with arbitrary sequences too:
    >>> dameraulevenshtein('abcd', ['b', 'a', 'c', 'd', 'e'])
    2
    """
    # codesnippet:D0DE4716-B6E6-4161-9219-2903BF8F547F
    # Conceptually, this is based on a len(seq1) + 1 * len(seq2) + 1 matrix.
    # However, only the current and two previous rows are needed at once,
    # so we only store those.
    oneago = None
    thisrow = range(1, len(seq2) + 1) + [0]
    for x in xrange(len(seq1)):
        # Python lists wrap around for negative indices, so put the
        # leftmost column at the *end* of the list. This matches with
        # the zero-indexed strings and saves extra calculation.
        twoago, oneago, thisrow = oneago, thisrow, [0] * len(seq2) + [x + 1]
        for y in xrange(len(seq2)):
            delcost = oneago[y] + 1
            addcost = thisrow[y - 1] + 1
            subcost = oneago[y - 1] + (seq1[x] != seq2[y])
            thisrow[y] = min(delcost, addcost, subcost)
            # This block deals with transpositions
            if (x > 0 and y > 0 and seq1[x] == seq2[y - 1]
                and seq1[x-1] == seq2[y] and seq1[x] != seq2[y]):
                thisrow[y] = min(thisrow[y], twoago[y - 2] + 1)
    return thisrow[len(seq2) - 1]

    
def correct_document(fname, d, lwl=float('inf'), printlist=True):
    '''Correct an entire document using word-level correction.
    
    Note: Uses a serialized version of an individual word checker. 
    
    fname: filename
    d: the main dictionary (python dict), which includes deletes
             entries, is in the form of: {word: ([suggested corrections], 
                                                 frequency of word in corpus), ...}
    lwl: optional identifier of longest real word in masterdict
    printlist: identify unknown words and words with error (default is True)
    '''
    
    # broadcast lookup dictionary to workers
    bd = sc.broadcast(d)
    
    print "Finding misspelled words in your document..." 
    
    # http://stackoverflow.com/questions/22520932/python-remove-all-non-alphabet-chars-from-string
    regex = re.compile('[^a-z ]')

    # convert file into one long sequence of words with the line index for reference
    make_all_lower = sc.textFile(fname).map(lambda line: line.lower()).zipWithIndex()
    replace_nonalphs = make_all_lower.map(lambda (line, index): (regex.sub(' ', line), index))
    flattened = replace_nonalphs.map(lambda (line, index): 
                                 [(i, index) for i in line.split()]).flatMap(list)
    
    # create RDD with (each word in document, corresponding line index) 
    # key value pairs and cache it
    all_words = flattened.partitionBy(n_partitions).cache()
    
    # check all words in parallel --  stores whole list of suggestions for each word
    get_corrections = all_words.map(lambda (w, index): 
                                    (w, (no_RDD_get_suggestions(w, bd.value, lwl, True), index)),
                                     preservesPartitioning=True).cache()
    
    # UNKNOWN words are words where the suggestion list is empty
    unknown_words = get_corrections.filter(lambda (w, (sl, index)): len(sl)==0)
    if printlist:
        print "    Unknown words (line number, word in text):"
        print unknown_words.map(lambda (w, (sl, index)): (index, str(w))).sortByKey().collect()
    
    # ERROR words are words where the word does not match the first tuple's word (top match)
    error_words = get_corrections.filter(lambda (w, (sl, index)): len(sl)>0 and w!=sl[0][0]) 
    if printlist:
        print "    Words with suggested corrections (line number, word in text, top match):"
        print error_words.map(lambda (w, (sl, index)): 
                                 (index, str(w) + " --> " +
                                         str(sl[0][0]))).sortByKey().collect()
    
    print "-----"
    print "total words checked: %i" % get_corrections.count()
    print "total unknown words: %i" % unknown_words.count()
    print "total potential errors found: %i" % error_words.count()

    return

<div class="alert alert-danger">
  <strong>Run the cell below only once to build the dictionary.</strong>
</div>

In [3]:
%%time
d, lwl = parallel_create_dictionary("testdata/big.txt")

Creating dictionary...
total words processed: 1105285
total unique words in corpus: 29157
total items in dictionary (corpus words and deletions): 2151998
  edit distance for deletions: 3
  length of longest word in corpus: 18
CPU times: user 11.6 s, sys: 1.19 s, total: 12.8 s
Wall time: 47.3 s


<div class="alert alert-success">
  <strong>Enter word to correct below.</strong>
</div>

In [4]:
%%time
no_RDD_get_suggestions("there", d, lwl)

looking up suggestions based on input word...
number of possible corrections: 604
  edit distance for deletions: 3
CPU times: user 60.2 ms, sys: 3.25 ms, total: 63.5 ms
Wall time: 61.7 ms


[('there', (2972, 0)),
 ('these', (1231, 1)),
 ('where', (977, 1)),
 ('here', (691, 1)),
 ('three', (584, 1)),
 ('thee', (26, 1)),
 ('chere', (9, 1)),
 ('theme', (8, 1)),
 ('the', (80030, 2)),
 ('her', (5284, 2)),
 ('were', (4289, 2)),
 ('they', (3938, 2)),
 ('their', (2955, 2)),
 ('them', (2241, 2)),
 ('then', (1558, 2)),
 ('other', (1502, 2)),
 ('those', (1201, 2)),
 ('others', (410, 2)),
 ('third', (239, 2)),
 ('term', (133, 2)),
 ('threw', (96, 2)),
 ('mere', (79, 2)),
 ('theory', (79, 2)),
 ('share', (69, 2)),
 ('hero', (55, 2)),
 ('tree', (42, 2)),
 ('hare', (36, 2)),
 (u'thereby', (32, 2)),
 ('sphere', (31, 2)),
 ('hers', (30, 2)),
 (u'thereof', (26, 2)),
 ('cher', (25, 2)),
 ('tore', (18, 2)),
 ('herd', (15, 2)),
 ('theirs', (14, 2)),
 ('thiers', (13, 2)),
 ('shore', (11, 2)),
 ('thence', (10, 2)),
 ('tete', (9, 2)),
 ('ether', (8, 2)),
 ('adhere', (8, 2)),
 ('sheer', (8, 2)),
 ('tver', (7, 2)),
 (u'therein', (6, 2)),
 ('herb', (5, 2)),
 ('cheer', (5, 2)),
 ('hire', (5, 2)),
 (

In [5]:
%%time
no_RDD_get_suggestions("zzffttt", d, lwl)

looking up suggestions based on input word...
number of possible corrections: 0
  edit distance for deletions: 3
CPU times: user 272 µs, sys: 97 µs, total: 369 µs
Wall time: 290 µs


[]

<div class="alert alert-success">
  <strong>Enter file name of document to correct below.</strong>
</div>

In [6]:
%%time
correct_document("testdata/OCRsample.txt", d, lwl)

Finding misspelled words in your document...
    Unknown words (line number, word in text):
[(11, 'oonipiittee'), (42, 'senbrnrgs'), (82, 'ghmhvestigat')]
    Words with suggested corrections (line number, word in text, top match):
[(3, 'taiths --> faith'), (13, 'gjpt --> get'), (13, 'tj --> to'), (13, 'mnnff --> snuff'), (15, 'bh --> by'), (15, 'uth --> th'), (15, 'unuer --> under'), (15, 'snc --> sac'), (20, 'mthiitt --> thirty'), (21, 'cas --> was'), (22, 'pythian --> scythian'), (26, 'brainin --> brain'), (27, 'jfl --> of'), (28, 'eug --> dug'), (28, 'stice --> stick'), (28, 'blaci --> black'), (28, 'ji --> i'), (28, 'debbs --> debts'), (29, 'nericans --> americans'), (30, 'ergs --> eggs'), (30, 'ainin --> again'), (31, 'trumped --> trumpet'), (32, 'erican --> american'), (33, 'thg --> the'), (33, 'nenance --> penance'), (33, 'unorthodox --> orthodox'), (34, 'rgs --> rags'), (34, 'sln --> son'), (38, 'eu --> e'), (38, 'williaij --> william'), (40, 'fcsf --> ff'), (40, 'ber --> be')

***

<div class="alert alert-info">
  <strong>START RUNNING CODE HERE</strong>
</div>

In [1]:
import re
import math
from scipy.stats import poisson

In [2]:
import findspark
import os
findspark.init()
import pyspark
sc = pyspark.SparkContext()
sc.setLogLevel('ERROR')

***
# Pre-processing

In [3]:
n_partitions = 6  # number of partitions to be used
MAX_EDIT_DISTANCE = 3

In [4]:
def get_n_deletes_list(w, n):
    '''given a word, derive list of strings with up to n characters deleted'''
    # since this list is generally of the same magnitude as the number of 
    # characters in a word, it may not make sense to parallelize this
    # so we use python to create the list
    deletes = []
    queue = [w]
    for d in range(n):
        temp_queue = []
        for word in queue:
            if len(word)>1:
                for c in range(len(word)):  # character index
                    word_minus_c = word[:c] + word[c+1:]
                    if word_minus_c not in deletes:
                        deletes.append(word_minus_c)
                    if word_minus_c not in temp_queue:
                        temp_queue.append(word_minus_c)
        queue = temp_queue
        
    return deletes

In [5]:
############
#
# load file & initial processing
#
############

In [6]:
fname = "testdata/big.txt"

In [7]:
regex = re.compile('[^a-z ]')

In [8]:
make_all_lower = sc.textFile(fname).map(lambda line: line.lower())

In [9]:
# print make_all_lower
# print make_all_lower.getNumPartitions()
# print make_all_lower.count()
# print make_all_lower.take(5)

In [10]:
split_sentence = make_all_lower.flatMap(lambda line: line.split('.')).map(lambda sentence: regex.sub(' ', sentence)) \
            .map(lambda sentence: sentence.split())

In [11]:
# print split_sentence
# print split_sentence.getNumPartitions()
# print split_sentence.count()
# print split_sentence.take(5)

In [12]:
############
#
# generate start probabilities
#
############

In [13]:
start_words = split_sentence.map(lambda sentence: sentence[0] if len(sentence)>0 else None) \
    .filter(lambda word: word!=None)

In [14]:
# print start_words
# print start_words.getNumPartitions()
# print start_words.count()
# print start_words.take(5)

In [15]:
accum_total_start_words = sc.accumulator(0)
count_start_words_once = start_words.map(lambda word: (word, 1))
count_total_start_words = count_start_words_once.foreach(lambda x: accum_total_start_words.add(1))
total_start_words = float(accum_total_start_words.value)

In [16]:
# print count_start_words_once
# print count_start_words_once.getNumPartitions()
# print count_start_words_once.count()
# print count_start_words_once.take(5)

print 'Total start words:', total_start_words

Total start words: 137073.0


In [17]:
unique_start_words = count_start_words_once.reduceByKey(lambda a, b: a + b, numPartitions = n_partitions)

In [18]:
# print unique_start_words
# print unique_start_words.getNumPartitions()
# print unique_start_words.count()
# print unique_start_words.take(5)

In [19]:
start_prob_calc = unique_start_words.map(lambda (k,v): (k, math.log(v/total_start_words)))
default_start_prob = math.log(1/total_start_words)

In [20]:
# print start_prob_calc
# print start_prob_calc.getNumPartitions()
# print start_prob_calc.count()
# print start_prob_calc.take(5)

print 'Default start probability:', default_start_prob

Default start probability: -11.8282689096


In [21]:
start_prob = start_prob_calc.collectAsMap()

In [22]:
############
#
# generate transition probabilities
#
############

In [23]:
def get_transitions(sentence):
    result = []
    if len(sentence)<2:
        return None
    else:
        for i in range(len(sentence)-1):
            result.append(((sentence[i], sentence[i+1]), 1))
        return result

In [24]:
accum_total_other_words = sc.accumulator(0)
other_words = split_sentence.map(lambda sentence: get_transitions(sentence)).filter(lambda x: x!=None). \
                flatMap(lambda x: x)
count_total_other_words = other_words.foreach(lambda x: accum_total_other_words.add(1))
total_other_words = float(accum_total_other_words.value)

In [25]:
# print other_words
# print other_words.getNumPartitions()
# print other_words.count()
# print other_words.take(5)

print 'Total other words', total_other_words

Total other words 968212.0


In [26]:
unique_other_words = other_words.reduceByKey(lambda a, b: a + b, numPartitions = n_partitions)

In [27]:
# print unique_other_words
# print unique_other_words.getNumPartitions()
# print unique_other_words.count()
# print unique_other_words.take(5)

In [28]:
other_words_collapsed = unique_other_words.map(lambda x: (x[0][0], (x[0][1], x[1]))).groupByKey().mapValues(dict)

In [29]:
# print other_words_collapsed
# print other_words_collapsed.getNumPartitions()
# print other_words_collapsed.count()
# print other_words_collapsed.take(5)

In [30]:
def map_transition_prob(x):
    vals = x[1]
    total = float(sum(vals.values()))
    probs = {k: math.log(v/total) for k, v in vals.items()}
    return (x[0], probs)

In [31]:
transition_prob_calc = other_words_collapsed.map(lambda x: map_transition_prob(x))
default_transition_prob = math.log(1/total_other_words)

In [32]:
# print transition_prob_calc
# print transition_prob_calc.getNumPartitions()
# print transition_prob_calc.count()
# print transition_prob_calc.take(5)

print 'Default transition probability:', default_transition_prob

Default transition probability: -13.7832063505


In [33]:
transition_prob = transition_prob_calc.collectAsMap()

In [34]:
############
#
# generate dictionary
#
############

In [35]:
all_words = make_all_lower.map(lambda line: regex.sub(' ', line)).flatMap(lambda line: line.split())

In [36]:
# print all_words
# print all_words.getNumPartitions()
# print all_words.count()
# print all_words.take(5)

In [37]:
count_once = all_words.map(lambda word: (word, 1))

In [38]:
# print count_once
# print count_once.getNumPartitions()
# print count_once.count()
# print count_once.take(5)

In [39]:
unique_words_with_count = count_once.reduceByKey(lambda a, b: a + b, numPartitions = n_partitions).cache()

In [40]:
# print unique_words_with_count
# print unique_words_with_count.getNumPartitions()
# print unique_words_with_count.count()
# print unique_words_with_count.take(5)

In [41]:
assert MAX_EDIT_DISTANCE>0 

In [42]:
generate_deletes = unique_words_with_count.map(lambda (parent, count): 
                                                   (parent, get_n_deletes_list(parent, MAX_EDIT_DISTANCE)))

In [43]:
# print generate_deletes
# print generate_deletes.getNumPartitions()
# print generate_deletes.count()
# print generate_deletes.take(5)

In [44]:
expand_deletes = generate_deletes.flatMapValues(lambda x: x)

In [45]:
# print expand_deletes
# print expand_deletes.getNumPartitions()
# print expand_deletes.count()
# print expand_deletes.take(5)

In [46]:
swap = expand_deletes.map(lambda (orig, delete): (delete, ([orig], 0)))

In [47]:
# print swap
# print swap.getNumPartitions()
# print swap.count()
# print swap.take(5)

In [48]:
corpus = unique_words_with_count.mapValues(lambda count: ([], count))

In [49]:
# print corpus
# print corpus.getNumPartitions()
# print corpus.count()
# print corpus.take(5)

In [50]:
combine = swap.union(corpus)  # combine deletes with main dictionary, eliminate duplicates

In [51]:
# print combine
# print combine.getNumPartitions()
# print combine.count()
# print combine.take(5)

In [52]:
dictionary = combine.reduceByKeyLocally(lambda a, b: (a[0]+b[0], a[1]+b[1]))

In [53]:
longest_word_length = unique_words_with_count.map(lambda (k, v): len(k)).reduce(max)

***
# Sentence-level parallelization

In [54]:
def get_emission_prob(edit_dist, poisson_lambda=0.01):
    '''
    The emission probability, i.e. P(word typed|word intended)
    is approximated by a Poisson(k, l) distribution, where 
    k=edit distance and l=0.01.
    
    The lambda parameter matches the one used in the AM207
    lecture notes. Various parameters between 0 and 1 were tested
    to confirm that 0.01 yields the most accurate results.
    '''
    
    return math.log(poisson.pmf(edit_dist, poisson_lambda))

In [55]:
def get_start_prob(word, start_prob, default_start_prob):
    try:
        return start_prob[word]
    except KeyError:
        return default_start_prob

In [56]:
def get_transition_prob(cur_word, prev_word, transition_prob, default_transition_prob):
    try:
        return transition_prob[prev_word][cur_word]
    except KeyError:
        return default_transition_prob

In [57]:
def get_belief(prev_word, prev_belief):
    try:
        return prev_belief[prev_word]
    except KeyError:
        return math.log(math.exp(min(prev_belief.values()))/2.)

In [58]:
def dameraulevenshtein(seq1, seq2):
    '''
    Calculate the Damerau-Levenshtein distance between sequences.
    Same code as word-level checking.
    '''
    
    # codesnippet:D0DE4716-B6E6-4161-9219-2903BF8F547F
    # Conceptually, this is based on a len(seq1) + 1 * len(seq2) + 1
    # matrix. However, only the current and two previous rows are
    # needed at once, so we only store those.
    
    oneago = None
    thisrow = range(1, len(seq2) + 1) + [0]
    
    for x in xrange(len(seq1)):
        
        # Python lists wrap around for negative indices, so put the
        # leftmost column at the *end* of the list. This matches with
        # the zero-indexed strings and saves extra calculation.
        twoago, oneago, thisrow = \
            oneago, thisrow, [0] * len(seq2) + [x + 1]
        
        for y in xrange(len(seq2)):
            delcost = oneago[y] + 1
            addcost = thisrow[y - 1] + 1
            subcost = oneago[y - 1] + (seq1[x] != seq2[y])
            thisrow[y] = min(delcost, addcost, subcost)
            # This block deals with transpositions
            if (x > 0 and y > 0 and seq1[x] == seq2[y - 1]
                and seq1[x-1] == seq2[y] and seq1[x] != seq2[y]):
                thisrow[y] = min(thisrow[y], twoago[y - 2] + 1)
                
    return thisrow[len(seq2) - 1]

In [59]:
def get_suggestions(string, dictionary, 
                    longest_word_length, min_count):
    '''
    Return list of suggested corrections for potentially incorrectly
    spelled word.
    Code based on get_suggestions function from word-level checking,
    with the addition of the min_count parameter, which only
    considers words that have occur more than min_count times in the
    (dictionary) corpus.
    '''
    
    if (len(string) - longest_word_length) > MAX_EDIT_DISTANCE:
        return []
    
    suggest_dict = {}
    
    queue = [string]
    q_dictionary = {}  # items other than string that we've checked
    
    while len(queue)>0:
        q_item = queue[0]  # pop
        queue = queue[1:]
        
        # process queue item
        if (q_item in dictionary) and (q_item not in suggest_dict):
            if (dictionary[q_item][1]>=min_count):
            # word is in dictionary, and is a word from the corpus,
            # and not already in suggestion list so add to suggestion
            # dictionary, indexed by the word with value (frequency
            # in corpus, edit distance)
            # note: q_items that are not the input string are shorter
            # than input string since only deletes are added (unless
            # manual dictionary corrections are added)
                assert len(string)>=len(q_item)
                suggest_dict[q_item] = \
                    (dictionary[q_item][1], len(string) - len(q_item))
            
            # the suggested corrections for q_item as stored in
            # dictionary (whether or not q_item itself is a valid
            # word or merely a delete) can be valid corrections
            for sc_item in dictionary[q_item][0]:
                if (sc_item not in suggest_dict):
                    
                    # compute edit distance
                    # suggested items should always be longer (unless
                    # manual corrections are added)
                    assert len(sc_item)>len(q_item)
                    # q_items that are not input should be shorter
                    # than original string 
                    # (unless manual corrections added)
                    assert len(q_item)<=len(string)
                    if len(q_item)==len(string):
                        assert q_item==string
                        item_dist = len(sc_item) - len(q_item)

                    # item in suggestions list should not be the same
                    # as the string itself
                    assert sc_item!=string           
                    # calculate edit distance using Damerau-
                    # Levenshtein distance
                    item_dist = dameraulevenshtein(sc_item, string)
                    
                    if item_dist<=MAX_EDIT_DISTANCE:
                        # should already be in dictionary if in
                        # suggestion list
                        assert sc_item in dictionary  
                        # trim list to contain state space
                        if (dictionary[q_item][1]>=min_count): 
                            suggest_dict[sc_item] = \
                                (dictionary[sc_item][1], item_dist)
        
        # now generate deletes (e.g. a substring of string or of a
        # delete) from the queue item as additional items to check
        # -- add to end of queue
        assert len(string)>=len(q_item)
        if (len(string)-len(q_item))<MAX_EDIT_DISTANCE \
            and len(q_item)>1:
            for c in range(len(q_item)): # character index        
                word_minus_c = q_item[:c] + q_item[c+1:]
                if word_minus_c not in q_dictionary:
                    queue.append(word_minus_c)
                    # arbitrary value to identify we checked this
                    q_dictionary[word_minus_c] = None

    # return list of suggestions:
    # (correction, (frequency in corpus, edit distance)):
    as_list = suggest_dict.items()
    return sorted(as_list, key = \
                  lambda (term, (freq, dist)): (dist, -freq))

    '''
    Output format:
    get_suggestions('file')
    [('file', (5, 0)),
     ('five', (67, 1)),
     ('fire', (54, 1)),
     ('fine', (17, 1))...]  
    '''

In [60]:
def viterbi(words, dictionary, longest_word_length,
            start_prob, default_start_prob, 
            transition_prob, default_transition_prob,
            min_count=1,num_word_suggestions=5000):
    
    V = [{}]
    path = {}
    path_context = []
    
    # character level correction - used to determine state space
    corrections = get_suggestions(
        words[0], dictionary, longest_word_length, min_count)

    # to ensure Viterbi can keep running
    # -- use the word itself if no corrections are found
    if len(corrections) == 0:
        corrections = [(words[0], (1, 0))]
    else:    
        if len(corrections) > num_word_suggestions:
            corrections = corrections[0:num_word_suggestions]
        
    # Initialize base cases (t == 0)
    for sug_word in corrections:
        
        # compute the value for all possible starting states
        V[0][sug_word[0]] = math.exp(
            get_start_prob(sug_word[0], start_prob, 
                           default_start_prob)
            + get_emission_prob(sug_word[1][1]))
        
        # remember all the different paths (only one state so far)
        path[sug_word[0]] = [sug_word[0]]
 
    # normalize for numerical stability
    path_temp_sum = sum(V[0].values())
    V[0].update({k: math.log(v/path_temp_sum) 
                 for k, v in V[0].items()})
    
    # keep track of previous state space
    prev_corrections = [i[0] for i in corrections]
    
    if len(words) == 1:
        path_context = [max(V[0], key=lambda i: V[0][i])]
        return path_context

    # run Viterbi for t > 0
    for t in range(1, len(words)):

        V.append({})
        new_path = {}
        
        # character level correction
        corrections = get_suggestions(
            words[t], dictionary, longest_word_length, min_count)
        
        # to ensure Viterbi can keep running
        # -- use the word itself if no corrections are found
        if len(corrections) == 0:
            corrections = [(words[t], (1, 0))]
        else:
            if len(corrections) > num_word_suggestions:
                corrections = corrections[0:num_word_suggestions]
 
        for sug_word in corrections:
        
            sug_word_emission_prob = get_emission_prob(sug_word[1][1])
            
            # compute the values coming from all possible previous
            # states, only keep the maximum
            (prob, word) = max(
                (get_belief(prev_word, V[t-1]) 
                + get_transition_prob(sug_word[0], prev_word, 
                    transition_prob, default_transition_prob)
                + sug_word_emission_prob, prev_word) 
                               for prev_word in prev_corrections)

            # save the maximum value for each state
            V[t][sug_word[0]] = math.exp(prob)
            
            # remember the path we came from to get this maximum value
            new_path[sug_word[0]] = path[word] + [sug_word[0]]
            
        # normalize for numerical stability
        path_temp_sum = sum(V[t].values())
        V[t].update({k: math.log(v/path_temp_sum) 
                     for k, v in V[t].items()})
        
        # keep track of previous state space
        prev_corrections = [i[0] for i in corrections]
 
        # don't need to remember the old paths
        path = new_path
     
    (prob, word) = max((V[t][sug_word[0]], sug_word[0]) 
                       for sug_word in corrections)
    path_context = path[word]
    
    return path_context

In [61]:
############
#
# load file & initial processing
#
############

In [62]:
fname = "testdata/test.txt"

In [63]:
# broadcast Python dictionaries to workers
bc_dictionary = sc.broadcast(dictionary)
bc_start_prob = sc.broadcast(start_prob)
bc_transition_prob = sc.broadcast(transition_prob)

In [64]:
make_all_lower = sc.textFile(fname).map(lambda line: line.lower())

In [65]:
# print make_all_lower
# print make_all_lower.getNumPartitions()
# print make_all_lower.count()
# print make_all_lower.take(5)

In [66]:
split_sentence = make_all_lower.flatMap(lambda line: line.split('.')).map(lambda sentence: regex.sub(' ', sentence)) \
            .map(lambda sentence: sentence.split())

In [67]:
# print split_sentence
# print split_sentence.getNumPartitions()
# print split_sentence.count()
# print split_sentence.take(5)
print split_sentence.collect()

[[u'this', u'is', u'a', u'test'], [u'this', u'is', u'a', u'test'], [u'here', u'is', u'a', u'test'], [u'this', u'is', u'ax', u'tesst'], [u'this', u'is', u'za', u'test'], [u'thee', u'is', u'a', u'test'], [u'her', u'tee', u'set']]


In [68]:
# use accumulator to count the number of words checked
accum_total_words = sc.accumulator(0)
split_words = split_sentence.flatMap(lambda x: x).foreach(lambda x: accum_total_words.add(1))
print 'Words checked: ', accum_total_words.value

Words checked:  27


In [69]:
sentence_id = split_sentence.zipWithIndex().map(lambda (k, v): (v, k))

In [70]:
# print sentence_id
# print sentence_id.getNumPartitions()
# print sentence_id.count()
# print sentence_id.take(5)
print sentence_id.collect()

[(0, [u'this', u'is', u'a', u'test']), (1, [u'this', u'is', u'a', u'test']), (2, [u'here', u'is', u'a', u'test']), (3, [u'this', u'is', u'ax', u'tesst']), (4, [u'this', u'is', u'za', u'test']), (5, [u'thee', u'is', u'a', u'test']), (6, [u'her', u'tee', u'set'])]


In [71]:
sentence_correction = sentence_id.map(lambda (k, v): (k, (v, viterbi(
                v, bc_dictionary.value, longest_word_length, bc_start_prob.value, 
                default_start_prob, bc_transition_prob.value, default_transition_prob))))

In [74]:
# print sentence_correction
# print sentence_correction.getNumPartitions()
# print sentence_correction.count()
# print sentence_correction.take(5)
print sentence_correction.collect()

[(0, ([u'this', u'is', u'a', u'test'], [u'this', u'is', u'a', u'test'])), (1, ([u'this', u'is', u'a', u'test'], [u'this', u'is', u'a', u'test'])), (2, ([u'here', u'is', u'a', u'test'], [u'here', u'is', u'a', u'test'])), (3, ([u'this', u'is', u'ax', u'tesst'], [u'this', u'is', u'a', u'test'])), (4, ([u'this', u'is', u'za', u'test'], [u'this', u'is', u'a', u'test'])), (5, ([u'thee', u'is', u'a', u'test'], [u'there', u'is', u'a', u'test'])), (6, ([u'her', u'tee', u'set'], [u'her', u'to', u'set']))]


In [75]:
def get_sentence_mismatches(sentences):
    orig_sentence, sug_sentence = sentences
    mismatches = [(orig_sentence[i], sug_sentence[i]) for i in range(len(orig_sentence)) 
            if orig_sentence[i]!=sug_sentence[i]]
    if len(mismatches)==0:
        return None
    else:
        return mismatches

In [76]:
sentence_mismatch = sentence_correction.map(lambda (k, v): (k, get_sentence_mismatches(v))) \
                .filter(lambda (k,v): v!=None)

In [77]:
# print sentence_mismatch
# print sentence_mismatch.getNumPartitions()
# print sentence_mismatch.count()
# print sentence_mismatch.take(5)
sentence_mismatch.collect()

[(3, [(u'ax', u'a'), (u'tesst', u'test')]),
 (4, [(u'za', u'a')]),
 (5, [(u'thee', u'there')]),
 (6, [(u'tee', u'to')])]

In [84]:
def split_mismatches(mismatches):
    sent_id, word_list = mismatches
    result = []
    for word in word_list:
        result.append([sent_id, word[0], word[1]])
    return result

In [91]:
word_mismatch = sentence_mismatch.flatMap(lambda x: split_mismatches(x))

In [92]:
# print word_mismatch
# print word_mismatch.getNumPartitions()
print word_mismatch.count()
# print word_mismatch.take(5)
print word_mismatch.collect()

5
[[3, u'ax', u'a'], [3, u'tesst', u'test'], [4, u'za', u'a'], [5, u'thee', u'there'], [6, u'tee', u'to']]


In [81]:
# use accumulator to count the number of mismatches
accum_total_mismatches = sc.accumulator(0)
count_mismatches = word_mismatch.foreach(lambda x: accum_total_mismatches.add(1))
print 'Potential mismatches: ', accum_total_mismatches.value

Potential mismatches:  5


In [82]:
printlist=True

In [83]:
# ERROR words are words where the word does not match the first tuple's word (top match)
if printlist:
    print '    Words with suggested corrections (line number, word in text, top match):'
    print word_mismatch.map(lambda x: (x[0], str(x[1]) + " --> " + str(x[2]))).collect()

    Words with suggested corrections (line number, word in text, top match):
[(3, 'ax --> a'), (3, 'tesst --> test'), (4, 'za --> a'), (5, 'thee --> there'), (6, 'tee --> to')]
