In [1]:
import findspark
findspark.init()

from pyspark import SparkContext
import json
import time
import collections
from operator import add
from pyspark.mllib.fpm import FPGrowth

In [2]:
startTime = time.time()

filter_threshold = 70
global_threshold = 50
input_file_path = r'user_business.csv'
output_file_path = r'task3_Output.txt'

sc = SparkContext.getOrCreate()

def turnStr2Pair(pairStr):
    return (pairStr.split(',')[0], pairStr.split(',')[1])
    
# read csv file and generate pairs
with open(input_file_path) as f:
    rawStrList = f.readlines()[1:]
    pairList = [pair.split('\n')[0] for pair in rawStrList]
    f.close()

qualifiedUsersRDD = sc.parallelize(pairList).map(lambda pairStr:turnStr2Pair(pairStr))\
                    .groupByKey().mapValues(lambda iterable:set(iterable))\
                    .filter(lambda user_bus_pair:len(user_bus_pair[1])>filter_threshold)\
                    .map(lambda pair: pair[1])

In [3]:
minSupport = global_threshold / qualifiedUsersRDD.count()
fpgRDD = FPGrowth.train(qualifiedUsersRDD, minSupport, 50)
result = sorted(fpgRDD.freqItemsets().collect())

In [4]:
task2RawResultList = list()
task2ResultList = list()
fileReader = open('task2Output.txt','r')
foundFrequentItemsetsSection = False
frequentItemsetsString = ''
for line in fileReader:
    if not foundFrequentItemsetsSection:
        if line=='Frequent Itemsets: \n':
            foundFrequentItemsetsSection = True
    else:
        frequentItemsetsString+=line
task2RawResultList = frequentItemsetsString.replace('\n\n',',').split('),(')

for frequentItemset in task2RawResultList:
    cleanItem = frequentItemset.replace('"','').replace('(','').replace(')','').replace('\'','')
    cleanItemList = cleanItem.split(', ')
    task2ResultList.append(tuple(cleanItemList))
print('Length of task2 result:', len(task2ResultList))

Length of task2 result: 2582


In [5]:
fpgResultList = list()
for frequentItem in result:
    fpgResultList.append(tuple(frequentItem[0]))
fpgResultList.sort(key=lambda l:(len(l),l))
print('Length of task3 result:', len(fpgResultList))

Length of task3 result: 2582


In [6]:
intersection = set(task2ResultList).intersection(set(fpgResultList))

In [7]:
outputFile = open(output_file_path, 'w')
outputFile.write('Task2,'+str(len(task2ResultList))+'\n')
outputFile.write('Task3,'+str(len(fpgResultList))+'\n')
outputFile.write('Intersection,'+str(len(intersection))+'\n')
outputFile.close()

print("Duration: %d" % (time.time() - startTime))

Duration: 88
