In [22]:
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from datetime import datetime
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    DecimalType,
    StringType,
    FloatType,
)
from functions import directionOfEffect
from pyspark.ml.functions import array_to_vector, vector_to_array

spark = SparkSession.builder.getOrCreate()

#### Can we use mouse phenotypes for: 
    >> Predict Safety Liabilities
    >> Drug warnings (Black Box Warnings)
    >> Withdrawn drugs

In [23]:
mophe_path = "gs://open-targets-data-releases/24.09/output/etl/parquet/mousePhenotypes"
mophe = spark.read.parquet(mophe_path)

target_path = "gs://open-targets-data-releases/24.09/output/etl/parquet/targets/"
target = spark.read.parquet(target_path)

mopheScore_path = "gs://ot-team/jroldan/20230825_mousePheScores.csv"
mopheScore = spark.read.csv(mopheScore_path, header=True)

In [24]:
def harmonic_sum(evidence_scores):
    harmonic_sum = sum(
        score / ((i + 1) ** (2)) for i, score in enumerate(evidence_scores)
    )
    return harmonic_sum


def max_harmonic_sum(evidence_scores):
    max_theoretical_harmonic_sum = sum(
        1 / ((i + 1) ** (2)) for i in range(len(evidence_scores))
    )
    return max_theoretical_harmonic_sum

#### Make dataset of targets with safety Liabilities.
    There are multiple safety liabitily types

>> See how many differents datasource of safety liabilities 

In [12]:
(
    target.select("id", F.explode_outer(F.col("safetyLiabilities")).alias("safeLiable"))
    .select("id", "safeLiable.*")
    .groupBy("datasource")
    .count()
).show(truncate=False)



+---------------------+-----+
|datasource           |count|
+---------------------+-----+
|Force et al. (2011)  |47   |
|null                 |62180|
|Lamore et al. (2017) |30   |
|Brennan et al. (2024)|210  |
|AOP-Wiki             |227  |
|Lynch et al. (2017)  |1341 |
|Bowes et al. (2012)  |313  |
|Urban et al. (2012)  |254  |
|ToxCast              |375  |
|PharmGKB             |1661 |
+---------------------+-----+



                                                                                

#### Remove datasources of Brennan et al (2024) and PharmGKB (Pharmacogenetics)


#### safety liabilities without Brennan and PGx

In [27]:
safetyLiability = (
    target.select("id", F.explode_outer(F.col("safetyLiabilities")).alias("safeLiable"))
    .select("id", "safeLiable.*")
    .filter(~F.col("datasource").isin(["Brennan et al. (2024)", "PharmGKB"]))
    .select("id", "event")
    .groupBy("id")
    .agg(F.count("event").alias("nr"))
    .sort(F.col("nr").asc())
)
n = safetyLiability.count()
print("There are", n, "targets with safety liabilities excluding Brennan 2024 and PGx")



There are 503 targets with safety liabilities excluding Brennan 2024 and PGx


                                                                                

#### genetic constraint and lof_tolerance score

In [28]:
constr = (
    target.select("id", F.explode_outer(F.col("constraint")).alias("genConstraint"))
    .select("id", "genConstraint.*")
    .filter(
        F.col("constraintType")
        == "lof"
        # ).sort(F.col("upperRank").desc()
    )
    .select(
        "id",
        "upperRank",
        "oe",
        F.col("score").alias("scoreConstraint"),
    )
)
minUpperRank = (
    target.select(F.col("id").alias("constr_id"), F.explode("constraint"))
    .select(F.col("col.*"))
    .filter(F.col("constraintType") == "lof")
    .groupBy("constraintType")
    .agg(F.min("upperRank").alias("upperRank"))
    .select("upperRank")
    .rdd.flatMap(lambda x: x)
    .collect()[0]
)

maxUpperRank = (
    target.select(F.col("id").alias("constr_id"), F.explode("constraint"))
    .select(F.col("col.*"))
    .filter(F.col("constraintType") == "lof")
    .groupBy("constraintType")
    .agg(F.max("upperRank").alias("upperRank"))
    .select("upperRank")
    .rdd.flatMap(lambda x: x)
    .collect()[0]
)

loftolerance = (
    target.select(F.col("id").alias("constr_id"), F.explode("constraint"))
    .select(F.col("constr_id"), F.col("col.*"))
    .filter(F.col("constraintType") == "lof")
    .withColumn(
        "cal_score",
        F.lit(
            (2 * ((F.col("upperRank") - minUpperRank) / (maxUpperRank - minUpperRank)))
            - 1
        ),
    )
    .selectExpr("constr_id as id", "cal_score", "constraintType")
)

                                                                                

#### mouse score

In [29]:
mousePhenoScoreFilter = mopheScore.select(
    F.col("id").alias("idLabel"),
    F.col("label").alias("phenoLabel"),
    F.col("score"),
).withColumn(
    "curatedScore",
    F.when(F.col("score") == 0.0, F.lit(0)).otherwise(F.lit(F.col("score"))),
)

#### mouse score per target

In [30]:
scoreCalc_list = (
    mophe.select(
        "targetFromSourceId",
        F.explode_outer(F.col("modelPhenotypeClasses.id")).alias("id"),
    )
    .join(mousePhenoScoreFilter, F.col("id") == mousePhenoScoreFilter.idLabel, "left")
    ##.na.drop(subset=['scoreRevisado3'])
    .withColumn("score", F.col("curatedScore").cast(FloatType()))
    .groupBy("targetFromSourceId")
    .agg(array_to_vector(F.collect_list("score")).alias("score"))
    #    .join(constr.select("constr_id","upperRank"), constr.constr_id == F.col("targetFromSourceId"),"left")
)

#### Safety WO ToxCast

In [32]:
safetyWOToxCast = (
    target.select("id", F.explode_outer(F.col("safetyLiabilities")).alias("safeLiable"))
    .filter(
        ~F.col("safeLiable.datasource").isin(
            ["ToxCast", "Brennan et al. (2024)", "PharmGKB"]
        )
    )
    .select("id", "safeLiable.*")
    .select("id", "event")
    .groupBy("id")
    .agg(F.count("event").alias("noToxCast"))
    .sort(F.col("noToxCast").asc())
)

n = safetyWOToxCast.count()
print(
    "There are",
    n,
    " targets with safety liabilities excluding ToxCast, Brennan 2024 and PGx",
)



There are 262  targets with safety liabilities excluding ToxCast, Brennan 2024 and PGx


                                                                                

#### approved Drugs

In [33]:
approvedTargets = spark.read.csv(
    "gs://ot-team/jroldan/2013-2022_approvals.csv", header=True
).drop("_c0")
approved = approvedTargets.select("targetIds")


def remove_all_whitespace(col):
    return F.regexp_replace(col, "\\s+", "")


approvedTargets = (
    approved.withColumn("targets", (F.explode(F.split(F.col("targetIds"), ";"))))
    .withColumn("trimmed", F.trim(F.col("targets")))
    .selectExpr("trimmed as id")
    .withColumn("approved", F.lit("approved"))
    .distinct()
)

### Calculate Score from mouse phenotypes using harmonic sum

In [34]:
df_py = scoreCalc_list.toPandas()

values = []
for row in df_py["score"]:
    z = sorted(row, reverse=True)
    values.append(harmonic_sum(z))

maximumScore = 1.644

df_py["harmonicSum"] = values
normalised = []
for row in df_py["harmonicSum"]:
    new = row / (maximumScore)
    normalised.append(new)
df_py["harmonicSumNorm"] = normalised

## convert pandas to spark dataframe
df = spark.createDataFrame(df_py).withColumnRenamed("targetFromSourceId", "id")

                                                                                

#### Run DoE to have LoF and GoF drugs

In [7]:
evidences = (
    spark.read.parquet(
        "gs://open-targets-data-releases/24.06/output/etl/parquet/evidence"
    )
    .filter(
        F.col("datasourceId").isin(
            [
                "ot_genetics_portal",
                "gene_burden",
                "eva",
                "eva_somatic",
                "gene2phenotype",
                "orphanet",
                "cancer_gene_census",
                "intogen",
                "impc",
                "chembl",
            ]
        )
    )
    .persist()
)
platform_v = "24.09"
dataset = directionOfEffect(evidences, platform_v).persist()

24/09/27 19:57:58 WARN CacheManager: Asked to cache already cached data.
24/09/27 19:57:59 WARN CacheManager: Asked to cache already cached data.
24/09/27 19:57:59 WARN CacheManager: Asked to cache already cached data.


#### Load Drug warnings to take Black Box Warning Drugs and Withdrawn Drugs

In [35]:
drugWarning_path = (
    "gs://open-targets-data-releases/24.09/output/etl/parquet/drugWarnings"
)
drugwarnings = spark.read.parquet(drugWarning_path)

#### filter Drugs by inhibitors (LoF)
lofDrugs = (
    dataset.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "drugId", "homogenized")
    .filter(F.col("homogenized") == "LoF_protect")
    .groupBy("targetId", "drugId")
    .count()
)

lofDrugsWarnings = lofDrugs.join(
    drugwarnings.withColumn("drugId", F.explode_outer("chemblIds")).select(
        "drugId", "toxicityClass", "warningType"
    ),
    on="drugId",
    how="left",
).distinct()

#### From Drugs with BBW to Targets with BBW

In [36]:
suffix_bbw = "_BBW"

lofDrugWarningBBX = (
    lofDrugsWarnings.filter(F.col("warningType") == "Black Box Warning")
    .groupBy("targetId")
    .pivot("toxicityClass")
    .agg(F.collect_set("warningType"))
)

array_columns = lofDrugWarningBBX.columns[1:]

# The value to check in the arrays
value_to_check = "Black Box Warning"

# Create a new DataFrame with transformed columns
df_transformed = lofDrugWarningBBX

for col_name in array_columns:
    new_col_name = f"{col_name}_BBW"
    df_transformed = df_transformed.withColumn(
        new_col_name,
        F.when(F.array_contains(F.col(col_name), value_to_check), 1).otherwise(0),
    )

lofDrugWarningBBX_format = df_transformed.drop(*array_columns).withColumn(
    "allBBW", F.lit(1)
)

                                                                                

#### From Drugs with WithDrawn to Targets with withDrawn

In [37]:
lofDrugWarningWD = (
    lofDrugsWarnings.filter(F.col("warningType") == "Withdrawn")
    .groupBy("targetId")
    .pivot("toxicityClass")
    .agg(F.collect_set("warningType"))
)

array_columns = lofDrugWarningWD.columns[1:]

# The value to check in the arrays
value_to_check = "Withdrawn"

# Create a new DataFrame with transformed columns
df_transformed = lofDrugWarningWD

for col_name in array_columns:
    new_col_name = f"{col_name}_WD"
    df_transformed = df_transformed.withColumn(
        new_col_name,
        F.when(F.array_contains(F.col(col_name), value_to_check), 1).otherwise(0),
    )

lofDrugWarningWD_format = df_transformed.drop(*array_columns).withColumn(
    "allWD", F.lit(1)
)

                                                                                

#### Join Drugs BBW and Drugs WithDrawn

In [38]:
lofDrugsBBw_WD = (
    lofDrugWarningWD_format.join(lofDrugWarningBBX_format, on="targetId", how="outer")
    .withColumn("bwwAndwd", F.lit(1))
    .na.fill(0)
)

### Make dataset and prepare normalized harmonic sum 

In [39]:
### make dataset for comparisons
df_comparisons = (
    df.join(loftolerance, on="id", how="left")
    .join(safetyLiability, on="id", how="left")
    .join(safetyWOToxCast, on="id", how="left")
    .join(approvedTargets, on="id", how="left")
    .join(lofDrugsBBw_WD.withColumnRenamed("targetId", "id"), on="id", how="left")
    .withColumn("allBBW", F.when(F.col("allBBW") == 1, F.lit(1)).otherwise(F.lit(0)))
    .withColumn("allWD", F.when(F.col("allWD") == 1, F.lit(1)).otherwise(F.lit(0)))
    .persist()
)

#### Make deciles for harmonic sum
# Calculate quartiles or deciles
quantiles = df_comparisons.approxQuantile(
    "harmonicSumNorm", [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 0.01
)
# Define quartile or decile bins
bins = [float("-inf")] + quantiles + [float("inf")]
window_spec = Window.orderBy("harmonicSumNorm")
df_with_labels = (
    df_comparisons.withColumn(
        "decilesHarmonicSumNorm",
        F.when(F.col("harmonicSumNorm").isNull(), None).otherwise(
            sum(
                F.when(F.col("harmonicSumNorm") >= bin_val, 1).otherwise(0)
                for bin_val in bins
            )
        ),
    )
    .withColumn(
        "safetyLiabilities", F.when(F.col("nr").isNotNull(), F.lit(1)).otherwise(0)
    )
    .withColumn(
        "noToxCastLiab", F.when(F.col("noToxCast").isNotNull(), F.lit(1)).otherwise(0)
    )
    .drop("nr", "noToxCast")
)

                                                                                

#### Use human genetics to define the Severity Score

## ROC Curves

In [44]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.datasets import make_classification
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.ensemble import RandomForestClassifier
import random

random.seed(42)
np.random.seed(42)