In [119]:
import pyspark.sql.functions as PySQL
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from random import randint
import numpy as np
import json

In [120]:
###############
#   GLOBALS   #
###############
SHINGLE_SIZE = 5
BAND_SIZE    = 20
ROW_SIZE     = 5

TEST = False
PERF = False

# Shingles
SHINGLE_BASE = ord("Z") - ord("A") + 1

# Minhash
HASH_MOD = 1_000_000_007
PERMUTATION_COUNT = BAND_SIZE * ROW_SIZE
RAND_MAX = (2 ** 32) - 1
PERMUTATION_ARR = np.array([
    (randint(1, RAND_MAX), randint(0, RAND_MAX))
    for _ in range(PERMUTATION_COUNT)
])

# Spark
SPARK = SparkSession \
    .builder \
    .master("local[*]") \
    .appName("mySession") \
    .getOrCreate()

# Group definition file
GROUP_DEFINITION_PATH   = "data/test.json" if TEST else "data/bruh.json" if PERF else "data/group_definition.json"
GROUP_DEFINITION_SCHEMA = StructType([
    StructField("group", StringType(), False),
    StructField("protein_list", ArrayType(StringType(), False), False)
])

# Fasta directory
FASTA_PATH   = "data/test_fasta" if TEST else "data/bruh_fasta" if PERF else "data/fasta"
FASTA_SCHEMA = StructType([
    StructField("name", StringType(), False),
    StructField("value", StringType(), False)
])

In [121]:
########################
#   HELPER FUNCTIONS   #
########################
def printSparkDetails(spark):
    print("Details of SparkContext:")
    print(f"App Name : {spark.sparkContext.appName}")
    print(f"Master : {spark.sparkContext.master}")

def loadDataFrameGroupDefinition():
    with open(GROUP_DEFINITION_PATH) as group_definitions_file:
        return SPARK.createDataFrame(
            json.load(group_definitions_file).items(),
            GROUP_DEFINITION_SCHEMA
        )

def loadDataFrameFasta():
    return SPARK.read.schema(FASTA_SCHEMA).json(FASTA_PATH)

def shingle_int(shingle):
    return sum(
        (ord(aminoacid) - ord("A")) * (SHINGLE_BASE ** exp)
        for exp, aminoacid in enumerate(shingle[::-1])
    )

In [122]:
################
#   WORKLOAD   #
################
def getMinhashesOfBands(value):
    shingle_int_arr = np.array([
        shingle_int(value[i : i + SHINGLE_SIZE])
        for i in range(len(value) - SHINGLE_SIZE + 1)
    ])

    signature_arr = np.array([
        np.min((a * shingle_int_arr + b) % HASH_MOD)
        for a, b in PERMUTATION_ARR
    ])

    signature_batch_hash_arr = np.array([
        hash(tuple(signature_arr[i : i + ROW_SIZE]))
        for i in range(0, PERMUTATION_COUNT, ROW_SIZE)
    ])

    return list(signature_batch_hash_arr.tolist())

In [None]:
###########
#   RUN   #
###########
printSparkDetails(SPARK)

df_group_definition = loadDataFrameGroupDefinition()
df_group_definition.printSchema()
df_group_definition.show()

udf_get_minhashes_of_bands = PySQL.udf(getMinhashesOfBands, ArrayType(IntegerType(), False))

df_fasta = loadDataFrameFasta()
df_fasta = df_fasta.withColumn("minhash_band_signature_list", udf_get_minhashes_of_bands("value"))
df_fasta.printSchema()
df_fasta.show()

%time _ = df_fasta.collect()

#################
#   SPEEDTEST   #
#################
# print("smokin' time!")
# gainz = ""
# with open("data/bruh_fasta/A0A1P8XQ85.json") as f:
#     gainz = json.load(f)["value"]
# %time _ = getMinhashesOfBands(gainz)
