In [None]:
import pyspark.sql.functions as PySQL
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from random import randint
import numpy as np

In [None]:
###############
#   GLOBALS   #
###############
SHINGLE_SIZE = 5
BAND_SIZE    = 20
ROW_SIZE     = 5

GCLOUD = False

# Spark
SPARK = SparkSession \
    .builder \
    .master("spark://master:7077" if GCLOUD else "local[*]") \
    .config("spark.executor.memory", "2g") \
    .config("spark.driver.memory", "2g") \
    .appName("PDD-Big-Task-1") \
    .getOrCreate()

# Shingles
SHINGLE_BASE = ord("Z") - ord("A") + 1

# Minhash
HASH_MOD = 1_000_000_007
PERMUTATION_COUNT = BAND_SIZE * ROW_SIZE
RAND_MAX = (2 ** 32) - 1
PERMUTATION_ARR_BROADCAST = SPARK.sparkContext.broadcast(
    np.array([
        (randint(1, RAND_MAX), randint(0, RAND_MAX))
        for _ in range(PERMUTATION_COUNT)
    ])
)

# Group definition file
GROUP_DEFINITION_PATH   = "hdfs://master:9000/data/group_definition.json" if GCLOUD else "data/group_definition.json"
GROUP_DEFINITION_SCHEMA = StructType([
    StructField("group", StringType(), False),
    StructField("protein_list", ArrayType(StringType(), False), False)
])

# Fasta directory
FASTA_PATH   = "hdfs://master:9000/data/fasta" if GCLOUD else "data/fasta"
FASTA_SCHEMA = StructType([
    StructField("name", StringType(), False),
    StructField("value", StringType(), False)
])

In [None]:
########################
#   HELPER FUNCTIONS   #
########################
def loadDataFrameGroupDefinition():
    df_group_definition_flat = SPARK.read.json(GROUP_DEFINITION_PATH)
    return df_group_definition_flat.melt([], df_group_definition_flat.columns, "group", "protein_list")

def loadDataFrameFasta():
    return SPARK.read.schema(FASTA_SCHEMA).json(FASTA_PATH)

def shingle_int(shingle):
    return sum(
        (ord(aminoacid) - ord("A")) * (SHINGLE_BASE ** exp)
        for exp, aminoacid in enumerate(shingle[::-1])
    )

In [None]:
###################
#   MINHASH UDF   #
###################
def getMinhashesOfBands(value):
    shingle_int_arr = np.array([
        shingle_int(value[i : i + SHINGLE_SIZE])
        for i in range(len(value) - SHINGLE_SIZE + 1)
    ])

    signature_arr = np.array([
        np.min((a * shingle_int_arr + b) % HASH_MOD)
        for a, b in PERMUTATION_ARR_BROADCAST.value
    ])

    signature_batch_hash_arr = np.array([
        hash(tuple(signature_arr[i : i + ROW_SIZE]))
        for i in range(0, PERMUTATION_COUNT, ROW_SIZE)
    ])

    return enumerate(signature_batch_hash_arr.tolist())

minhash_tuple_type = StructType([
    StructField("minhash_id", IntegerType(), False),
    StructField("minhash_value", IntegerType(), False)
])

udf_get_minhashes_of_bands = PySQL.udf(getMinhashesOfBands, ArrayType(minhash_tuple_type,False))

In [None]:
###################
#   DATA FRAMES   #
###################

df_group_definition = loadDataFrameGroupDefinition()
df_fasta = loadDataFrameFasta()

# Process group definition data
df_group_statistics = df_group_definition \
    .withColumn(
        "group_count",
        (PySQL.size(PySQL.col("protein_list"))).cast(LongType())
    ) \
    .withColumn(
        "group_pairs",
        (PySQL.col("group_count") * (PySQL.col("group_count") - 1) / 2).cast(LongType())
    ) \
    .select("group", "group_count", "group_pairs")

protein_count_total = df_group_statistics \
    .agg(PySQL.sum("group_count").alias("protein_count_sum")) \
    .collect()[0]["protein_count_sum"]

df_group_statistics = df_group_statistics \
    .withColumn(
        "mixed_pairs",
        PySQL.col("group_count") * (protein_count_total - PySQL.col("group_count"))
    ) \
    .select("group", "group_count", "group_pairs", "mixed_pairs")

df_proteins = df_group_definition \
    .select("group", PySQL.explode("protein_list").alias("protein"))

# LSH DF
df_lsh = df_fasta \
    .withColumn("minhash_band_signature_list", udf_get_minhashes_of_bands("value")) \
    .select("name", "minhash_band_signature_list") \
    .join(df_proteins, df_fasta.name == df_proteins.protein, "left") \
    .select("group", "name", PySQL.explode("minhash_band_signature_list").alias("minhash"))

# Similarity DF
df_similarity = df_lsh.alias("df_1") \
    .join(
        df_lsh.alias("df_2"),
        (PySQL.col("df_1.minhash") == PySQL.col("df_2.minhash")) & \
            (PySQL.col("df_1.name") < PySQL.col("df_2.name")),
        "inner"
    ) \
    .select(
        PySQL.col("df_1.group").alias("group_1"),
        PySQL.col("df_2.group").alias("group_2"),
        PySQL.col("df_1.name").alias("name_1"),
        PySQL.col("df_2.name").alias("name_2"),
        PySQL.col("df_1.minhash").alias("minhash")
    ) \
    .dropDuplicates(["group_1", "group_2", "name_1", "name_2"])

In [None]:
##################
#   STATISTICS   #
##################
true_positive_total = df_similarity \
    .filter(PySQL.col("group_1") == PySQL.col("group_2")) \
    .count()

false_positive_total = df_similarity \
    .filter(PySQL.col("group_1") != PySQL.col("group_2")) \
    .count()

single_group_pairs_total = df_group_statistics \
    .agg(PySQL.sum("group_pairs").alias("group_pairs_sum")) \
    .collect()[0]["group_pairs_sum"]

mixed_group_pairs_total = df_group_statistics \
    .agg(PySQL.sum("mixed_pairs").alias("mixed_pairs_sum")) \
    .collect()[0]["mixed_pairs_sum"] \
    // 2

print(f"No. of true positive pairs:  {true_positive_total}")
print(f"No. of false positive pairs: {false_positive_total}")
print(f"No. of single group pairs:   {single_group_pairs_total}")
print(f"No. of mixed group pairs:    {mixed_group_pairs_total}")
print()

true_positive_rate  = true_positive_total  / single_group_pairs_total
false_positive_rate = false_positive_total / mixed_group_pairs_total
precision           = true_positive_total  / (true_positive_total + false_positive_total)

print(f"True positive rate:  {true_positive_rate}")
print(f"False positive rate: {false_positive_rate}")
print(f"Precision:           {precision}")