In [1]:
from pyspark.sql import SQLContext
from pyspark.sql.types import *
from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.functions import *
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
from pyspark.sql.types import StructType, StructField
from pyspark import SparkContext
from pyspark.sql import SparkSession

sc = SparkContext("local",'app')
spark = SparkSession.builder.appName('name').config('spark.sql.shuffle.partitions',10).getOrCreate()

In [6]:
# Data file at https://www.cse.ust.hk/msbd5003/data
numPartitions = 10
lines = sc.textFile('adj_noun_pairs.txt', numPartitions)
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()

PythonRDD[5] at RDD at PythonRDD.scala:48

In [10]:
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2)
pair_freqs.cache()
print pair_freqs.take(5)

[((u'move', u'army'), 3), ((u'beautiful', u'Girl'), 3), ((u'great', u'reduction'), 5), ((u'female', u'37,794'), 1), ((u'=', u'Pesach'), 1)]


In [14]:
sort_pair_freqs=pair_freqs.sortBy(lambda w:w[1],False)
sort_pair_freqs.cache()
print sort_pair_freqs.collect()[:5]

[((u'external', u'link'), 8136), ((u'19th', u'century'), 2869), ((u'20th', u'century'), 2816), ((u'same', u'time'), 2744), ((u'first', u'time'), 2632)]


In [15]:
with_index_pair=sort_pair_freqs.zipWithIndex()
with_index_pair.cache()
print with_index_pair.collect()[:5]

[(((u'external', u'link'), 8136), 0), (((u'19th', u'century'), 2869), 1), (((u'20th', u'century'), 2816), 2), (((u'same', u'time'), 2744), 3), (((u'first', u'time'), 2632), 4)]


In [19]:
amount_pair=with_index_pair.count()
median_number=(amount_pair+1)/2

temp_answer=with_index_pair.filter(lambda w: w[1]==median_number)
temp_answer.cache()
median_pair=temp_answer.collect()[0][0][0]
print median_pair

(u'widespread', u'matriarchy')


In [3]:
lines.getNumPartitions()

8

In [4]:
lines.take(5)

[u'early radical', u'french revolution', u'pejorative way', u'violent means', u'positive label']

In [5]:
# Converting lines into word pairs. 
# Data is dirty: some lines have more than 2 words, so filter them out.
pairs = lines.map(lambda l: tuple(l.split())).filter(lambda p: len(p)==2)
pairs.cache()

PythonRDD[4] at RDD at PythonRDD.scala:48

In [6]:
pairs.take(5)

[(u'early', u'radical'), (u'french', u'revolution'), (u'pejorative', u'way'), (u'violent', u'means'), (u'positive', u'label')]

In [7]:
N = pairs.count()

In [8]:
N

3162674

In [9]:
# Compute the frequency of each pair.
# Ignore pairs that not frequent enough
pair_freqs = pairs.map(lambda p: (p,1)).reduceByKey(lambda f1, f2: f1 + f2) \
                  .filter(lambda pf: pf[1] >= 100)

In [10]:
pair_freqs.take(5)

[((u'lead', u'role'), 298), ((u'other', u'means'), 202), ((u'huge', u'number'), 129), ((u'young', u'boy'), 156), ((u'old', u'age'), 174)]

In [11]:
# Computing the frequencies of the adjectives and the nouns
a_freqs = pairs.map(lambda p: (p[0],1)).reduceByKey(lambda x,y: x+y)
n_freqs = pairs.map(lambda p: (p[1],1)).reduceByKey(lambda x,y: x+y)

In [12]:
a_freqs.take(5)

[(u'algeria-related', 1), (u'funereal', 5), (u'datalink', 1), (u'then-leading', 1), (u'214th', 3)]

In [13]:
n_freqs.count()

106333

In [14]:
# A broadcast variable that gets reused across tasks. 
# broadcast variable is more powerful than global variable
#collectAsMap: Return the key-value pairs in this RDD to the master as a dictionary
#a_dict = sc.parallelize(a_dict).map(lambda x: x)
n_dict = sc.broadcast(n_freqs.collectAsMap())
a_dict = sc.broadcast(a_freqs.collectAsMap())
a_dict.value['violent']

1191

In [15]:
from math import *

# Computing the PMI for a pair.
def pmi_score(pair_freq):
    w1, w2 = pair_freq[0]
    f = pair_freq[1]
    pmi = log(float(f)*N/(a_dict.value[w1]*n_dict.value[w2]), 2)
    return pmi, (w1, w2)

In [16]:
# Computing the PMI for all pairs.
scored_pairs = pair_freqs.map(pmi_score)

In [17]:
# Printing the most strongly associated pairs. 
scored_pairs.top(10)

[(14.41018838546462, (u'magna', u'carta')), (13.071365888694997, (u'polish-lithuanian', u'Commonwealth')), (12.990597616733414, (u'nitrous', u'oxide')), (12.64972604311254, (u'latter-day', u'Saints')), (12.50658937509916, (u'stainless', u'steel')), (12.482331020687814, (u'pave', u'runway')), (12.19140721768055, (u'corporal', u'punishment')), (12.183248694293388, (u'capital', u'punishment')), (12.147015483562537, (u'rush', u'yard')), (12.109945794428935, (u'globular', u'cluster'))]