## Pointwise mutual information for tokens in Shakespeare's work that co-occur beyond a given threshold

In [None]:
from simple_tokenize import simple_tokenize
from math import log
from itertools import permutations
from pyspark import SparkContext, SparkConf
sc = SparkContext(appName="MyApp", master="local[2]")

In [2]:
# Returns a list of tuples with the following format:
# ((token1, token2), pmi, co-occurrence_count, token1_count, token2_count)
def PMI(threshold):
    # read in text file as RDD
    lines = sc.textFile('Shakespeare.txt')
    
    # count line co-occurances for each pair
    # filter for co-occurances >= threshold
    pairCount = lines.map(lambda line: simple_tokenize(line)) \
                     .map(lambda line: list(set(line))) \
                     .flatMap(lambda line: permutations(line, 2)) \
                     .map(lambda pair: (pair, 1)) \
                     .reduceByKey(lambda x, y: x + y) \
                     .filter(lambda x: x[1] >= threshold)
    
    # count line occurances for each token
    tokenCount = lines.map(lambda line: simple_tokenize(line)) \
                      .flatMap(lambda line: list(set(line))) \
                      .map(lambda token: (token, 1)) \
                      .reduceByKey(lambda x, y: x + y)

    # PMI function
    # pmi(nxy, nx, ny, nlines) takes n(x), n(y), n(x,y), and the number of lines and produces the PMI
    # pmi: Int Int Int --> Int
    def pmi(nxy, nx, ny, nlines=lines.count()):
        pxy = nxy / nlines
        px = nx / nlines
        py = ny / nlines
        pmi = log(pxy / (px * py), 10)
        return pmi
    
    # join pairCount and tokenCount for each pair
    #  algorithm:
    #   join tokenCount on token 1
    #   join tokenCount on token 2
    #   compute pmi
    #   organize results
    pmiData = pairCount.map(lambda x: (x[0][0], (x[0][1], x[1]))).join(tokenCount) \
                       .map(lambda x: (x[1][0][0], (x[0], x[1][0][1], x[1][1]))).join(tokenCount) \
                       .map(lambda x: ((x[1][0][0], x[0]), x[1][0][1], x[1][0][2], x[1][1])) \
                       .map(lambda x: ((x[0], pmi(x[1], x[2], x[3])) + x[1:4]))

    return pmiData.collect()

In [4]:
# PMI for tokens that co-occur on >= 2000 lines
PMI(2000)

[(('i', 'of'), -0.08531809933104954, 2081, 18657, 16624),
 (('the', 'of'), 0.34294075191889295, 7266, 24300, 16624),
 (('and', 'of'), 0.028305447826683594, 3565, 24604, 16624),
 (('a', 'of'), 0.13551796879761382, 2463, 13280, 16624),
 (('i', 'have'), 0.447248246874471, 2450, 18657, 5742),
 (('i', 'that'), 0.11221751207929727, 2085, 18657, 10569),
 (('the', 'that'), 0.069458604407197, 2461, 24300, 10569),
 (('i', 'and'), -0.05037403533805645, 3338, 18657, 24604),
 (('of', 'and'), 0.028305447826683594, 3565, 16624, 24604),
 (('in', 'and'), 0.04031821796867754, 2340, 10614, 24604),
 (('the', 'and'), 0.0459349918330654, 5427, 24300, 24604),
 (('to', 'and'), 0.017522588038697152, 3815, 18237, 24604),
 (('a', 'and'), 0.0006198226234107661, 2672, 13280, 24604),
 (('you', 'and'), -0.05963472348492286, 2136, 12196, 24604),
 (('my', 'and'), 0.0056896916303759305, 2351, 11549, 24604),
 (('i', 'to'), 0.046852605922459045, 3095, 18657, 18237),
 (('the', 'to'), 0.05123525982989819, 4072, 24300, 1823