In [202]:
import findspark
findspark.init("/Users/chukuemekaogudu/Documents/Dev-Spark-Apache/Apache-Spark/spark-2.4.5-bin-hadoop2.7")

from pyspark import SparkContext, SparkConf
import json
import time
import os
import random
from collections import defaultdict
from functools import reduce
from itertools import combinations

random.seed(25)

In [203]:
data_dir = "/Volumes/oli2/inf533_datasets/"

In [204]:
config = SparkConf().setMaster("local[*]") \
                    .setAppName("Task1") \
                    .set("spark.executor.memory", "4g") \
                    .set("spark.driver.memory", "4g")
sc = SparkContext(conf=config).getOrCreate()

In [205]:
lines = sc.textFile(os.path.join(data_dir, "train_review.json"))

In [206]:
rdd = lines.map(json.loads) \
           .map(lambda x: (x["user_id"], x["business_id"])).cache()

In [207]:
business_map = rdd.map(lambda x: x[1]) \
                   .distinct() \
                   .zipWithIndex().cache()

b_dict = business_map.collectAsMap()
reversed_b_dict = business_map.map(lambda x: (x[1], x[0])).collectAsMap()

In [208]:
user_dict = rdd.map(lambda x: x[0]) \
               .distinct() \
               .zipWithIndex() \
               .collectAsMap()

In [209]:
user_business_rdd = rdd.map(lambda x: (user_dict[x[0]], b_dict[x[1]])) \
                       .groupByKey() \
                       .mapValues(lambda x: list(set(x))).cache()

In [210]:
business_user_dict = rdd.map(lambda x: (b_dict[x[1]], user_dict[x[0]])) \
                        .groupByKey() \
                        .mapValues(lambda x: list(set(x))) \
                        .collectAsMap()

### MinHash

In [177]:
n_buckets = 100
business_count = len(b_dict)

In [178]:
def dynamic_hash(idx):
    p = 2**35 - 365
    a = random.randint(1, p - 1)
    b = random.randint(217, p - 1)
    m = 4294967295
    return lambda x: ((a * x + b * idx) % p) % m

In [179]:
hash_funcs = []
for i in range(n_buckets):
    hash_funcs.append(dynamic_hash(i))

In [180]:
def get_hash(row):
    hash_vals = []
    for hash_func in hash_funcs:
        hash_vals.append(hash_func(row))
    return hash_vals

In [181]:
hash_rdd = user_business_rdd.map(lambda x: (x[0], get_hash(x[0])))

In [182]:
joined_hash_rdd = hash_rdd.join(user_business_rdd).partitionBy(7, lambda x: hash(x))

In [183]:
def get_user_hash(pair):
    hash_val = pair[0]
    users = pair[1]
    return [(user, hash_val) for user in users]

In [184]:
def min_hash(h1, h2):
    signature = []
    for v1, v2, in zip(h1, h2):
        signature.append(min(v1, v2))
    return signature

In [185]:
signature_mat = joined_hash_rdd.map(lambda x: get_user_hash(x[1])) \
                               .flatMap(lambda x: x) \
                               .reduceByKey(lambda h1, h2: min_hash(h1, h2)).cache()

### LSH

In [186]:
BANDS = 25

In [194]:
def lsh_hash(idx):
    p = 2**75 - 545
    a = random.randint(71**2, p - 1)
    b = random.randint(0, p - 1)
    m = 729351
    return lambda x: ((a * x + b * idx) % p) % m

In [195]:
lsh_hash_funcs = []
for i in range(BANDS):
    lsh_hash_funcs.append(lsh_hash(i))

In [196]:
def generate_bands(signature):
    bands = []
    length = len(signature)
    window = length//BANDS
    idx = 1
    for i in range(0, length, window):
        start = i
        bands.append((idx, signature[i: i + window]))
        idx += 1
    return bands

In [197]:
def group_bands(pairs):
    business_idx = pairs[0]
    group = []

    for band_pair in pairs[1]:
        group.append((band_pair[0], (business_idx, band_pair[1])))
    return group

In [198]:
def lsh(bands):
    band_id = bands[0]
    pairs = bands[1]
    
    hash_table = defaultdict(list)
    for pair in pairs:
        business_id = pair[0]
        hash_sum = hash(tuple(pair[1]))
        hash_func = lsh_hash_funcs[band_id - 1]
        hash_val = hash_func(hash_sum)
        hash_table[hash_val].append(business_id)
        
    results = [v for _, v in hash_table.items()]
    return (band_id, results)

In [199]:
candidates = signature_mat.map(lambda x: (x[0], generate_bands(x[1]))) \
                               .map(group_bands) \
                               .flatMap(lambda x: x).groupByKey() \
                               .map(lsh) \
                               .flatMap(lambda x: x[1]) \
                               .filter(lambda x: len(x) > 1) \
                               .flatMap(lambda pairs: [pair for pair in combinations(pairs, 2)])

In [200]:
print(len(candidates.collect()))

3158


In [201]:
def jaccard(s1, s2):
    return s1.intersect(s2)/s1.union(s2)

In [None]:
def compute_similarity():
    