In [None]:
# Spark configuration
import os 
os.environ['PYSPARK_SUBMIT_ARGS'] ="--conf spark.driver.memory=3g  pyspark-shell"

from pyspark.sql import SparkSession
JAVA_HOME = "/usr/lib/jvm/java-1.17.0-openjdk-amd64" # Set Java version (>=17)
os.environ["JAVA_HOME"] = JAVA_HOME

try: 
    spark
    print("Spark application already started. Terminating existing application and starting new one")
    spark.stop()
except: 
    pass

spark = SparkSession \
    .builder \
    .master("local[*]") \
    .appName("demoRDD") \
    .getOrCreate()
    
sc=spark.sparkContext

sc

In [None]:
# Read input data from file and convert it to key-value pairs, 
# with the post id as key and a list of shingles as value.
# The data should be structured as a post id and the (unhashed) shingles per line.
# Use convert_xml.py to get this structure.

input_data = "data/se_parsed.txt"
unhashed_shingles = (sc.textFile(input_data)
         .map(lambda line: line.split(","))
         .map(lambda x: (int(x[0]), x[1:]))
         .filter(lambda kv: len(kv[1]) > 0))

In [None]:
import random 
import hashlib

def str_to_int32(s):
    return int.from_bytes(hashlib.sha1(s.encode("utf-8")).digest()[:4], "little")

a1 = random.randint(0, 2**32)
b1 = random.randint(0, 2**32)
p = 2**61 - 1 # Mersenne prime 

def hash_shingle(s) -> int:
    x = str_to_int32(s) 
    hash = ((a1*x + b1) % p) % 2**32
    return hash

def hash_post(shingles) -> list[int]:
    hashed = []
    for shingle in shingles:
        hashed.append(hash_shingle(shingle))
    return hashed

In [None]:
hashed_shingles = unhashed_shingles.map(lambda kv: (kv[0], hash_post(kv[1]))).persist()

In [None]:
# MinHash parameters
num_minhashes = 45  # number of hash functions in the MinHash signature
import random
random.seed(42)

max_uint32 = 2**32

In [None]:
# Generate parameters (a and b) for k hash functions
hash_params = [(random.randint(1, max_uint32-1), random.randint(0, max_uint32-1)) for _ in range(num_minhashes)]
hash_params_bc = sc.broadcast(hash_params)

In [None]:
def create_minhash(shingle_hashes):
    if not shingle_hashes:
        return [-1]*len(hash_params)
    sig = []
    for a,b in hash_params:
        m = min((((a*i + b) % p) % 2**32) for i in shingle_hashes)
        sig.append(m)
    return sig

minhash_sigs = hashed_shingles.mapValues(lambda hashes: create_minhash(hashes)).persist()

In [None]:
# LSH parameters
t = 0.4
rows_per_band = 3
bands = 15

# Check that rows * bands is equal to the number of MinHashes
assert rows_per_band * bands == num_minhashes, f"rows * bands = {rows_per_band * bands} != minhash_rows = {num_minhashes}"
print(f"t = {t}, (1/b)^(1/r) = {(1 / bands) ** (1 / rows_per_band)}")

In [None]:
# Book on LSH:
# "For each band, there is a hash function that takes vectors of r integers
# (the portion of one column within that band) and hashes them to some large
# number of buckets. We can use the same hash function for all the bands, but
# we use a separate bucket array for each band, so columns with the same vector
# in different bands will not hash to the same bucket.""

def lsh_bucket_pairs(post_id, sig):
    # produce ((band, bucket_hash), post_id) for each band
    for band in range(bands):
        start = band * rows_per_band
        end = start + rows_per_band
        sig_slice = sig[start:end] 

        x = 0
        for r, v in enumerate(sig_slice):
            # x += r * v
            x = ((x * 1000003) ^ (v + 0x9e3779b9 + (r << 6) + (r >> 2))) & 0xFFFFFFFF
        bucket_hash = ((a1 * x + b1) % p) % 2**32
        yield ((band, bucket_hash), post_id)

# create RDD of ((band, bucket), [post_ids]) 
buckets_rdd = (
    minhash_sigs
    .flatMap(lambda kv: lsh_bucket_pairs(kv[0], kv[1]))
    .map(lambda kv: (kv[0], [kv[1]]))
    .reduceByKey(lambda a, b: a + b)
)

# Keep only buckets with more than one post 
candidate_buckets = buckets_rdd.filter(lambda kv: len(kv[1]) > 1)


In [None]:
from itertools import combinations

pair_counts = (
    candidate_buckets
    .flatMap(lambda kv: [tuple(sorted((i, j))) for i, j in combinations(kv[1], 2)])
    .map(lambda pair: (pair, 1))
    .reduceByKey(lambda a, b: a + b)
).persist()

In [None]:
# This cell takes several minutes to compute 
top_pairs = pair_counts.map(lambda kv: (kv[1], kv[0])).sortByKey(False).take(100)

In [None]:
# Test the number of buckets a pair was found in,
# vs the true Jaccard similarity and MinHash similarity

needed_ids = {i for _, (i, j) in top_pairs for i in (i, j)}
hs_map = hashed_shingles.filter(lambda kv: kv[0] in needed_ids).collectAsMap()
ms_map = minhash_sigs.filter(lambda kv: kv[0] in needed_ids).collectAsMap()

for count, (i, j) in top_pairs:
    shingles_i = set(hs_map.get(i, []))
    shingles_j = set(hs_map.get(j, []))

    sigs_i = ms_map.get(i)
    sigs_j = ms_map.get(j)

    if len(shingles_i) == 0 and len(shingles_j) == 0:
        true_jaccard = 1.0
    elif len(shingles_i) == 0 or len(shingles_j) == 0:
        true_jaccard = 0.0
    else:
        true_jaccard = len(shingles_i & shingles_j) / len(shingles_i | shingles_j)

    if sigs_i is None or sigs_j is None:
        minhash_sim = None
    else:
        minhash_sim = sum(1 for a, b in zip(sigs_i, sigs_j) if a == b) / len(sigs_i)

    print(f"{(i, j)}: {count}, jaccard similarity: {true_jaccard:.3f}, minhash similarity: {minhash_sim:.3f}")


In [None]:
sc.stop()