In [1]:
import pyspark.sql.functions as PySQL
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from random import randint
import numpy as np
import json

In [2]:
###############
#   GLOBALS   #
###############
SHINGLE_SIZE = 5
BAND_SIZE    = 20
ROW_SIZE     = 5

TEST = False
PERF = False

# Spark
SPARK = SparkSession \
    .builder \
    .master("local[*]") \
    .appName("mySession") \
    .getOrCreate()

# Shingles
SHINGLE_BASE = ord("Z") - ord("A") + 1

# Minhash
HASH_MOD = 1_000_000_007
PERMUTATION_COUNT = BAND_SIZE * ROW_SIZE
RAND_MAX = (2 ** 32) - 1
PERMUTATION_ARR_BROADCAST = SPARK.sparkContext.broadcast(
    np.array([
        (randint(1, RAND_MAX), randint(0, RAND_MAX))
        for _ in range(PERMUTATION_COUNT)
    ])
)

# Group definition file
GROUP_DEFINITION_PATH   = "data/test.json" if TEST else "data/bruh.json" if PERF else "data/group_definition.json"
GROUP_DEFINITION_SCHEMA = StructType([
    StructField("group", StringType(), False),
    StructField("protein_list", ArrayType(StringType(), False), False)
])

# Fasta directory
FASTA_PATH   = "data/test_fasta" if TEST else "data/bruh_fasta" if PERF else "data/fasta"
FASTA_SCHEMA = StructType([
    StructField("name", StringType(), False),
    StructField("value", StringType(), False)
])

your 131072x1 screen size is bogus. expect trouble
24/05/15 19:14:59 WARN Utils: Your hostname, DELL-laptop-14-5401 resolves to a loopback address: 127.0.1.1; using 172.20.97.216 instead (on interface eth0)
24/05/15 19:14:59 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/15 19:15:01 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
########################
#   HELPER FUNCTIONS   #
########################
def printSparkDetails(spark):
    print("Details of SparkContext:")
    print(f"App Name : {spark.sparkContext.appName}")
    print(f"Master : {spark.sparkContext.master}")

def loadDataFrameGroupDefinition():
    with open(GROUP_DEFINITION_PATH) as group_definitions_file:
        return SPARK.createDataFrame(
            json.load(group_definitions_file).items(),
            GROUP_DEFINITION_SCHEMA
        )

def loadDataFrameFasta():
    return SPARK.read.schema(FASTA_SCHEMA).json(FASTA_PATH)

def shingle_int(shingle):
    return sum(
        (ord(aminoacid) - ord("A")) * (SHINGLE_BASE ** exp)
        for exp, aminoacid in enumerate(shingle[::-1])
    )

In [4]:
################
#   WORKLOAD   #
################
def getMinhashesOfBands(value):
    shingle_int_arr = np.array([
        shingle_int(value[i : i + SHINGLE_SIZE])
        for i in range(len(value) - SHINGLE_SIZE + 1)
    ])

    signature_arr = np.array([
        np.min((a * shingle_int_arr + b) % HASH_MOD)
        for a, b in PERMUTATION_ARR_BROADCAST.value
    ])

    signature_batch_hash_arr = np.array([
        hash(tuple(signature_arr[i : i + ROW_SIZE]))
        for i in range(0, PERMUTATION_COUNT, ROW_SIZE)
    ])

    return enumerate(signature_batch_hash_arr.tolist())

minhash_tuple_type = StructType([
    StructField("minhash_id", IntegerType(), False),
    StructField("minhash_value", IntegerType(), False)
])

udf_get_minhashes_of_bands = PySQL.udf(getMinhashesOfBands, ArrayType(minhash_tuple_type,False))

In [5]:
###########
#   RUN   #
###########
printSparkDetails(SPARK)

df_group_definition = loadDataFrameGroupDefinition() \
    .select("group", PySQL.explode("protein_list").alias("protein"))

df_group_definition.printSchema()
df_group_definition.show()

df_fasta = loadDataFrameFasta()

df_fasta.printSchema()
df_fasta.show()

df_lsh = df_fasta \
    .withColumn("minhash_band_signature_list", udf_get_minhashes_of_bands("value")) \
    .select("name", "minhash_band_signature_list")

df_lsh.printSchema()
df_lsh.show(50)

df_lsh = df_lsh \
    .join(df_group_definition, df_fasta.name == df_group_definition.protein, "left") \
    .select("group", "name", "minhash_band_signature_list")

df_lsh.printSchema()
df_lsh.show(50)

df_lsh = df_lsh \
    .select("group", "name", PySQL.explode("minhash_band_signature_list").alias("minhash"))

# df_lsh = df_fasta \
    # .withColumn("minhash_band_signature_list", udf_get_minhashes_of_bands("value")) \
    # .select("name", "minhash_band_signature_list") \
    # .join(df_group_definition, df_fasta.name == df_group_definition.protein, "left") \
    # .select("group", "name", PySQL.explode("minhash_band_signature_list").alias("minhash"))

df_lsh.printSchema()
df_lsh.show(50)

Details of SparkContext:
App Name : mySession
Master : local[*]
root
 |-- group: string (nullable = false)
 |-- protein: string (nullable = false)



                                                                                

+---------------+----------+
|          group|   protein|
+---------------+----------+
|UniRef50_Q8WZ42|    A2ASS6|
|UniRef50_Q8WZ42|  Q8WZ42-8|
|UniRef50_Q8WZ42|A0A2J8PRG4|
|UniRef50_Q8WZ42|A0A2J8PRH0|
|UniRef50_Q8WZ42|A0A2J8VRI6|
|UniRef50_Q8WZ42|A0A2J8VRF7|
|UniRef50_Q8WZ42|A0A8I5U7Y9|
|UniRef50_Q8WZ42|  Q8WZ42-2|
|UniRef50_Q8WZ42|A0A0C4DG59|
|UniRef50_Q8WZ42|  Q8WZ42-7|
|UniRef50_Q8WZ42|A0A2J8PRG6|
|UniRef50_Q8WZ42|A0A2J8VRH1|
|UniRef50_Q8WZ42|    C0JYZ2|
|UniRef50_Q8WZ42| Q8WZ42-11|
|UniRef50_Q8WZ42|    H2P803|
|UniRef50_Q8WZ42|A0A6P8QXT8|
|UniRef50_Q8WZ42|A0A6P8RJ11|
|UniRef50_Q8WZ42|A0A8B7X843|
|UniRef50_Q8WZ42|A0A091R2T7|
|UniRef50_Q8WZ42|A0A7L3SFD7|
+---------------+----------+
only showing top 20 rows

root
 |-- name: string (nullable = true)
 |-- value: string (nullable = true)

+-------------+--------------------+
|         name|               value|
+-------------+--------------------+
|   A0AA41SNZ1|GKWVQLQLAESQPNLLE...|
|   A0A6B0RPA5|MSSQESPAVEFSTTTVS...|
|UPI00295AB97A

                                                                                

+-------------+---------------------------+
|         name|minhash_band_signature_list|
+-------------+---------------------------+
|   A0AA41SNZ1|       [{0, 681660500}, ...|
|   A0A6B0RPA5|       [{0, -988763031},...|
|UPI00295AB97A|       [{0, 575693548}, ...|
|UPI0011814A49|       [{0, -1865848111}...|
|   A0A6P7YNV3|       [{0, 1495276144},...|
|   A0A6P8RG40|       [{0, -642710800},...|
|UPI001C675CD7|       [{0, 684897736}, ...|
|UPI000D72069A|       [{0, -1135013108}...|
|UPI0023A8EE64|       [{0, 684897736}, ...|
|UPI00202717A7|       [{0, -968338580},...|
|   A0A6P8QZN4|       [{0, -642710800},...|
|UPI002AC85FA8|       [{0, -1315202270}...|
|UPI002AC82E1F|       [{0, -1315202270}...|
|UPI002AC7EF71|       [{0, -1315202270}...|
|UPI002AC886BD|       [{0, -1315202270}...|
|UPI002AC7FA2A|       [{0, -1315202270}...|
|UPI002AC846FB|       [{0, -1315202270}...|
|UPI002AC86ACB|       [{0, -1315202270}...|
|UPI002AC88D55|       [{0, -1315202270}...|
|UPI002AC852B2|       [{0, -1315

                                                                                

+-------------------+-------------+---------------------------+
|              group|         name|minhash_band_signature_list|
+-------------------+-------------+---------------------------+
|    UniRef50_Q8WZ42|UPI002AC7EE33|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|   A0A6B0RPA5|       [{0, -988763031},...|
|    UniRef50_Q8WZ42|   A0A6P8RJ11|       [{0, -642710800},...|
|    UniRef50_Q8WZ42|UPI000D72069A|       [{0, -1135013108}...|
|    UniRef50_Q8WZ42|UPI002AC8541B|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|UPI00202717A7|       [{0, -968338580},...|
|    UniRef50_Q8WZ42|UPI002AC846FB|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|UPI002AC84FB1|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|UPI002AC80B36|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|UPI002AC852B2|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|   A0A6P7YNV3|       [{0, 1495276144},...|
|    UniRef50_Q8WZ42|UPI002AC7FA2A|       [{0, -1315202270}...|
|    UniRef50_Q8WZ42|UPI002AC80DEC|     

                                                                                

+---------------+-------------+-----------------+
|          group|         name|          minhash|
+---------------+-------------+-----------------+
|UniRef50_Q8WZ42|UPI002AC7EE33| {0, -1315202270}|
|UniRef50_Q8WZ42|UPI002AC7EE33|   {1, 459268878}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {2, 1730365586}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {3, -455073897}|
|UniRef50_Q8WZ42|UPI002AC7EE33| {4, -1389168993}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {5, 2053515549}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {6, 1407119065}|
|UniRef50_Q8WZ42|UPI002AC7EE33| {7, -1534999881}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {8, 1925077339}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {9, -131755876}|
|UniRef50_Q8WZ42|UPI002AC7EE33|   {10, 78398550}|
|UniRef50_Q8WZ42|UPI002AC7EE33|{11, -1588537335}|
|UniRef50_Q8WZ42|UPI002AC7EE33|{12, -1873022898}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {13, 901077080}|
|UniRef50_Q8WZ42|UPI002AC7EE33|  {14, 607353332}|
|UniRef50_Q8WZ42|UPI002AC7EE33| {15, -527860956}|
|UniRef50_Q8WZ42|UPI002AC7EE33| {16, -778278285}|


In [6]:
########################
#   PERFORMANCE TEST   #
########################
print("cooking time!")
gainz = ""
with open("data/bruh_fasta/A0A1P8XQ85.json") as f:
    gainz = json.load(f)["value"]
%timeit _ = getMinhashesOfBands(gainz)
# %time   _ = df_fasta.collect()

cooking time!
42.3 ms ± 3.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
