In [None]:
#### ####
import findspark
findspark.init("/opt/homebrew/Cellar/apache-spark/3.3.0/libexec")

import pyspark
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql import Window

from psutil import virtual_memory
from pyspark import SparkFiles
from pyspark.conf import SparkConf
from pyspark.sql.functions import col
from pyspark.sql.types import StructType, StructField, StringType


def detect_spark_memory_limit():
    """Spark does not automatically use all available memory on a machine. When working on large datasets, this may
    cause Java heap space errors, even though there is plenty of RAM available. To fix this, we detect the total amount
    of physical memory and allow Spark to use (almost) all of it."""
    mem_gib = virtual_memory().total >> 30
    return int(mem_gib * 0.9)


spark_mem_limit = detect_spark_memory_limit()
spark_conf = (
    SparkConf()
    .set("spark.driver.memory", f"{spark_mem_limit}g")
    .set("spark.executor.memory", f"{spark_mem_limit}g")
    .set("spark.driver.maxResultSize", "0")
    .set("spark.debug.maxToStringFields", "2000000000")
    .set("spark.sql.execution.arrow.maxRecordsPerBatch", "500000")
    ###.set("spark.executor.heartbeatInterval", "3600s")
    .set(
        "spark.sql.execution.arrow.pyspark.enabled", "true"
    )  ## esto lo pongo por esto: https://stackoverflow.com/questions/69973790/pyspark-spark-sparkexception-job-aborted-due-to-stage-failure-task-0-in-stage
    .set("spark.ui.showConsoleProgress", "false")
)


spark = (
    SparkSession.builder.config(conf=spark_conf)
    .master("local[*]")
    .config("spark.driver.bindAddress", "127.0.0.1")
    .config("spark.driver.host","localhost") ### Run locally 
    .getOrCreate()
)

In [None]:
import pyspark.sql.functions as F
import pandas as pd

### Load datasets (target, interact db and molecule) with the last downloaded version (december 2022)
interactors ="/Users/juanr/Desktop/Target_Engine/data_download/december2022/interaction/"
interact_db = spark.read.parquet(interactors)
                                 
disease_path = "/Users/juanr/Desktop/Target_Engine/data_download/december2022/diseases/"
diseases = spark.read.parquet(disease_path)

target_path = "/Users/juanr/Desktop/Target_Engine/data_download/december2022/targets/"
target = spark.read.parquet(target_path)

molecule_path = "/Users/juanr/Desktop/MR_Maya/Downloaded_20230110/molecule/"
molecule = spark.read.parquet(molecule_path)

indication_path = "/Users/juanr/Desktop/MR_Maya/Downloaded_20230110/indication/"
indication = spark.read.parquet(indication_path)

indirecAssoc_path = "/Users/juanr/Desktop/MR_Maya/Downloaded_20230110/associationByDatasourceIndirect/"
indirecAssoc = spark.read.parquet(indirecAssoc_path)

## load symbols to complement EnsemblID
symbol = target.select("id", "approvedSymbol")

## take diseases id with therapeutic areas
diseases_name = diseases.select(
    F.col("id").alias("efoDisease"),
    F.col("name").alias("diseaseName"),
    F.col("therapeuticAreas"),
)

##### 

#### Load Direct Score data
overallDirecAssocScore_path = "/Users/juanr/Desktop/MR_Maya/Downloaded_20230110/assocOverallDirectJanuary2023/associationByOverallDirect"
overallDirecAssocScore = spark.read.parquet(overallDirecAssocScore_path)
#### annotate Direct score
oaDirectScore = overallDirecAssocScore.select(
    F.col("targetId").alias("targetIdoaDirect"),
    F.col("diseaseId").alias("diseaseIdDirect"),
    F.col("score").alias("scoreDirect"),
)

#### Load Direct overall
direcAssoc_path ="/Users/juanr/Desktop/MR_Maya/Downloaded_20230110/associationByDatasourceDirect/"
direcAssoc = spark.read.parquet(direcAssoc_path)
#### annotate which are the supporting datasourceId
targetDirectAssoc = (
    direcAssoc.groupBy("targetId", "diseaseId")
    .agg(F.collect_set("datasourceId").alias("datasourceId"))
    .select(
        F.col("targetId").alias("targetIdDirect"),
        F.col("diseaseId").alias("diseaseIdDirect"),
        F.col("datasourceId").alias("datasourceIdDirect"),
    )
)

##### Dataset format curation #####

#### Read original dataset with all target-trait pairs per study
path = "/Users/juanr/Desktop/MR_Maya/MRdrug20230123_Analysis.tsv"
df2 = spark.read.csv(path, sep=r"\t", header=True)

## convert from string to array
convert = [
    "adverse_effects",
    "adverse_effects_studies",
    "adverse_effects_bxy",
    "mech",
    "mech_studies",
    "mech_coloc_h4",
    "mech_bxy",
]
## transform from multiple string to one
transform = ["adverse_effects_coloc_h4", "outcome_trait"]

df2 = df2.select(
    *[F.split(col, ";").alias(col) if col in convert else col for col in df2.columns]
).select(
    *[
        F.concat(F.lit('"'), som, F.lit('"')).alias(som) if som in transform else som
        for som in df2.columns
    ]
)

df = df2.select(
    F.col("curated_ensid").alias("ensid"),
    "indice",
    "mergedOutcomeTraitEfo2",
    "protein_datasets",
    "outcome_datasets",
)

##### Join therapeutic areas #####

### Build efo terms as array to explode later. 
### Select columns with therapeutic information coming from diseases_name

prequeryset = (
    df.withColumn("efo_array", F.split(F.col("mergedOutcomeTraitEfo2"), ", "))
    .withColumn("efo_ensid_individual_final", F.explode_outer(F.col("efo_array")))
    .join(
        diseases_name,
        F.col("efoDisease") == F.col("efo_ensid_individual_final"),
        "left",
    )
    .select(
        "ensid",
        "mergedOutcomeTraitEfo2",
        "indice",
        F.col("efo_ensid_individual_final").alias("efo_ensid"),
        F.explode_outer(F.col("therapeuticAreas")).alias("therapeuticAreas"),
        "protein_datasets",
        "outcome_datasets",
        "efo_array",
    )
)

## Build a queryset with the therapeutic Areas to evaluate theses.
queryset = prequeryset.groupBy(
    "ensid", "indice", "outcome_datasets", "protein_datasets", "efo_ensid"
).agg(
    F.collect_set("therapeuticAreas").alias("therapeuticAreas"),
)

# take diseases with therapeutic areas to join with target B
diseasesB = diseases.select(
    F.col("id").alias("id_B"),
    F.col("therapeuticAreas").alias("therAreasB"),
    F.col("name").alias("diseaseNameB"),
)

# take diseases with therapeutic areas to join with target A
diseasesA = diseases.select(
    F.col("id").alias("id_A"),
    F.col("therapeuticAreas").alias("therAreasA"),
    F.col("name").alias("diseaseNameA"),
)
### get maxClinPhase for indication
indicationsToJoin = indication.select(
    "id", F.explode_outer(F.col("indications")).alias("indications")
).select(
    "id",
    F.col("indications.disease").alias("indicatedDisease"),
    F.col("indications.efoName").alias("indicatedEfoName"),
    F.col("indications.maxPhaseForIndication").alias("indicatedMaxPhaseIndication"),
)
## obtaining chembl related to targets, therapeutic areas of diseases
# from target-disease relation
tar_group = (
    molecule.select(
        F.col("id").alias("chemblIdTargetB"),
        F.col("name").alias("drugNameTargetB"),
        F.col("maximumClinicalTrialPhase").alias("maxClinTrialPhaseTargetB"),
        F.col("linkedDiseases"),
        F.explode_outer("linkedTargets.rows").alias("chemblLinkedTargetB"),
    )
    .select(
        F.col("chemblIdTargetB"),
        F.col("drugNameTargetB"),
        F.col("maxClinTrialPhaseTargetB"),
        F.col("chemblLinkedTargetB"),
        F.explode_outer("linkedDiseases.rows").alias("diseaseLinkedChemblTargetB"),
    )
    .join(
        indicationsToJoin,
        (indicationsToJoin.id == F.col("chemblIdTargetB"))
        & (indicationsToJoin.indicatedDisease == F.col("diseaseLinkedChemblTargetB")),
        "left",
    )
    .join(diseasesB, F.col("diseaseLinkedChemblTargetB") == diseasesB.id_B, "left")
    .groupBy(
        "chemblLinkedTargetB",
        "id_B",
        "therAreasB",
        "drugNameTargetB",
        F.col("indicatedDisease").alias("indicatedDiseaseB"),
        "maxClinTrialPhaseTargetB",
        F.col("indicatedMaxPhaseIndication").alias("indicatedMaxPhaseIndicationB"),
    )
    .agg(F.count("chemblLinkedTargetB"))
)
### drugs for A. In theory is not required. 
tar_group2 = (
    molecule.select(
        F.col("id").alias("chemblIdTargetA"),
        F.col("name").alias("drugNameTargetA"),
        F.col("maximumClinicalTrialPhase").alias("maxClinTrialPhaseTargetA"),
        F.col("linkedDiseases"),
        F.explode_outer("linkedTargets.rows").alias("chemblLinkedTargetA"),
    )
    .select(
        F.col("chemblIdTargetA"),
        F.col("drugNameTargetA"),
        F.col("maxClinTrialPhaseTargetA"),
        F.col("chemblLinkedTargetA"),
        F.explode_outer("linkedDiseases.rows").alias("diseaseLinkedChemblTargetA"),
    )
    .join(
        indicationsToJoin,
        (indicationsToJoin.id == F.col("chemblIdTargetA"))
        & (indicationsToJoin.indicatedDisease == F.col("diseaseLinkedChemblTargetA")),
        "left",
    )
    .join(diseasesA, F.col("diseaseLinkedChemblTargetA") == diseasesA.id_A, "left")
    .groupBy(
        "chemblLinkedTargetA",
        "id_A",
        "therAreasA",
        "drugNameTargetA",
        F.col("indicatedDisease").alias("indicatedDiseaseA"),
        "maxClinTrialPhaseTargetA",
        F.col("indicatedMaxPhaseIndication").alias("indicatedMaxPhaseIndicationA"),
    )
    .agg(F.count("chemblLinkedTargetA"))
)

### Get interactors using IntAct database
# filter by 0.42 score & add linked CHEMBL (tar_group) of the partners

look_other = (
    interact_db.filter(F.col("sourceDatabase") == "intact")
    .select("sourceDatabase", "targetA", "targetB", "scoring")
    .filter(F.col("scoring") > "0.42")
    .join(queryset, queryset.ensid == F.col("targetA"), "right")
    .join(tar_group, F.col("targetB") == tar_group.chemblLinkedTargetB, "left")
    .join(tar_group2, F.col("targetA") == tar_group2.chemblLinkedTargetA, "left") ### remove
    .withColumn("therAreasDiseases", F.explode_outer(F.col("therapeuticAreas")))
    .withColumn("therAreasTargetA", F.explode_outer(F.col("therAreasA"))) ### remove
    .withColumn("therAreasTargetB", F.explode_outer(F.col("therAreasB")))
)

In [None]:
first_step = (
    look_other.join(symbol, F.col("targetB") == symbol.id, "left")
    ###  Assess if the efo term is similar: 
    .withColumn("sameEfoensid_A",
        F.when( 
            (F.col("id_A").isNotNull()) &
            (F.col("efo_ensid")==F.col("id_A")), 
            F.lit('sameDisease'))
        .when(
            (F.col("id_A").isNotNull()) &
            (F.col("efo_ensid")!=F.col("id_A")), F.lit('difDisease'))
        .otherwise(F.lit('noDisease')))
    
    .withColumn("sameEfoensid_B",
        F.when( 
            (F.col("id_B").isNotNull()) &
            (F.col("efo_ensid")==F.col("id_B")), 
            F.lit('sameDisease'))
        .when(
            (F.col("id_B").isNotNull()) &
            (F.col("efo_ensid")!=F.col("id_B")), F.lit('difDisease'))
        .otherwise(F.lit('noDisease')))
    ## Assess if therapeuticAreas are the same
    .withColumn(
        "coincident_A_string",
        F.when(
            F.col("therAreasTargetA").isNull(), F.lit('noTherAreaTargetA'))
        .when (
            (
                (F.col("therAreasTargetA").isNotNull()) &
                (F.col('therAreasDiseases')== F.col("therAreasTargetA")) 
            ), 
            F.lit("coincident"))
        .when ( 
            F.col("therAreasDiseases") != F.col("therAreasTargetA"),
            F.lit("nonCoincident"),
        )
        .otherwise(F.lit("dif")),
    )
    .withColumn(
        "coincident_B_string",
        F.when(
            F.col("therAreasTargetB").isNull(), F.lit('noTherAreaTargetB'))
        .when (
            (   
                (F.col("therAreasTargetB").isNotNull()) &
                (F.col('therAreasDiseases')== F.col("therAreasTargetB"))   
            ), 
            F.lit("coincident"))
        .when ( 
            F.col("therAreasDiseases") != F.col("therAreasTargetB"),
            F.lit("nonCoincident"),
        )
        .otherwise(F.lit("dif")),
    )
    ## Write drug from A if therapy areas are the same as disease
    .withColumn(
        "drugFromA_string",
        F.when(
            F.col("coincident_A_string") == "coincident",
            F.concat_ws(
                "_", F.col("drugNameTargetA"), F.col("indicatedMaxPhaseIndicationA")
            ),
        ).otherwise(  ##F.col('therAreasA')))
            F.lit(None)
        ),
    )
    ## Write drug from B if therapy areas are in the same as disease
    .withColumn(
        "drugFromB_string",
        F.when(
            (F.col('targetB').isNotNull()) &
            (F.col('ensid') != F.col('targetB')) &
            (F.col("coincident_B_string") == "coincident"),
            F.concat_ws(
                "_",
                F.col("approvedSymbol"),
                F.col("drugNameTargetB"),
                F.col("indicatedMaxPhaseIndicationB"),
            ),
        ).otherwise(  ##F.col('therAreasB')))
            F.lit(None)
        ),
    )
    .withColumn(
        "drugFromA_repurposing_string",
        F.when(
            (F.col("coincident_A_string") == "nonCoincident")
            & (F.col("drugNameTargetA").isNotNull()),
            F.concat_ws(
                "_", F.col("drugNameTargetA"), F.col("indicatedMaxPhaseIndicationA")
            ),
        ).otherwise(  ##F.col('therAreasTargetA')))
            F.lit(None)
        ),
    )
        .withColumn(
        "drugFromBNovel_string",
        F.when(
            (F.col('targetB').isNotNull()) &
            (F.col('ensid') != F.col('targetB')) &
            (F.col("coincident_B_string") == "nonCoincident") &
            (F.col("therAreasB").isNotNull()),
            F.concat_ws(
                "_",
                F.col("approvedSymbol"),
                F.col("drugNameTargetB"),
                F.col("indicatedMaxPhaseIndicationB"),
            ),
        ).otherwise(  ##F.col('therAreasTargetB')))
            F.lit(None)
        ),
    )
    ## Write drug from A if efo_disease is the same as disease
    .withColumn(
        "drugFromA_efo",
        F.when(
            F.col("sameEfoensid_A") == "sameDisease",
            F.concat_ws(
                "_", F.col("drugNameTargetA"), F.col("indicatedMaxPhaseIndicationA"),F.col("id_A"))
            )
        .otherwise(F.lit(None)))
        
    ## Write drug from B if efo_disease is the same as disease
    .withColumn(
        "drugFromB_efo",
        F.when(
            (F.col('targetB').isNotNull()) &
            (F.col('ensid') != F.col('targetB')) & 
            (F.col("sameEfoensid_B") == "sameDisease"),
            F.concat_ws(
                "_",F.col("approvedSymbol"), F.col("drugNameTargetB"), F.col("indicatedMaxPhaseIndicationA"),F.col("id_B"))
            )
        .otherwise(F.lit(None)))

    .withColumn(
        "coincident_A",
        F.array_intersect(F.col("therAreasA"), F.col("therapeuticAreas")),
    )
    .withColumn(
        "coincident_B",
        F.array_intersect(F.col("therAreasB"), F.col("therapeuticAreas")),
    )
    .withColumn(
        "nonCoincidentA", F.array_except(F.col("therapeuticAreas"), F.col("therAreasA"))
    )
    .withColumn(
        "nonCoincidentB", F.array_except(F.col("therapeuticAreas"), F.col("therAreasB"))
    )
)

In [None]:
second_step=(first_step
.groupBy(
    F.col("ensid").alias("ensid2"),
    F.col("indice").alias("indice2"))
.agg(

    F.collect_set(F.col("drugFromA_efo")).alias("drugFromA_trait"),        
    F.collect_set(F.col("drugFromB_efo")).alias("drugFromB_trait"),
    F.collect_set(F.col("drugFromA_string")).alias("drugFromA_therArea"),
    F.collect_set(F.col("drugFromA_repurposing_string")).alias("drugFromA_repurposing_string"),
    F.collect_set(F.col("drugFromB_string")).alias("drugFromB_therArea"),
    F.collect_set(F.col("drugFromBNovel_string")).alias("drugFromBNovel_string"))
.withColumn(
    "drugFromARepurposing_therArea",
    F.array_except("drugFromA_repurposing_string", "drugFromA_therArea"),
)
.withColumn(
    "drugFromBNovel_therArea",
    F.array_except("drugFromBNovel_string", "drugFromB_therArea"),
)
)

In [None]:
support=(queryset
    .join(
        targetDirectAssoc,
        (targetDirectAssoc.targetIdDirect == F.col("ensid"))
        & (targetDirectAssoc.diseaseIdDirect == F.col("efo_ensid"))
        ,
        "left")
    .join(
        oaDirectScore, 
        (F.col("ensid")==oaDirectScore.targetIdoaDirect) &
        (F.col("efo_ensid")==oaDirectScore.diseaseIdDirect)
        ,
        "left")  
    .withColumn('datasourceIdDirect_take',
        F.explode(F.col('datasourceIdDirect')))
    .groupBy(
    F.col("ensid").alias("ensid3"),
    F.col("indice"))
    .agg(
        F.collect_set(F.col('datasourceIdDirect_take')).alias('datasourceIdDirect'),
        F.collect_set(F.col('scoreDirect')).alias('scoreDirect'),
    )
    ).repartition(20)

In [None]:
third_step=(second_step
.join(support, 
(support.ensid3==F.col('ensid2')) &
(support.indice==F.col('indice2')),'left')
.withColumnRenamed("indice",'indiceJoin')
)

In [None]:
fourth_step = (df2.join(
    third_step,
    (df2.indice == third_step.indiceJoin),
    "left"
))

fourth_step=fourth_step.drop(
    "ensid3",
    "ensid2",
    'indice2',
    'indiceJoin',
    "protein_datasets2",
    "protein_datasets3",
    "outcome_datasets2",
    "outcome_datasets3",
    "drugFromA_repurposing_string",
    "drugFromBNovel_string")