In [None]:
import pyspark.sql.functions as PySQL
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from random import randint
import numpy as np
import json

In [None]:
###############
#   GLOBALS   #
###############
SHINGLE_SIZE = 5
BAND_SIZE    = 20
ROW_SIZE     = 5

TEST = False
PERF = False

# Spark
SPARK = SparkSession \
    .builder \
    .master("local[*]") \
    .appName("mySession") \
    .getOrCreate()

# Shingles
SHINGLE_BASE = ord("Z") - ord("A") + 1

# Minhash
HASH_MOD = 1_000_000_007
PERMUTATION_COUNT = BAND_SIZE * ROW_SIZE
RAND_MAX = (2 ** 32) - 1
PERMUTATION_ARR_BROADCAST = SPARK.sparkContext.broadcast(
    np.array([
        (randint(1, RAND_MAX), randint(0, RAND_MAX))
        for _ in range(PERMUTATION_COUNT)
    ])
)

# Group definition file
GROUP_DEFINITION_PATH   = "data/group_definition.json" if TEST else "data/bruh.json" if PERF else "data/group_definition.json"
GROUP_DEFINITION_SCHEMA = StructType([
    StructField("group", StringType(), False),
    StructField("protein_list", ArrayType(StringType(), False), False)
])

# Fasta directory
FASTA_PATH   = "data/test_fasta" if TEST else "data/bruh_fasta" if PERF else "data/fasta"
FASTA_SCHEMA = StructType([
    StructField("name", StringType(), False),
    StructField("value", StringType(), False)
])

In [None]:
########################
#   HELPER FUNCTIONS   #
########################
def printSparkDetails(spark):
    print("Details of SparkContext:")
    print(f"App Name : {spark.sparkContext.appName}")
    print(f"Master : {spark.sparkContext.master}")

def loadDataFrameGroupDefinition():
    with open(GROUP_DEFINITION_PATH) as group_definitions_file:
        return SPARK.createDataFrame(
            json.load(group_definitions_file).items(),
            GROUP_DEFINITION_SCHEMA
        )

def loadDataFrameFasta():
    return SPARK.read.schema(FASTA_SCHEMA).json(FASTA_PATH)

def shingle_int(shingle):
    return sum(
        (ord(aminoacid) - ord("A")) * (SHINGLE_BASE ** exp)
        for exp, aminoacid in enumerate(shingle[::-1])
    )

In [None]:
################
#   WORKLOAD   #
################
def getMinhashesOfBands(value):
    shingle_int_arr = np.array([
        shingle_int(value[i : i + SHINGLE_SIZE])
        for i in range(len(value) - SHINGLE_SIZE + 1)
    ])

    signature_arr = np.array([
        np.min((a * shingle_int_arr + b) % HASH_MOD)
        for a, b in PERMUTATION_ARR_BROADCAST.value
    ])

    signature_batch_hash_arr = np.array([
        hash(tuple(signature_arr[i : i + ROW_SIZE]))
        for i in range(0, PERMUTATION_COUNT, ROW_SIZE)
    ])

    return enumerate(signature_batch_hash_arr.tolist())

minhash_tuple_type = StructType([
    StructField("minhash_id", IntegerType(), False),
    StructField("minhash_value", IntegerType(), False)
])

udf_get_minhashes_of_bands = PySQL.udf(getMinhashesOfBands, ArrayType(minhash_tuple_type,False))

In [None]:
###########
#   RUN   #
###########
printSparkDetails(SPARK)

df_group_definition = loadDataFrameGroupDefinition() \
    .select("group", PySQL.explode("protein_list").alias("protein"))

df_fasta = loadDataFrameFasta()

# df_lsh = df_fasta \
#     .withColumn("minhash_band_signature_list", udf_get_minhashes_of_bands("value")) \
#     .select("name", "minhash_band_signature_list")

# df_lsh = df_lsh \
#     .join(df_group_definition, df_lsh.name == df_group_definition.protein, "left") \
#     .select("group", "name", "minhash_band_signature_list")

# df_lsh = df_lsh \
#     .select("group", "name", PySQL.explode("minhash_band_signature_list").alias("minhash"))

df_lsh = df_fasta \
    .withColumn("minhash_band_signature_list", udf_get_minhashes_of_bands("value")) \
    .select("name", "minhash_band_signature_list") \
    .join(df_group_definition, df_fasta.name == df_group_definition.protein, "left") \
    .select("group", "name", PySQL.explode("minhash_band_signature_list").alias("minhash"))

df_similarity = df_lsh.alias("df_1") \
    .join(
        df_lsh.alias("df_2"),
        (PySQL.col("df_1.minhash") == PySQL.col("df_2.minhash")) & \
            (PySQL.col("df_1.name") < PySQL.col("df_2.name")),
        "inner"
    ) \
    .select(
        PySQL.col("df_1.group").alias("group_1"),
        PySQL.col("df_2.group").alias("group_2"),
        PySQL.col("df_1.name").alias("name_1"),
        PySQL.col("df_2.name").alias("name_2"),
        PySQL.col("df_1.minhash").alias("minhash")
    )

In [None]:
########################
#   PERFORMANCE TEST   #
########################
print("cooking time!")
gainz = ""
with open("data/bruh_fasta/A0A1P8XQ85.json") as f:
    gainz = json.load(f)["value"]
%time _ = getMinhashesOfBands(gainz)
%time _ = df_similarity.collect()
df_similarity.printSchema()
df_similarity.show(100, False)