In [1]:
from pyspark import SparkContext
import csv
from itertools import combinations
from collections import OrderedDict,Counter
import random
from numpy.random import rand

In [2]:
def id_to_index(iteration,order="u,b"):
    for review in iteration:
        uidx = uid2idx[review[0]]
        bidx = bid2idx[review[1]]
        if order == "u,b":
            yield (uidx,[bidx])
        elif order == "b,u":
            yield (bidx,[uidx])

def sort_idxs(iteration):
    for line in iteration:
        yield (line[0],sorted(line[1]))
        

# [uidx, [h1(uidx), h2(uidx), ...]]
def cal_hash(iters,a_list,b_list,user_cnt):
    for x in iters:
        yield (x,[(x*a + b) % user_cnt for a,b in zip(a_list,b_list)])

        
def emit_sig_band(signature_cols,r,b):
    for sig_col in signature_cols:
        for i in range(b):
            band_cnt = str(i)
            partial_sig = sig_col[1][i*r:(i+1)*r]
            sig_bucket = '_'.join([str(sig) for sig in partial_sig])
            sig_bucket = band_cnt + '_' + sig_bucket
            yield (sig_bucket, [sig_col[0]])  # emit (sig_bucket, bidx)
    

            

def emit_candidate_pairs(buckets):
    pairs = []
    for bucket in buckets:
        if len(bucket[1])==2:
            pair = ','.join([str(i) for i in sorted(bucket[1])])
            if pair not in pairs: 
                pairs.append(pair)
                yield pair
        elif len(bucket[1])>2:
            for pair in combinations(bucket[1],2):
                pair = ','.join([str(i) for i in sorted(pair)])
                if pair not in pairs:
                    pairs.append(pair)
                    yield pair

                    
def min_hash(iters,hash_bc):
    hash_num = len(hash_bc.value[0])
    for business_col in iters:
        bidx = business_col[0]
        uidxs = business_col[1]
        tmp_signature_col = [[] for i in range(hash_num)]
        for uidx in uidxs:
            hash_col_u = hash_bc.value[uidx]
            for i in range(hash_num):
                tmp_signature_col[i].append(hash_col_u[i])
        signature_col = [min(i) for i in tmp_signature_col]  # a little bit faster than using list(map(min,tmp_signature_col))
        yield (bidx, signature_col)


def cal_jaccard_sim(pairs,char_mat_bc):
    for pair in pairs:
        pair = [int(i) for i in pair.split(',')]
        s1 = set(char_mat_bc.value[pair[0]])
        s2 = set(char_mat_bc.value[pair[1]])
        similarity = len(s1.intersection(s2)) / float(len(s1.union(s2)))
        if similarity>=0.5:
            yield [pair[0],pair[1],similarity]

def transform_to_bid_pair(similar_pairs,idx_to_bid_bc):
    for pair in similar_pairs:
        bid_pair = sorted([idx_to_bid_bc.value[pair[0]],idx_to_bid_bc.value[pair[1]]])
        yield [bid_pair[0],bid_pair[1],pair[2]]

        
def transform_to_bid_pair2(similar_pairs,idx_to_bid_bc):
    for pair in similar_pairs:
        pair = [int(i) for i in pair.split(',')]
        bid_pair = sorted([idx_to_bid_bc.value[pair[0]],idx_to_bid_bc.value[pair[1]]])
        yield [bid_pair[0],bid_pair[1]]        
        

In [3]:
sc = SparkContext.getOrCreate()

input_file = "/Users/liangsiqi/Documents/Dataset/yelp_rec_data/yelp_train.csv"
minPartition = 5
raw_data = sc.textFile(input_file,minPartition) #(input_file, minPartition)
header = raw_data.first()
clean_data = raw_data.filter(lambda x: x != header).mapPartitions(lambda x: csv.reader(x))

print(raw_data.count(),clean_data.count())

455855 455854


In [4]:
print("Raw data example:")
print(header)
for line in clean_data.take(3):
    print(line)
print("\nTotal number in '%s': %d" % (input_file, clean_data.count()))

user_ids = clean_data.map(lambda x: x[0]).distinct().collect()
business_ids = clean_data.map(lambda x: x[1]).distinct().collect()
print("User id numbers: %d" % len(user_ids))
print("Business id numbers: %d" % len(business_ids))

uid2idx = dict()
bid2idx = dict()
idx2bid = dict()
for idx,uid in enumerate(user_ids,0):
    uid2idx[uid] = idx

for idx,bid in enumerate(business_ids,0):
    bid2idx[bid] = idx
    idx2bid[idx] = bid

idx2bid_bc = sc.broadcast(idx2bid)


        
# convert to [uidx,[bidx1, bidx2, ...]], [bidx1, bidx2, ...] is sorted
rows = clean_data.mapPartitions(lambda iters: id_to_index(iters,"u,b")).reduceByKey(lambda a,b: a+b).mapPartitions(lambda iters: sort_idxs(iters))
rows.persist() # TODO???
# convert to [bidx,[uidx1, uidx2, ...]] , [uidx1, uidx2, ...] is sorted
columns = clean_data.mapPartitions(lambda iters: id_to_index(iters,"b,u")).reduceByKey(lambda a,b: a+b).mapPartitions(lambda iters: sort_idxs(iters))
columns.persist() # TODO???
columns_bc = sc.broadcast(columns.collectAsMap())

Raw data example:
user_id, business_id, stars
['vxR_YV0atFxIxfOnF9uHjQ', 'gTw6PENNGl68ZPUpYWP50A', '5.0']
['o0p-iTC5yTBV5Yab_7es4g', 'iAuOpYDfOTuzQ6OPpEiGwA', '4.0']
['-qj9ouN0bzMXz1vfEslG-A', '5j7BnXXvlS69uLVHrY9Upw', '2.0']

Total number in '/Users/liangsiqi/Documents/Dataset/yelp_rec_data/yelp_train.csv': 455854
User id numbers: 11270
Business id numbers: 24732


In [5]:
random.seed(42)
user_cnt = len(user_ids)
business_cnt = len(business_ids)
r = 3
b = 30
print((1/b)**(1/r))
hash_num = r*b
a_values = []
b_values = []
for i in range(hash_num):
    a_r = random.randint(1,business_cnt)
    while a_r in a_values:
        a_r = random.randint(1,business_cnt)
    a_values.append(a_r)
    b_values.append(random.randint(1,business_cnt))

    

# {uidx1: [h1(uidx1), h2(uidx1), ...], 
#  uidx2: [h1(uidx2), h2(uidx2), ...],
#  ...}
hash_values = rows.keys().mapPartitions(lambda iters: cal_hash(iters, a_values,b_values,user_cnt)).collectAsMap()

hash_values_bc = sc.broadcast(hash_values)


# (bidx, [sig1,sig2, ...])
signature_mat = columns.mapPartitions(lambda iteration: min_hash(iteration, hash_values_bc))


0.3218297948685433


In [6]:
candidates = signature_mat.mapPartitions(lambda signature_cols: emit_sig_band(signature_cols,r,b)).reduceByKey(lambda a,b: a+b).filter(lambda x: len(x[1])>1)
print("===Total number of possible candidate groups: ",candidates.count())

===Total number of possible candidate groups:  21813


In [7]:
Counter(candidates.map(lambda x: len(x[1])).collect())

Counter({3: 3018,
         2: 16690,
         11: 31,
         4: 1001,
         5: 447,
         7: 125,
         13: 15,
         6: 212,
         14: 12,
         37: 1,
         20: 3,
         10: 42,
         8: 84,
         12: 21,
         26: 3,
         18: 9,
         9: 48,
         16: 7,
         58: 1,
         25: 1,
         27: 3,
         15: 9,
         23: 2,
         94: 1,
         24: 2,
         17: 9,
         21: 3,
         30: 2,
         38: 2,
         40: 1,
         19: 3,
         36: 1,
         47: 1,
         70: 1,
         35: 1,
         22: 1})

In [8]:
distinct_candidate_pair = candidates.mapPartitions(emit_candidate_pairs).distinct()
print("candidate pairs after distinct:", distinct_candidate_pair.count())

similar_pair_idx = distinct_candidate_pair.mapPartitions(lambda pairs: cal_jaccard_sim(pairs,columns_bc))
similar_pair_bid = similar_pair_idx.mapPartitions(lambda similar_pairs: transform_to_bid_pair(similar_pairs,idx2bid_bc)).sortBy(lambda pair: [pair[0],pair[1]])

similar_pair_bid.take(3), similar_pair_bid.count()

candidate pairs after distinct: 69812


([['-8O4kt8AIRhM3OUxt-pWLg', '_p64KqqRmPwGKhZ-xZwhtg', 0.5],
  ['-A5jntJgFglQ6zwAmOiOMw', 'cTqIuG-fvlQQL0OWzsFdig', 0.5],
  ['-Jhlh8Scjy669NdtCfKSSg', 'o5Mofj5KJkYAMs_fhxftpg', 0.5]],
 639)

In [9]:
res = similar_pair_bid.map(lambda x: x[0] + ',' + x[1]).collect()
res_lines = similar_pair_bid.map(lambda x: x[0] + ',' + x[1] + ',%.2f' % x[2]).collect()

In [10]:
truth_file = "/Users/liangsiqi/Documents/Dataset/yelp_rec_data/pure_jaccard_similarity.csv"
raw_ground_truth = sc.textFile(truth_file)
header = raw_ground_truth.first()
ground_truth_lines = raw_ground_truth.filter(lambda x: x != header).mapPartitions(lambda x: csv.reader(x))
ground_truth = ground_truth_lines.map(lambda x: x[0]+','+x[1]).collect()

In [11]:
len(res),len(set(res))

(639, 639)

In [12]:
in_ground_truth = 0
for pair in res:
    if pair in ground_truth:
        in_ground_truth+=1
print(in_ground_truth)
print("Precision: %.3f" % (in_ground_truth/len(res_lines)))
print("Recall: %.3f" % ())in_ground_truth/len(ground_truth)

639
Precision: 1.000
Recall: 0.992


In [13]:
len(ground_truth)

644

In [28]:
sig_band_cnt = candidates.map(lambda x: len(x[1])).collect()
print("./tmp_sig_band_cnt%d_%d.txt" % (r,b))
with open("./tmp_sig_band_cnt%d_%d.txt" % (r,b),'w') as f:
    for cnt in sig_band_cnt:
        print(cnt,file=f)

./tmp_sig_band_cnt5_32.txt


In [44]:
mat = rand(hash_num,business_cnt).tolist()

%timeit list(map(min,mat))
%timeit [min(i) for i in mat]

24.4 ms ± 651 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
23.8 ms ± 428 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [17]:
r = 4
b = list(range(20,31))
print(b)
for bb in b:
    print((1/bb)**(1/r))

[20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
0.4728708045015879
0.4671379777282001
0.4617366309441026
0.4566337854967312
0.4518010018049224
0.4472135954999579
0.4428500142691474
0.4386913376508308
0.43472087194499137
0.43092381945890607
0.42728700639623407
