In [2]:
# Data file at https://www.cse.ust.hk/msbd5003/data

lines = sc.textFile('../data/adj_noun_pairs.txt')

In [3]:
lines.count()

3162692

In [4]:
lines.getNumPartitions()

2

In [6]:
lines.take(5)

[u'early radical', u'french revolution', u'pejorative way', u'violent means', u'positive label']

In [7]:
# Converting lines into word pairs. 
# Data is dirty: some lines have more than 2 words, so filter them out.
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()

PythonRDD[7] at RDD at PythonRDD.scala:48

In [8]:
pairs.take(5)

[(u'early', u'radical'), (u'french', u'revolution'), (u'pejorative', u'way'), (u'violent', u'means'), (u'positive', u'label')]

In [9]:
N = pairs.count()

In [10]:
N

3162674

In [11]:
# Compute the frequency of each pair.
# Ignore pairs that not frequent enough
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2) \
                  .filter(lambda pf: pf[1] >= 100)

In [12]:
pair_freqs.take(5)

[((u'much', u'debate'), 136), ((u'new', u'name'), 221), ((u'other', u'country'), 1857), ((u'other', u'book'), 223), ((u'present', u'husband'), 414)]

In [13]:
# Computing the frequencies of the adjectives and the nouns
a_freqs = pairs.map(lambda p: (p[0],1)).reduceByKey(lambda x,y: x+y)
n_freqs = pairs.map(lambda p: (p[1],1)).reduceByKey(lambda x,y: x+y)

In [14]:
a_freqs.take(5)

[(u'fawn', 2), (u'base-paired', 3), (u'eicosapentanoic', 1), (u'host-cell', 2), (u'1,800', 1)]

In [15]:
n_freqs.count()

106333

In [16]:
# Broadcasting the adjective and noun frequencies. 
#a_dict = a_freqs.collectAsMap()
#a_dict = sc.parallelize(a_dict).map(lambda x: x)
n_dict = sc.broadcast(n_freqs.collectAsMap())
a_dict = sc.broadcast(a_freqs.collectAsMap())
a_dict.value['violent']

1191

In [17]:
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
    pmi = log(float(f)*N/(a_dict.value[w1]*n_dict.value[w2]), 2)
    return pmi, (w1, w2)

In [18]:
# Computing the PMI for all pairs.
scored_pairs = pair_freqs.map(pmi_score)

In [19]:
# Printing the most strongly associated pairs. 
scored_pairs.top(10)

[(14.41018838546462, (u'magna', u'carta')), (13.071365888694997, (u'polish-lithuanian', u'Commonwealth')), (12.990597616733414, (u'nitrous', u'oxide')), (12.64972604311254, (u'latter-day', u'Saints')), (12.50658937509916, (u'stainless', u'steel')), (12.482331020687814, (u'pave', u'runway')), (12.19140721768055, (u'corporal', u'punishment')), (12.183248694293388, (u'capital', u'punishment')), (12.147015483562537, (u'rush', u'yard')), (12.109945794428935, (u'globular', u'cluster'))]