In [1]:
import pyspark.sql.functions as f
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, StructType, IntegerType, StringType

In [2]:
spark = SparkSession.builder.getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/26 10:45:37 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
evidence = (
    spark.read.parquet("../../evidence")
    .filter(f.col("variantId").isNotNull())
    .withColumn('chr', f.split(f.col('variantId'), '_').getItem(0))
    .withColumn('genomicLocation', f.split(f.col('variantId'), '_').getItem(1))
    .groupBy('chr', 'genomicLocation')
    .agg(
        f.collect_set(f.struct(f.col('variantId'), f.col('diseaseId'), f.col('diseaseFromSource'))).alias('evidenceInfo')
    )
)

In [4]:
evidence.count()

                                                                                

756504

In [5]:
evidence.show(truncate=False)



+---+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|chr|genomicLocation|evidenceInfo                                                                                                                                                                                                                       |
+---+---------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1  |100188948      |[{1_100188948_C_G, Orphanet_511, Maple syrup urine disease}]                                                                                                                                                                       |


                                                                                

In [6]:
gen_location = (
    spark.read.json("residue_gen_pos_output/residue_genomic_position.json")
)

gen_location = (
    gen_location
    .withColumn("pdbCompound", gen_location["resInfos.compound"])
    .withColumn("resNb", gen_location["resInfos.res_nb"])
    .withColumn("chain", gen_location["resInfos.chain"])
    .withColumn("resType", gen_location["resInfos.res_type"])
    .withColumn("interType", gen_location["resInfos.inter_type"])
    .withColumn("chr", gen_location["resInfos.chromosome"])
    .withColumn("genLocation_1", gen_location["resInfos.genLocation.res_pos_1"])
    .withColumn("genLocation_2", gen_location["resInfos.genLocation.res_pos_2"])
    .withColumn("genLocation_3", gen_location["resInfos.genLocation.res_pos_3"])
    .drop("resInfos")
)

                                                                                

In [7]:
unpivot_expression = '''stack(3, 'genLocation_1', genLocation_1, 'genLocation_2', genLocation_2, 'genLocation_3', genLocation_3) as (genLocation_label, genLocation_val)'''

gen_location_unpivot = (
    gen_location
        .select('geneId', 'pdbStructId', 'pdbCompound', 'resNb', 'chain', 'resType', 'chr', f.expr(unpivot_expression))
)

In [8]:
gen_location.count()

                                                                                

262855

In [9]:
gen_location_unpivot.count()  # More lines but less operations

                                                                                

788565

In [8]:
gen_location_unpivot.show()

[Stage 11:>                                                         (0 + 1) / 1]

+---------------+-----------+-----------+-----+-----+-------+---+-----------------+---------------+
|         geneId|pdbStructId|pdbCompound|resNb|chain|resType|chr|genLocation_label|genLocation_val|
+---------------+-----------+-----------+-----+-----+-------+---+-----------------+---------------+
|ENSG00000001626|       1xmj|        ATP|  466|    A|    SER|  7|    genLocation_1|      117559467|
|ENSG00000001626|       1xmj|        ATP|  466|    A|    SER|  7|    genLocation_2|      117559468|
|ENSG00000001626|       1xmj|        ATP|  466|    A|    SER|  7|    genLocation_3|      117559469|
|ENSG00000001626|       1xmj|        ATP|  465|    A|    THR|  7|    genLocation_1|      117559464|
|ENSG00000001626|       1xmj|        ATP|  465|    A|    THR|  7|    genLocation_2|      117559465|
|ENSG00000001626|       1xmj|        ATP|  465|    A|    THR|  7|    genLocation_3|      117559466|
|ENSG00000001626|       1xmj|        ATP|  464|    A|    LYS|  7|    genLocation_1|      117548821|


                                                                                

In [9]:
res_with_disease = (
    gen_location_unpivot.join(
        evidence, 
        (gen_location_unpivot.chr == evidence.chr) &
        (gen_location_unpivot.genLocation_val == evidence.genomicLocation)
    )
)

In [12]:
res_with_disease.count()

                                                                                

12728

In [10]:
res_with_disease.show()

22/04/26 10:46:09 ERROR Executor: Exception in task 39.0 in stage 13.0 (TID 413)
java.lang.OutOfMemoryError: Java heap space
22/04/26 10:46:09 ERROR Executor: Exception in task 104.0 in stage 13.0 (TID 478)
java.lang.RuntimeException: Cannot reserve additional contiguous bytes in the vectorized reader (requested 16392 bytes). As a workaround, you can reduce the vectorized reader batch size, or disable the vectorized reader, or disable spark.sql.sources.bucketing.enabled if you read from bucket table. For Parquet file format, refer to spark.sql.parquet.columnarReaderBatchSize (default 4096) and spark.sql.parquet.enableVectorizedReader; for ORC file format, refer to spark.sql.orc.columnarReaderBatchSize (default 4096) and spark.sql.orc.enableVectorizedReader.
	at org.apache.spark.sql.execution.vectorized.WritableColumnVector.throwUnsupportedException(WritableColumnVector.java:113)
	at org.apache.spark.sql.execution.vectorized.WritableColumnVector.reserve(WritableColumnVector.java:93)
	at

ConnectionRefusedError: [Errno 111] Connection refused

In [38]:
res_with_disease_cleaned = (

        res_with_disease

        .drop("genLocation_1", "genLocation_2", "genLocation_3", "targetId", "chr", "genomicLocation")

        .groupby([f.col('geneId'),
                f.col('pdbStructId'),
                f.col("resNb"),
                f.col("resType"),
                f.col("interType"),
                f.col("diseaseId"),
                f.col("diseaseFromSource")
                ])

        .agg(
                f.collect_set(f.struct(f.col('variantId'))).alias("variantIds"),
                f.collect_set(f.col('pdbStructId')).alias("pdbStructIds"),
        )
)

In [37]:
res_with_disease_cleaned.show()

+---------------+-----------+-----+-------+-------------+-------------+--------------------+--------------------+------------+
|         geneId|pdbStructId|resNb|resType|    interType|    diseaseId|   diseaseFromSource|          variantIds|pdbStructIds|
+---------------+-----------+-----+-------+-------------+-------------+--------------------+--------------------+------------+
|ENSG00000001626|       1xmi|  401|    TRP|      pistack| Orphanet_586|     Cystic fibrosis|[{7_117542101_G_A...|    [{1xmi}]|
|ENSG00000001626|       1xmi|  462|    ALA|        hbond| Orphanet_586|     Cystic fibrosis| [{7_117548815_G_A}]|    [{1xmi}]|
|ENSG00000001626|       1xmi|  464|    LYS|        hbond| Orphanet_586|     Cystic fibrosis|[{7_117542108_G_C...|    [{1xmi}]|
|ENSG00000001626|       1xmi|  464|    LYS|        hbond| Orphanet_676|Hereditary pancre...| [{7_117548823_G_T}]|    [{1xmi}]|
|ENSG00000001626|       1xmi|  464|    LYS|   saltbridge| Orphanet_586|     Cystic fibrosis|[{7_117542108_G_C..

In [25]:
# MOLECULES
molecule_df = (
    spark.read
    .parquet("../../molecule/")
    .select(
        f.col('inchiKey').alias('inchikey'), 'name'
        #, 'linkedTargets', 'linkedDiseases'
    )
    .persist()
)

In [27]:
# INCHIKEY MOLECULES
inchikey_df = (
    spark.read
        .csv("../../inchikey/components_inchikeys.csv", sep=',', header=True, comment='#')
        .select(
            f.col('InChIKey').alias('inchikey'), 
            f.col('CCD_ID').alias('pdbCompound')
        )
    .persist()
)

In [28]:
# MOLECULE WITH COMPOUND ID
molecules_inchikey_join = (
    molecule_df
    .join(inchikey_df, on='inchikey')
    .drop("inchikey")
    .persist()
)

In [29]:
molecules_inchikey_join.show()

+--------------------+-----------+
|                name|pdbCompound|
+--------------------+-----------+
|(1-Phenylcyclopen...|        007|
|        CHEMBL381806|        008|
|      BENZYL ALCOHOL|        010|
|           DARUNAVIR|        017|
|               N6022|        022|
|        CHEMBL243940|        024|
|         CHEMBL55264|        028|
|         VEMURAFENIB|        032|
|       CHEMBL1213083|        039|
|             TAK-285|        03P|
|          PRINABEREL|        041|
|        CHEMBL478524|        047|
|         AMINOPTERIN|        04J|
|       CHEMBL1229525|        057|
|          LASMIDITAN|        05X|
|          GSK-256066|        066|
|        INFIGRATINIB|        07J|
|        CHEMBL305178|        084|
|    SULFAMETHOXAZOLE|        08D|
|          ALPRAZOLAM|        08H|
+--------------------+-----------+
only showing top 20 rows



In [43]:
# COMPOUND NAME
disease_ass_comp_name = (
    molecules_inchikey_join
    .join(res_with_disease, on='pdbCompound')
)

                                                                                

+-----------+------------+---------------+-----------+-----+-----+-------+-------------+---+-------------+-------------+-------------+------------+---------------+-------------+-----------+--------------------+---+---------------+
|pdbCompound|        name|         geneId|pdbStructId|resNb|chain|resType|    interType|chr|genLocation_1|genLocation_2|genLocation_3|datasourceId|       targetId|    variantId|  diseaseId|   diseaseFromSource|chr|genomicLocation|
+-----------+------------+---------------+-----------+-----+-----+-------+-------------+---+-------------+-------------+-------------+------------+---------------+-------------+-----------+--------------------+---+---------------+
|        GDP|CHEMBL384759|ENSG00000136238|       1g4u|   17|    R|    THR|        hbond|  7|      6387227|      6387228|      6387229|         eva|ENSG00000136238|7_6387229_G_A|EFO_0009156|Intellectual disa...|  7|        6387229|
|        GDP|CHEMBL384759|ENSG00000136238|       1g4u|   16|    R|    LYS|  

In [44]:
disease_ass_comp_name.show(2, True, True)

ERROR:root:KeyboardInterrupt while sending command.>               (8 + 3) / 11]
Traceback (most recent call last):
  File "/Users/marinegirardey/miniforge3/envs/plip_env/lib/python3.8/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Users/marinegirardey/miniforge3/envs/plip_env/lib/python3.8/site-packages/py4j/clientserver.py", line 475, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/Users/marinegirardey/miniforge3/envs/plip_env/lib/python3.8/socket.py", line 669, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 