In [None]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()


test2.unpersist()
test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
test2.persist()

common_cols = [
    "targetId","diseaseId","maxClinPhase",
    "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
    "NoneCellYes","NdiagonalYes","hasGenetics2"
]

# 1) actionType2 is ARRAY<STRING> → explode to one value per row
long_action = (
    test2.join(
        benchmark.select(
            "targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"]),
        on=["targetId","diseaseId","maxClinPhase"], how="left"
    )
    .select(*common_cols, F.explode_outer("actionType2").alias("value"))
    .withColumn("feature", F.lit("actionType2"))
    .select(*common_cols, "feature", "value")
)

# 2–5) the scalar columns
def longify_scalar(colname: str):
    return (
        test2.join(
            benchmark.select(
                "targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
            ).dropDuplicates(["targetId","diseaseId","maxClinPhase"]),
            on=["targetId","diseaseId","maxClinPhase"], how="left"
        )
        .select(*common_cols, F.col(colname).alias("value"))
        .withColumn("feature", F.lit(colname))
        .select(*common_cols, "feature", "value")
    )

long_biosample = longify_scalar("biosampleName")
long_project   = longify_scalar("projectId")
long_rstype    = longify_scalar("rightStudyType")
long_colocm    = longify_scalar("colocalisationMethod")

# Union them (schema-aligned)
long_features = (
    long_action
    .unionByName(long_biosample)
    .unionByName(long_project)
    .unionByName(long_rstype)
    .unionByName(long_colocm)
).filter(F.col("value").isNotNull())


# Single aggregation for ALL features and ALL metrics (booleans as max over yes/no)
agg_once = (
    long_features
    .groupBy("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT","feature","value")
    .agg(
        F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
        F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
        F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
    )
    .selectExpr(
        "*",
        "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
        "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
        "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
    )
)
# ---- make sure 'yes'/'no' columns exist as ints (we already did fillna(0), but be explicit)
mat_counts = (
    mat_counts
    .withColumn("yes", F.coalesce(F.col("yes"), F.lit(0)).cast("int"))
    .withColumn("no",  F.coalesce(F.col("no"),  F.lit(0)).cast("int"))
)


spark session created at 2025-09-10 14:59:00.115158
Analysis started on 2025-09-10 at  2025-09-10 14:59:00.115158


25/09/10 14:59:05 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/10 14:59:05 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:43343
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc
Built temporary DoE datasets


                                                                                

Built analysis_chembl_indication


#### Actions:
##### Find how the matrix 2x2 is done
##### substitute rel_success and odds ratio with our formula

In [None]:

import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# Same schema you had:
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(matrix_2x2: np.ndarray):
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf has rows for both comparison=='yes' and comparison=='no' (may be missing one)
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    out = pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])
    return out

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# ✅ Use applyInPandas with the explicit schema
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)


# ============================
# Fisher + reporting (applyInPandas version)
# ============================
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from datetime import date
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# 0) Ensure 'yes'/'no' numeric columns exist and are ints (no NaN)
mat_counts = (
    mat_counts
    .fillna(0)
    .withColumn("yes", F.coalesce(F.col("yes"), F.lit(0)).cast("int"))
    .withColumn("no",  F.coalesce(F.col("no"),  F.lit(0)).cast("int"))
)

# 1) Output schema (mirrors your previous one)
result_schema = StructType([
    StructField("group",        StringType(),  True),  # metric
    StructField("comparison",   StringType(),  True),  # e.g. "<value>_only"
    StructField("phase",        StringType(),  True),  # phase_name
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

# 2) Relative success helper (vectorized inside pandas func)
def _relative_success(matrix_2x2: np.ndarray):
    # rows: comparison yes/no ; cols: prediction yes/no
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

# 3) Plain Python function for applyInPandas (no decorator)
def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf columns: metric, feature, value, phase_name, comparison, yes, no
    sub = pdf[["comparison","yes","no"]].copy()
    # enforce both rows 'yes' and 'no'
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# 4) Apply per (metric, feature, value, phase_name)
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# 5) Formatting + annotation + CSV export (unchanged logic)
from itertools import chain
from pyspark.sql.functions import create_map

# disdic from agg_once, if not already present
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


                                                                                

importing functions
imported functions


25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_167 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_115 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_49 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_128 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: 

Analysis written: gs://ot-team/jroldan/analysis/2025-09-10_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv


In [6]:
# ============================
# Spark-friendly ANALYSIS
# ============================
from datetime import date
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from pyspark.sql import Window

# ---- 1) Build the comparison × prediction long table in one go ----
# metrics we want to analyze (these replace columns_to_aggregate)
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# Melt the phase flags into (phase_name, prediction) rows
phases_long = (
    agg_once.select(
        "targetId","diseaseId","maxClinPhase","feature","value",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name", "prediction")
    )
    .filter(F.col("prediction").isNotNull())
)

# For each metric, attach its comparison flag ("yes"/"no") and union them all
def attach_metric(metric_col: str):
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value", F.col(metric_col).alias("comparison"))
                .join(phases_long.select("targetId","diseaseId","maxClinPhase","feature","value","phase_name","prediction"),
                      on=["targetId","diseaseId","maxClinPhase","feature","value"],
                      how="inner")
                .withColumn("metric", F.lit(metric_col.replace("_flag", "")))  # pretty label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# Now analysis_long has rows like:
# (targetId, diseaseId, feature, value, metric, phase_name, comparison='yes'/'no', prediction='yes'/'no')

# ---- 2) Count ALL 2×2 cells in one aggregation ----
cell_counts = (
    analysis_long
    .groupBy("metric","feature","value","phase_name","comparison","prediction")
    .agg(F.count("*").alias("a"))
)

# ---- 3) Reshape to 2×2 matrices per (metric, feature, value, phase) ----
# Create a compact 2-row frame with columns yes/no for prediction
mat_counts = (
    cell_counts
    .groupBy("metric","feature","value","phase_name","comparison")
    .pivot("prediction", ["yes","no"])
    .agg(F.first("a"))
    .fillna(0)
)

# We'll compute Fisher per group using a grouped map Pandas UDF
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.functions import pandas_udf

# Output schema mirrors your previous 'schema'
result_schema = StructType([
    StructField("group",        StringType(),  True),  # metric name
    StructField("comparison",   StringType(),  True),  # e.g., "<value>_only"
    StructField("phase",        StringType(),  True),  # phase_name
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

# Relative success helper – reusing your function on driver won’t vectorize,
# so we replicate a simple version here. If you must use your exact math,
# import it and call it inside the pandas UDF.
def _relative_success(matrix_2x2: np.ndarray):
    # matrix rows: comparison yes/no; cols: prediction yes/no
    # success rate in comparison==yes vs comparison==no
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    # crude Wald CI for difference in proportions (you can swap with your exact function)
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

@pandas_udf(result_schema)
def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf has columns: metric, feature, value, phase_name, comparison, yes, no
    # build 2x2 with rows ordered comparison=['yes','no'] and cols ['yes','no']
    # Ensure both rows exist; fill missing with zeros
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    row = {
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",  # to match your previous naming
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    float(round(or_val, 2)),
        "pValue":       float(p_val),
        "lowerInterval":float(round(ci[0], 2)),
        "upperInterval":float(round(ci[1], 2)),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   float(round(rs, 2)),
        "rsLower":      float(round(rs_lo, 2)),
        "rsUpper":      float(round(rs_hi, 2)),
        "path":         "",   # you can fill a path pattern here if you still write per-combo parquet
    }
    return pd.DataFrame([row])

# Apply the grouped map per (metric, feature, value, phase)
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .apply(fisher_by_group)
)

# ---- 4) Optional spreadsheet formatting + annotate and export ----
from itertools import chain
from pyspark.sql.functions import create_map

# Your disdic: {value -> feature}; rebuild safely from agg_once if not present
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

# If you still want to keep the regex-based prefix/suffix (works with our "value_only" naming):
patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)  # reuse your helper
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


ValueError: Invalid udf: the udf argument must be a pandas_udf of type GROUPED_MAP.

### second try

In [None]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()

test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
test2.persist()

# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2"  # note: hasGenetics2 from your test2
    ]

    # Join phase flags once
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    t2_with_phase = test2.join(
        phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="left"
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
    ).filter(F.col("value").isNotNull())

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


# ============================
# Denominator = ALL pairs in analysis_chembl_indication (deduped)
# Build 2x2 counts using totals, then Fisher via applyInPandas
# ============================
from datetime import date
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# ---- 0) Universe of pairs & phase flags (only de-dup, no other filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")  # dedupe on these
    .distinct()
    .join(F.broadcast(negativeTD), on=["targetId","diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)
print('universe of pairs and phase flags built')

# Long view of phase flags for universe
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr("stack(5, "
           "'Phase>=4', `Phase>=4`, "
           "'Phase>=3', `Phase>=3`, "
           "'Phase>=2', `Phase>=2`, "
           "'Phase>=1', `Phase>=1`, "
           "'PhaseT',  `PhaseT`"
           ")").alias("phase_name","prediction")
)
print('phase_universe_long built')

# Totals per phase (denominator totals)
total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)
print('phase_universe_long built')

# ---- 1) Build analysis_long from agg_once (flags) + phases (prediction)
# metrics we’ll analyze
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# phase flags per (target,disease,maxClinPhase)
phase_flags = (
    benchmark.select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

# stack phases for the records present in agg_once (feature,value specific)
phases_long_for_records = (
    phase_flags.join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
                     on=["targetId","diseaseId","maxClinPhase"], how="inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    # comparison = metric flag yes/no at (target,disease,feature,value)
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value",
                        F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, on=["targetId","diseaseId","maxClinPhase"], how="inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))  # prettier label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# ---- 2) Count distinct pairs for 2x2 components using the fixed universe
# a = count of pairs with comparison=='yes' AND prediction=='yes'
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
# yes_total = count of pairs with comparison=='yes' (regardless of prediction)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

# Assemble b,c,d from totals
counts = (
    yes_total
    .join(yes_yes, on=["metric","feature","value","phase_name"], how="left")
    .join(total_pairs_by_phase, on="phase_name", how="left")
    .join(total_pred_yes_by_phase, on="phase_name", how="left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.when(F.col("a")<0,0).otherwise(F.col("a")).cast("int").alias("a"),
        F.when(F.col("b")<0,0).otherwise(F.col("b")).cast("int").alias("b"),
        F.when(F.col("c")<0,0).otherwise(F.col("c")).cast("int").alias("c"),
        F.when(F.col("d")<0,0).otherwise(F.col("d")).cast("int").alias("d"),
        "total_pairs","total_pred_yes"
    )
)

# Convert to two-row format (comparison yes/no) with columns yes/no → ready for Fisher
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
)

# Safety: ensure ints and no nulls
mat_counts = (
    mat_counts.fillna(0)
              .withColumn("yes", F.col("yes").cast("int"))
              .withColumn("no",  F.col("no").cast("int"))
)

# ---- 3) Fisher per group with applyInPandas
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(matrix_2x2: np.ndarray):
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    if total == 0:
        # Return a neutral row to avoid SciPy errors on empty tables
        return pd.DataFrame([{
            "group":        pdf["metric"].iloc[0],
            "comparison":   f"{pdf['value'].iloc[0]}_only",
            "phase":        pdf["phase_name"].iloc[0],
            "oddsRatio":    float("nan"),
            "pValue":       float("nan"),
            "lowerInterval":float("nan"),
            "upperInterval":float("nan"),
            "total":        "0",
            "values":       mat.tolist(),
            "relSuccess":   float("nan"),
            "rsLower":      float("nan"),
            "rsUpper":      float("nan"),
            "path":         "",
        }])

    from scipy.stats import fisher_exact
    from scipy.stats.contingency import odds_ratio

    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# ---- 4) Spreadsheet formatting + annotation + CSV
from itertools import chain
from pyspark.sql.functions import create_map

# build disdic from agg_once
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

#today_date = date.today().isoformat()
#out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
#df_annot.toPandas().to_csv(out_csv, index=False)
#print(f"Analysis written: {out_csv}")
today_date = date.today().isoformat()

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
print('preparing df_annot to write')
(df_annot
  .coalesce(coalesce_n)                 # 1 file if feasible; increase if OOM on shuffle
  .write.mode("overwrite")
  .option("header", "true")
  .csv(out_path))

print(f"Wrote CSV shards to: {out_path}")


spark session created at 2025-09-17 14:54:25.328103
Analysis started on 2025-09-17 at  2025-09-17 14:54:25.328103


25/09/17 14:54:27 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/17 14:54:27 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:42205
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc
Built temporary DoE datasets


                                                                                

Built analysis_chembl_indication
[info] agg_once not found — rebuilding it from test2/benchmark …
[info] agg_once rebuilt.
universe of pairs and phase flags built
phase_universe_long built
phase_universe_long built


                                                                                

importing functions
imported functions




NameError: name 'today_date' is not defined

In [None]:
today_date = date.today().isoformat()

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
print('preparing df_annot to write')
(df_annot
  .coalesce(coalesce_n)                 # 1 file if feasible; increase if OOM on shuffle
  .write.mode("overwrite")
  .option("header", "true")
  .csv(out_path))

print(f"Wrote CSV shards to: {out_path}")

preparing df_annot to write


In [None]:
today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_36 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_63 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_115 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_97 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_13 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint:

In [None]:
today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


In [4]:
(analysis_long
    .filter(
        (F.col("metric") == "NoneCellYes") &
        (F.col("value") == "Alasoo_2018")   # before we added "_only" suffix
    ))

ConnectionRefusedError: [Errno 111] Connection refused

In [2]:
# Example: show rows for metric/group = 'NoneCellYes' AND value = 'Alasoo_2018'
check_df = (
    analysis_long
    .filter(
        (F.col("metric") == "NoneCellYes") &
        (F.col("value") == "Alasoo_2018")   # before we added "_only" suffix
    )
)

print("Count of rows that will enter Fisher 2×2 for Alasoo_2018 / NoneCellYes:", check_df.count())
check_df.show(50, truncate=False)


25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_63 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_49 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_53 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_10 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_31 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: 

Count of rows that will enter Fisher 2×2 for Alasoo_2018 / NoneCellYes: 285


25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_73 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_63 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_98 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_97 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_13 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_53 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_10 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: 

Py4JError: An error occurred while calling o2989.showString

### desglossing the analysis

In [3]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()




SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:37033
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc


25/09/17 12:50:13 WARN CacheManager: Asked to cache already cached data.
25/09/17 12:50:14 WARN CacheManager: Asked to cache already cached data.


Built temporary DoE datasets


25/09/17 12:50:15 WARN CacheManager: Asked to cache already cached data.
25/09/17 12:50:15 WARN CacheManager: Asked to cache already cached data.


Built analysis_chembl_indication


In [4]:
# --- prerequisites used below ---
# doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# Recompute safe maxima + names (handles nulls as 0)
safe_max = F.greatest(*[F.coalesce(F.col(c), F.lit(0)) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.coalesce(F.col(c), F.lit(0)) == safe_max, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)
max_names_set = F.array_sort(F.array_distinct(max_names))

# Define coherent cross-pairs for GWAS/coloc DoE ties (order-independent)
pair1 = F.array_sort(F.array(F.lit("GoF_protect"), F.lit("LoF_risk")))
pair2 = F.array_sort(F.array(F.lit("LoF_protect"), F.lit("GoF_risk")))

# GWAS/coloc is comparable if: single maximum OR exactly one of the coherent pairs
gwasComparable = (
    (F.size(max_names_set) == 1) |
    ((F.size(max_names_set) == 2) & ((max_names_set == pair1) | (max_names_set == pair2)))
)

# Drug is comparable if exactly one of the two protect signals is present
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()
drugComparable = (has_lof_ch != has_gof_ch)   # XOR in Spark

# Keep your existing drugCoherency label (optional; unchanged)
drugCoherencyCol = (
    F.when(has_lof_ch & ~has_gof_ch, "coherent")
     .when(~has_lof_ch & has_gof_ch, "coherent")
     .when(has_lof_ch & has_gof_ch, "dispar")
     .otherwise("other")
)

# Apply the “compare only if both sides are coherent” rule to the flags
test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("gwasComparable", gwasComparable)
    .withColumn("drugComparable", drugComparable)
    .withColumn(
        "NoneCellYes",
        F.when(
            gwasComparable & drugComparable &
            has_lof_ch & F.array_contains(max_names, F.lit("LoF_protect")),
            "yes"
        )
        .when(
            gwasComparable & drugComparable &
            has_gof_ch & F.array_contains(max_names, F.lit("GoF_protect")),
            "yes"
        )
        .otherwise("no")
    )
    .withColumn(
        "NdiagonalYes",
        F.when(
            gwasComparable & drugComparable &
            has_lof_ch &
            (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))),
            "yes"
        )
        .when(
            gwasComparable & drugComparable &
            has_gof_ch &
            (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))),
            "yes"
        )
        .otherwise("no")
    )
    .withColumn("drugCoherency", drugCoherencyCol)
    .withColumn(
        "hasGenetics2",
        F.when(
            reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
            F.lit("no")
        ).otherwise(F.lit("yes"))
    )
    # If you want hasGenetics to reflect hasGenetics2, set it directly; otherwise keep your placeholder:
    .withColumn("hasGenetics", F.col("hasGenetics2"))
)


In [5]:
test2.persist()

DataFrame[targetId: string, diseaseId: string, maxClinPhase: double, actionType2: array<string>, biosampleName: string, projectId: string, rightStudyType: string, colocalisationMethod: string, drugLoF_protect: bigint, drugGoF_protect: bigint, LoF_protect: bigint, GoF_risk: bigint, LoF_risk: bigint, GoF_protect: bigint, gwasComparable: boolean, drugComparable: boolean, NoneCellYes: string, NdiagonalYes: string, drugCoherency: string, hasGenetics2: string, hasGenetics: string]

In [7]:
test2.filter(F.col('NoneCellYes')=='yes').count()

15854

In [None]:

# --- Max DoE names (safe to nulls) + coherency annotation only ---
# If you already have `doe_cols` defined:
# doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# 1) Recompute maxima safely (treat nulls as 0) and collect all tied names
safe_max = F.greatest(*[F.coalesce(F.col(c), F.lit(0)) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.coalesce(F.col(c), F.lit(0)) == safe_max, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# 2) Build a sorted distinct array to compare sets ignoring order
max_names_set = F.array_sort(F.array_distinct(max_names))

# 3) Define the only coherent cross-pairs
pair1 = F.array_sort(F.array(F.lit("GoF_protect"), F.lit("LoF_risk")))
pair2 = F.array_sort(F.array(F.lit("LoF_protect"), F.lit("GoF_risk")))

# 4) Annotate coherency of DoE maxima (NO filtering, just a label)
test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn(
        "maxDoECoherency",
        F.when(F.size(max_names_set) == 1, F.lit("single"))
         .when(
             (F.size(max_names_set) == 2) &
             ((max_names_set == pair1) | (max_names_set == pair2)),
             F.lit("coherent")
         )
         .when(F.size(max_names_set) >= 2, F.lit("incoherent"))   # includes size 3 or 4
         .otherwise(F.lit("single"))
    )
)



test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
test2.persist()

In [10]:
# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2_df, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2",
        # (optional diagnostics if you want them downstream)
        # "gwasComparable","drugComparable","maxDoECoherency"
    ]

    # Join phase flags once (LEFT join is correct here)
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    # Only rows present in test2 matter for feature/value tall view
    t2_with_phase = (
        test2_df
        .join(phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="right")
        # make flags explicit; null → "no"
        .withColumn("NoneCellYes",  F.coalesce(F.col("NoneCellYes"),  F.lit("no")))
        .withColumn("NdiagonalYes", F.coalesce(F.col("NdiagonalYes"), F.lit("no")))
        .withColumn("hasGenetics2", F.coalesce(F.col("hasGenetics2"), F.lit("no")))
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table (drop value=null)
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
        #.filter(F.col("value").isNotNull())
    )

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")  == "yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes") == "yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2") == "yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1  THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1  THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


In [12]:
agg_once.limit(5).show()



ERROR:root:KeyboardInterrupt while sending command.0][Stage 200:=>(9 + 5) / 17] ]
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/miniconda3/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_267_90 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_192_57 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_192_59 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_267_79 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_34 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_267_30 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_157 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: N

In [3]:
test2.filter(F.col('NoneCellYes')=='yes').show()

+---------------+-----------+------------+--------------------+--------------------+----------+--------------+--------------------+---------------+---------------+-----------+--------+--------+-----------+-----------+------------+-------------+------------+-----------+
|       targetId|  diseaseId|maxClinPhase|         actionType2|       biosampleName| projectId|rightStudyType|colocalisationMethod|drugLoF_protect|drugGoF_protect|LoF_protect|GoF_risk|LoF_risk|GoF_protect|NoneCellYes|NdiagonalYes|drugCoherency|hasGenetics2|hasGenetics|
+---------------+-----------+------------+--------------------+--------------------+----------+--------------+--------------------+---------------+---------------+-----------+--------+--------+-----------+-----------+------------+-------------+------------+-----------+
|ENSG00000164116|EFO_0000537|         4.0|[POSITIVE ALLOSTE...|dorsolateral pref...|CommonMind|          eqtl|             eCAVIAR|           NULL|              1|          0|       0|      

In [None]:

# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2"  # note: hasGenetics2 from your test2
    ]

    # Join phase flags once
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    t2_with_phase = test2.join(
        phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="right" #### maybe this should be right
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
    )#.filter(F.col("value").isNotNull())

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


[info] agg_once not found — rebuilding it from test2/benchmark …
[info] agg_once rebuilt.


In [10]:
agg_once.persist()

25/09/10 22:01:21 WARN CacheManager: Asked to cache already cached data.


DataFrame[targetId: string, diseaseId: string, maxClinPhase: double, Phase>=4: string, Phase>=3: string, Phase>=2: string, Phase>=1: string, PhaseT: string, feature: string, value: string, NoneCellYes: int, NdiagonalYes: int, hasGenetics: int, NoneCellYes_flag: string, NdiagonalYes_flag: string, hasGenetics_flag: string]

In [11]:
agg_once.count()

85938

In [8]:
agg_once.groupBy('feature').count().show()

+--------------------+-----+
|             feature|count|
+--------------------+-----+
|         actionType2|78109|
|           projectId| 2055|
|colocalisationMethod|  937|
|       biosampleName| 3779|
|      rightStudyType| 1058|
+--------------------+-----+



In [None]:


# ============================
# Denominator = ALL pairs in analysis_chembl_indication (deduped)
# Build 2x2 counts using totals, then Fisher via applyInPandas
# ============================
from datetime import date
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# ---- 0) Universe of pairs & phase flags (only de-dup, no other filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")  # dedupe on these
    .distinct()
    .join(F.broadcast(negativeTD), on=["targetId","diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)

# Long view of phase flags for universe
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr("stack(5, "
           "'Phase>=4', `Phase>=4`, "
           "'Phase>=3', `Phase>=3`, "
           "'Phase>=2', `Phase>=2`, "
           "'Phase>=1', `Phase>=1`, "
           "'PhaseT',  `PhaseT`"
           ")").alias("phase_name","prediction")
)

# Totals per phase (denominator totals)
total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)

# ---- 1) Build analysis_long from agg_once (flags) + phases (prediction)
# metrics we’ll analyze
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# phase flags per (target,disease,maxClinPhase)
phase_flags = (
    benchmark.select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

# stack phases for the records present in agg_once (feature,value specific)
phases_long_for_records = (
    phase_flags.join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
                     on=["targetId","diseaseId","maxClinPhase"], how="inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    # comparison = metric flag yes/no at (target,disease,feature,value)
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value",
                        F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, on=["targetId","diseaseId","maxClinPhase"], how="inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))  # prettier label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# ---- 2) Count distinct pairs for 2x2 components using the fixed universe
# a = count of pairs with comparison=='yes' AND prediction=='yes'
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
# yes_total = count of pairs with comparison=='yes' (regardless of prediction)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

# Assemble b,c,d from totals
counts = (
    yes_total
    .join(yes_yes, on=["metric","feature","value","phase_name"], how="left")
    .join(total_pairs_by_phase, on="phase_name", how="left")
    .join(total_pred_yes_by_phase, on="phase_name", how="left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.when(F.col("a")<0,0).otherwise(F.col("a")).cast("int").alias("a"),
        F.when(F.col("b")<0,0).otherwise(F.col("b")).cast("int").alias("b"),
        F.when(F.col("c")<0,0).otherwise(F.col("c")).cast("int").alias("c"),
        F.when(F.col("d")<0,0).otherwise(F.col("d")).cast("int").alias("d"),
        "total_pairs","total_pred_yes"
    )
)

# Convert to two-row format (comparison yes/no) with columns yes/no → ready for Fisher
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
)

# Safety: ensure ints and no nulls
mat_counts = (
    mat_counts.fillna(0)
              .withColumn("yes", F.col("yes").cast("int"))
              .withColumn("no",  F.col("no").cast("int"))
)

#### LAST ITERATION


In [1]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()

test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
#test2.persist()

# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2"  # note: hasGenetics2 from your test2
    ]

    # Join phase flags once
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    t2_with_phase = test2.join(
        phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="left"
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
    ).filter(F.col("value").isNotNull())

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


# ============================
# Denominator = ALL pairs in analysis_chembl_indication (deduped)
# Build 2x2 counts using totals, then Fisher via applyInPandas
# ============================
from datetime import date
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# ---- 0) Universe of pairs & phase flags (only de-dup, no other filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")  # dedupe on these
    .distinct()
    .join(F.broadcast(negativeTD), on=["targetId","diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)
print('universe of pairs and phase flags built')

# Long view of phase flags for universe
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr("stack(5, "
           "'Phase>=4', `Phase>=4`, "
           "'Phase>=3', `Phase>=3`, "
           "'Phase>=2', `Phase>=2`, "
           "'Phase>=1', `Phase>=1`, "
           "'PhaseT',  `PhaseT`"
           ")").alias("phase_name","prediction")
)
print('phase_universe_long built')

# Totals per phase (denominator totals)
total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)
print('phase_universe_long built')

# ---- 1) Build analysis_long from agg_once (flags) + phases (prediction)
# metrics we’ll analyze
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# phase flags per (target,disease,maxClinPhase)
phase_flags = (
    benchmark.select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

# stack phases for the records present in agg_once (feature,value specific)
phases_long_for_records = (
    phase_flags.join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
                     on=["targetId","diseaseId","maxClinPhase"], how="inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    # comparison = metric flag yes/no at (target,disease,feature,value)
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value",
                        F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, on=["targetId","diseaseId","maxClinPhase"], how="inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))  # prettier label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# ---- 2) Count distinct pairs for 2x2 components using the fixed universe
# a = count of pairs with comparison=='yes' AND prediction=='yes'
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
# yes_total = count of pairs with comparison=='yes' (regardless of prediction)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

# Assemble b,c,d from totals
counts = (
    yes_total
    .join(yes_yes, on=["metric","feature","value","phase_name"], how="left")
    .join(total_pairs_by_phase, on="phase_name", how="left")
    .join(total_pred_yes_by_phase, on="phase_name", how="left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.when(F.col("a")<0,0).otherwise(F.col("a")).cast("int").alias("a"),
        F.when(F.col("b")<0,0).otherwise(F.col("b")).cast("int").alias("b"),
        F.when(F.col("c")<0,0).otherwise(F.col("c")).cast("int").alias("c"),
        F.when(F.col("d")<0,0).otherwise(F.col("d")).cast("int").alias("d"),
        "total_pairs","total_pred_yes"
    )
)

# Convert to two-row format (comparison yes/no) with columns yes/no → ready for Fisher
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
)

# Safety: ensure ints and no nulls
mat_counts = (
    mat_counts.fillna(0)
              .withColumn("yes", F.col("yes").cast("int"))
              .withColumn("no",  F.col("no").cast("int"))
)

# ---- 3) Fisher per group with applyInPandas
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(matrix_2x2: np.ndarray):
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    if total == 0:
        # Return a neutral row to avoid SciPy errors on empty tables
        return pd.DataFrame([{
            "group":        pdf["metric"].iloc[0],
            "comparison":   f"{pdf['value'].iloc[0]}_only",
            "phase":        pdf["phase_name"].iloc[0],
            "oddsRatio":    float("nan"),
            "pValue":       float("nan"),
            "lowerInterval":float("nan"),
            "upperInterval":float("nan"),
            "total":        "0",
            "values":       mat.tolist(),
            "relSuccess":   float("nan"),
            "rsLower":      float("nan"),
            "rsUpper":      float("nan"),
            "path":         "",
        }])

    from scipy.stats import fisher_exact
    from scipy.stats.contingency import odds_ratio

    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# ---- 4) Spreadsheet formatting + annotation + CSV
from itertools import chain
from pyspark.sql.functions import create_map

# build disdic from agg_once
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

#today_date = date.today().isoformat()
#out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
#df_annot.toPandas().to_csv(out_csv, index=False)
#print(f"Analysis written: {out_csv}")
today_date = date.today().isoformat()

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)


spark session created at 2025-09-17 15:35:22.959890
Analysis started on 2025-09-17 at  2025-09-17 15:35:22.959890


25/09/17 15:35:28 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/17 15:35:28 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:37689


                                                                                

Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc
Built temporary DoE datasets


                                                                                

Built analysis_chembl_indication
[info] agg_once not found — rebuilding it from test2/benchmark …
[info] agg_once rebuilt.
universe of pairs and phase flags built
phase_universe_long built
phase_universe_long built


                                                                                

importing functions
imported functions




In [None]:

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
print('preparing df_annot to write')
(df_annot
  .coalesce(coalesce_n)                 # 1 file if feasible; increase if OOM on shuffle
  .write.mode("overwrite")
  .option("header", "true")
  .csv(out_path))

print(f"Wrote CSV shards to: {out_path}")

In [None]:
prev = spark.conf.get("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false")


out_path    = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try"
parquet_path = out_path + "_parquet"

writer_parts = 64   # try 48/64/80 depending on size

# Stage 1: write Parquet
(df_annot
  #.repartition(writer_parts, "metric", "phase")   # pick keys you have; omit if not present
  .write.mode("overwrite")
  .parquet(parquet_path))

# Free lineage
df_annot.unpersist()
spark.catalog.clearCache()

# Stage 2: read Parquet back and write CSV shards (no coalesce(1)!)
(spark.read.parquet(parquet_path)
  .repartition(writer_parts, "metric", "phase")
  .write.mode("overwrite")
  .option("header","true")
  .option("maxRecordsPerFile","200000")   # smaller chunks → lower peak memory
  .option("compression","gzip")           # optional; reduces IO/space
  .csv(out_path))

# Restore AQE setting
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", prev)

print(f"Wrote CSV shards under: {out_path}")


[Stage 317:> (9 + 1) / 80][Stage 318:(113 + 1) / 174][Stage 319:> (8 + 1) / 80]]]

In [2]:
parquet_path = out_path + "_parquet"

print("rows:", df_annot.count())
print("partitions (annot):", df_annot.rdd.getNumPartitions())
print("partitions (out parquet):", spark.read.parquet(parquet_path).rdd.getNumPartitions())


25/09/17 15:43:49 WARN TransportChannelHandler: Exception in connection from /10.132.0.18:55618
java.io.IOException: Connection reset by peer
	at java.base/sun.nio.ch.FileDispatcherImpl.read0(Native Method) ~[?:?]
	at java.base/sun.nio.ch.SocketDispatcher.read(SocketDispatcher.java:39) ~[?:?]
	at java.base/sun.nio.ch.IOUtil.readIntoNativeBuffer(IOUtil.java:276) ~[?:?]
	at java.base/sun.nio.ch.IOUtil.read(IOUtil.java:233) ~[?:?]
	at java.base/sun.nio.ch.IOUtil.read(IOUtil.java:223) ~[?:?]
	at java.base/sun.nio.ch.SocketChannelImpl.read(SocketChannelImpl.java:356) ~[?:?]
	at io.netty.buffer.PooledByteBuf.setBytes(PooledByteBuf.java:254) ~[netty-buffer-4.1.100.Final.jar:4.1.100.Final]
	at io.netty.buffer.AbstractByteBuf.writeBytes(AbstractByteBuf.java:1132) ~[netty-buffer-4.1.100.Final.jar:4.1.100.Final]
	at io.netty.channel.socket.nio.NioSocketChannel.doReadBytes(NioSocketChannel.java:357) ~[netty-transport-4.1.100.Final.jar:4.1.100.Final]
	at io.netty.channel.nio.AbstractNioByteChannel$

KeyboardInterrupt: 

[Stage 275:(141 + 1) / 174][Stage 276:(141 + 1) / 174][Stage 277:(141 + 1) / 174]

In [2]:
parquet_path = out_path + "_parquet"

# Stage 1: write Parquet (fast/compact)

(df_annot
  .repartition(64)
  .write.mode("overwrite")
  .parquet(parquet_path))
print('parquet written')
# Stage 2: read back and spill to CSV shards
(spark.read.parquet(parquet_path)
  .repartition(64)
  .write.mode("overwrite")
  .option("header","true")
  .option("maxRecordsPerFile","500000")
  .csv(out_path))
print('csv written')

25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_73 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_115 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_49 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_157 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_79 !
25/09/17 15:29:13 WARN BlockManagerMasterEndpoint: No

KeyboardInterrupt: 