In [None]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()


test2.unpersist()
test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
test2.persist()

common_cols = [
    "targetId","diseaseId","maxClinPhase",
    "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
    "NoneCellYes","NdiagonalYes","hasGenetics2"
]

# 1) actionType2 is ARRAY<STRING> → explode to one value per row
long_action = (
    test2.join(
        benchmark.select(
            "targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"]),
        on=["targetId","diseaseId","maxClinPhase"], how="left"
    )
    .select(*common_cols, F.explode_outer("actionType2").alias("value"))
    .withColumn("feature", F.lit("actionType2"))
    .select(*common_cols, "feature", "value")
)

# 2–5) the scalar columns
def longify_scalar(colname: str):
    return (
        test2.join(
            benchmark.select(
                "targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
            ).dropDuplicates(["targetId","diseaseId","maxClinPhase"]),
            on=["targetId","diseaseId","maxClinPhase"], how="left"
        )
        .select(*common_cols, F.col(colname).alias("value"))
        .withColumn("feature", F.lit(colname))
        .select(*common_cols, "feature", "value")
    )

long_biosample = longify_scalar("biosampleName")
long_project   = longify_scalar("projectId")
long_rstype    = longify_scalar("rightStudyType")
long_colocm    = longify_scalar("colocalisationMethod")

# Union them (schema-aligned)
long_features = (
    long_action
    .unionByName(long_biosample)
    .unionByName(long_project)
    .unionByName(long_rstype)
    .unionByName(long_colocm)
).filter(F.col("value").isNotNull())


# Single aggregation for ALL features and ALL metrics (booleans as max over yes/no)
agg_once = (
    long_features
    .groupBy("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT","feature","value")
    .agg(
        F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
        F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
        F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
    )
    .selectExpr(
        "*",
        "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
        "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
        "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
    )
)
# ---- make sure 'yes'/'no' columns exist as ints (we already did fillna(0), but be explicit)
mat_counts = (
    mat_counts
    .withColumn("yes", F.coalesce(F.col("yes"), F.lit(0)).cast("int"))
    .withColumn("no",  F.coalesce(F.col("no"),  F.lit(0)).cast("int"))
)


spark session created at 2025-09-10 14:59:00.115158
Analysis started on 2025-09-10 at  2025-09-10 14:59:00.115158


25/09/10 14:59:05 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/10 14:59:05 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:43343
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc
Built temporary DoE datasets


                                                                                

Built analysis_chembl_indication


#### Actions:
##### Find how the matrix 2x2 is done
##### substitute rel_success and odds ratio with our formula

In [None]:

import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# Same schema you had:
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(matrix_2x2: np.ndarray):
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf has rows for both comparison=='yes' and comparison=='no' (may be missing one)
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    out = pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])
    return out

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# ✅ Use applyInPandas with the explicit schema
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)


# ============================
# Fisher + reporting (applyInPandas version)
# ============================
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from datetime import date
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# 0) Ensure 'yes'/'no' numeric columns exist and are ints (no NaN)
mat_counts = (
    mat_counts
    .fillna(0)
    .withColumn("yes", F.coalesce(F.col("yes"), F.lit(0)).cast("int"))
    .withColumn("no",  F.coalesce(F.col("no"),  F.lit(0)).cast("int"))
)

# 1) Output schema (mirrors your previous one)
result_schema = StructType([
    StructField("group",        StringType(),  True),  # metric
    StructField("comparison",   StringType(),  True),  # e.g. "<value>_only"
    StructField("phase",        StringType(),  True),  # phase_name
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

# 2) Relative success helper (vectorized inside pandas func)
def _relative_success(matrix_2x2: np.ndarray):
    # rows: comparison yes/no ; cols: prediction yes/no
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

# 3) Plain Python function for applyInPandas (no decorator)
def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf columns: metric, feature, value, phase_name, comparison, yes, no
    sub = pdf[["comparison","yes","no"]].copy()
    # enforce both rows 'yes' and 'no'
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

# 4) Apply per (metric, feature, value, phase_name)
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# 5) Formatting + annotation + CSV export (unchanged logic)
from itertools import chain
from pyspark.sql.functions import create_map

# disdic from agg_once, if not already present
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


                                                                                

importing functions
imported functions


25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_167 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_115 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_49 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_128 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/10 10:23:11 WARN BlockManagerMasterEndpoint: 

Analysis written: gs://ot-team/jroldan/analysis/2025-09-10_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv


In [6]:
# ============================
# Spark-friendly ANALYSIS
# ============================
from datetime import date
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from pyspark.sql import Window

# ---- 1) Build the comparison × prediction long table in one go ----
# metrics we want to analyze (these replace columns_to_aggregate)
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# Melt the phase flags into (phase_name, prediction) rows
phases_long = (
    agg_once.select(
        "targetId","diseaseId","maxClinPhase","feature","value",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name", "prediction")
    )
    .filter(F.col("prediction").isNotNull())
)

# For each metric, attach its comparison flag ("yes"/"no") and union them all
def attach_metric(metric_col: str):
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value", F.col(metric_col).alias("comparison"))
                .join(phases_long.select("targetId","diseaseId","maxClinPhase","feature","value","phase_name","prediction"),
                      on=["targetId","diseaseId","maxClinPhase","feature","value"],
                      how="inner")
                .withColumn("metric", F.lit(metric_col.replace("_flag", "")))  # pretty label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# Now analysis_long has rows like:
# (targetId, diseaseId, feature, value, metric, phase_name, comparison='yes'/'no', prediction='yes'/'no')

# ---- 2) Count ALL 2×2 cells in one aggregation ----
cell_counts = (
    analysis_long
    .groupBy("metric","feature","value","phase_name","comparison","prediction")
    .agg(F.count("*").alias("a"))
)

# ---- 3) Reshape to 2×2 matrices per (metric, feature, value, phase) ----
# Create a compact 2-row frame with columns yes/no for prediction
mat_counts = (
    cell_counts
    .groupBy("metric","feature","value","phase_name","comparison")
    .pivot("prediction", ["yes","no"])
    .agg(F.first("a"))
    .fillna(0)
)

# We'll compute Fisher per group using a grouped map Pandas UDF
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.functions import pandas_udf

# Output schema mirrors your previous 'schema'
result_schema = StructType([
    StructField("group",        StringType(),  True),  # metric name
    StructField("comparison",   StringType(),  True),  # e.g., "<value>_only"
    StructField("phase",        StringType(),  True),  # phase_name
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

# Relative success helper – reusing your function on driver won’t vectorize,
# so we replicate a simple version here. If you must use your exact math,
# import it and call it inside the pandas UDF.
def _relative_success(matrix_2x2: np.ndarray):
    # matrix rows: comparison yes/no; cols: prediction yes/no
    # success rate in comparison==yes vs comparison==no
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    # crude Wald CI for difference in proportions (you can swap with your exact function)
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

@pandas_udf(result_schema)
def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf has columns: metric, feature, value, phase_name, comparison, yes, no
    # build 2x2 with rows ordered comparison=['yes','no'] and cols ['yes','no']
    # Ensure both rows exist; fill missing with zeros
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    row = {
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",  # to match your previous naming
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    float(round(or_val, 2)),
        "pValue":       float(p_val),
        "lowerInterval":float(round(ci[0], 2)),
        "upperInterval":float(round(ci[1], 2)),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   float(round(rs, 2)),
        "rsLower":      float(round(rs_lo, 2)),
        "rsUpper":      float(round(rs_hi, 2)),
        "path":         "",   # you can fill a path pattern here if you still write per-combo parquet
    }
    return pd.DataFrame([row])

# Apply the grouped map per (metric, feature, value, phase)
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .apply(fisher_by_group)
)

# ---- 4) Optional spreadsheet formatting + annotate and export ----
from itertools import chain
from pyspark.sql.functions import create_map

# Your disdic: {value -> feature}; rebuild safely from agg_once if not present
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

# If you still want to keep the regex-based prefix/suffix (works with our "value_only" naming):
patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)  # reuse your helper
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


ValueError: Invalid udf: the udf argument must be a pandas_udf of type GROUPED_MAP.

### second try

In [None]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()

test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
test2.persist()

# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2"  # note: hasGenetics2 from your test2
    ]

    # Join phase flags once
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    t2_with_phase = test2.join(
        phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="left"
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
    ).filter(F.col("value").isNotNull())

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


# ============================
# Denominator = ALL pairs in analysis_chembl_indication (deduped)
# Build 2x2 counts using totals, then Fisher via applyInPandas
# ============================
from datetime import date
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# ---- 0) Universe of pairs & phase flags (only de-dup, no other filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")  # dedupe on these
    .distinct()
    .join(F.broadcast(negativeTD), on=["targetId","diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)
print('universe of pairs and phase flags built')

# Long view of phase flags for universe
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr("stack(5, "
           "'Phase>=4', `Phase>=4`, "
           "'Phase>=3', `Phase>=3`, "
           "'Phase>=2', `Phase>=2`, "
           "'Phase>=1', `Phase>=1`, "
           "'PhaseT',  `PhaseT`"
           ")").alias("phase_name","prediction")
)
print('phase_universe_long built')

# Totals per phase (denominator totals)
total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)
print('phase_universe_long built')

# ---- 1) Build analysis_long from agg_once (flags) + phases (prediction)
# metrics we’ll analyze
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# phase flags per (target,disease,maxClinPhase)
phase_flags = (
    benchmark.select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

# stack phases for the records present in agg_once (feature,value specific)
phases_long_for_records = (
    phase_flags.join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
                     on=["targetId","diseaseId","maxClinPhase"], how="inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    # comparison = metric flag yes/no at (target,disease,feature,value)
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value",
                        F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, on=["targetId","diseaseId","maxClinPhase"], how="inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))  # prettier label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# ---- 2) Count distinct pairs for 2x2 components using the fixed universe
# a = count of pairs with comparison=='yes' AND prediction=='yes'
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
# yes_total = count of pairs with comparison=='yes' (regardless of prediction)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

# Assemble b,c,d from totals
counts = (
    yes_total
    .join(yes_yes, on=["metric","feature","value","phase_name"], how="left")
    .join(total_pairs_by_phase, on="phase_name", how="left")
    .join(total_pred_yes_by_phase, on="phase_name", how="left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.when(F.col("a")<0,0).otherwise(F.col("a")).cast("int").alias("a"),
        F.when(F.col("b")<0,0).otherwise(F.col("b")).cast("int").alias("b"),
        F.when(F.col("c")<0,0).otherwise(F.col("c")).cast("int").alias("c"),
        F.when(F.col("d")<0,0).otherwise(F.col("d")).cast("int").alias("d"),
        "total_pairs","total_pred_yes"
    )
)

# Convert to two-row format (comparison yes/no) with columns yes/no → ready for Fisher
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
)

# Safety: ensure ints and no nulls
mat_counts = (
    mat_counts.fillna(0)
              .withColumn("yes", F.col("yes").cast("int"))
              .withColumn("no",  F.col("no").cast("int"))
)

# ---- 3) Fisher per group with applyInPandas
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(matrix_2x2: np.ndarray):
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    if total == 0:
        # Return a neutral row to avoid SciPy errors on empty tables
        return pd.DataFrame([{
            "group":        pdf["metric"].iloc[0],
            "comparison":   f"{pdf['value'].iloc[0]}_only",
            "phase":        pdf["phase_name"].iloc[0],
            "oddsRatio":    float("nan"),
            "pValue":       float("nan"),
            "lowerInterval":float("nan"),
            "upperInterval":float("nan"),
            "total":        "0",
            "values":       mat.tolist(),
            "relSuccess":   float("nan"),
            "rsLower":      float("nan"),
            "rsUpper":      float("nan"),
            "path":         "",
        }])

    from scipy.stats import fisher_exact
    from scipy.stats.contingency import odds_ratio

    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# ---- 4) Spreadsheet formatting + annotation + CSV
from itertools import chain
from pyspark.sql.functions import create_map

# build disdic from agg_once
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

#today_date = date.today().isoformat()
#out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
#df_annot.toPandas().to_csv(out_csv, index=False)
#print(f"Analysis written: {out_csv}")
today_date = date.today().isoformat()

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
print('preparing df_annot to write')
(df_annot
  .coalesce(coalesce_n)                 # 1 file if feasible; increase if OOM on shuffle
  .write.mode("overwrite")
  .option("header", "true")
  .csv(out_path))

print(f"Wrote CSV shards to: {out_path}")


spark session created at 2025-09-17 14:54:25.328103
Analysis started on 2025-09-17 at  2025-09-17 14:54:25.328103


25/09/17 14:54:27 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/17 14:54:27 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:42205
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc
Built temporary DoE datasets


                                                                                

Built analysis_chembl_indication
[info] agg_once not found — rebuilding it from test2/benchmark …
[info] agg_once rebuilt.
universe of pairs and phase flags built
phase_universe_long built
phase_universe_long built


                                                                                

importing functions
imported functions




NameError: name 'today_date' is not defined

In [None]:
today_date = date.today().isoformat()

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
print('preparing df_annot to write')
(df_annot
  .coalesce(coalesce_n)                 # 1 file if feasible; increase if OOM on shuffle
  .write.mode("overwrite")
  .option("header", "true")
  .csv(out_path))

print(f"Wrote CSV shards to: {out_path}")

preparing df_annot to write


In [None]:
today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_36 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_63 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_115 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_97 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_13 !
25/09/17 14:50:13 WARN BlockManagerMasterEndpoint:

In [None]:
today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")


In [4]:
(analysis_long
    .filter(
        (F.col("metric") == "NoneCellYes") &
        (F.col("value") == "Alasoo_2018")   # before we added "_only" suffix
    ))

ConnectionRefusedError: [Errno 111] Connection refused

In [2]:
# Example: show rows for metric/group = 'NoneCellYes' AND value = 'Alasoo_2018'
check_df = (
    analysis_long
    .filter(
        (F.col("metric") == "NoneCellYes") &
        (F.col("value") == "Alasoo_2018")   # before we added "_only" suffix
    )
)

print("Count of rows that will enter Fisher 2×2 for Alasoo_2018 / NoneCellYes:", check_df.count())
check_df.show(50, truncate=False)


25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_112 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_63 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_49 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_53 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_10 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_31 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/10 16:24:34 WARN BlockManagerMasterEndpoint: 

Count of rows that will enter Fisher 2×2 for Alasoo_2018 / NoneCellYes: 285


25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_73 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_63 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_34 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_98 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_97 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_98 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_215_13 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_144_53 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_138_10 !
25/09/10 16:30:33 WARN BlockManagerMasterEndpoint: 

Py4JError: An error occurred while calling o2989.showString

### desglossing the analysis

In [3]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "12g"                 # string with unit
executor_memory = "40g"               # string with unit (heap)
executor_cores = 10                   # int
num_executors = 1                     # int (one fat executor on single node)
executor_memory_overhead = "6g"       # string with unit (PySpark/Arrow/off-heap)
shuffle_partitions = 128              # int (~2–3x cores)
default_parallelism = 128             # int (match shuffle_partitions)

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()




SparkSession created with:
  spark.driver.memory: 12g
  spark.executor.memory: 40g
  spark.executor.cores: 10
  spark.executor.instances: 1
  spark.yarn.executor.memoryOverhead: 6g
  spark.sql.shuffle.partitions: 128
  spark.default.parallelism: 128
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:37033
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc


25/09/17 12:50:13 WARN CacheManager: Asked to cache already cached data.
25/09/17 12:50:14 WARN CacheManager: Asked to cache already cached data.


Built temporary DoE datasets


25/09/17 12:50:15 WARN CacheManager: Asked to cache already cached data.
25/09/17 12:50:15 WARN CacheManager: Asked to cache already cached data.


Built analysis_chembl_indication


In [4]:
# --- prerequisites used below ---
# doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# Recompute safe maxima + names (handles nulls as 0)
safe_max = F.greatest(*[F.coalesce(F.col(c), F.lit(0)) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.coalesce(F.col(c), F.lit(0)) == safe_max, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)
max_names_set = F.array_sort(F.array_distinct(max_names))

# Define coherent cross-pairs for GWAS/coloc DoE ties (order-independent)
pair1 = F.array_sort(F.array(F.lit("GoF_protect"), F.lit("LoF_risk")))
pair2 = F.array_sort(F.array(F.lit("LoF_protect"), F.lit("GoF_risk")))

# GWAS/coloc is comparable if: single maximum OR exactly one of the coherent pairs
gwasComparable = (
    (F.size(max_names_set) == 1) |
    ((F.size(max_names_set) == 2) & ((max_names_set == pair1) | (max_names_set == pair2)))
)

# Drug is comparable if exactly one of the two protect signals is present
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()
drugComparable = (has_lof_ch != has_gof_ch)   # XOR in Spark

# Keep your existing drugCoherency label (optional; unchanged)
drugCoherencyCol = (
    F.when(has_lof_ch & ~has_gof_ch, "coherent")
     .when(~has_lof_ch & has_gof_ch, "coherent")
     .when(has_lof_ch & has_gof_ch, "dispar")
     .otherwise("other")
)

# Apply the “compare only if both sides are coherent” rule to the flags
test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("gwasComparable", gwasComparable)
    .withColumn("drugComparable", drugComparable)
    .withColumn(
        "NoneCellYes",
        F.when(
            gwasComparable & drugComparable &
            has_lof_ch & F.array_contains(max_names, F.lit("LoF_protect")),
            "yes"
        )
        .when(
            gwasComparable & drugComparable &
            has_gof_ch & F.array_contains(max_names, F.lit("GoF_protect")),
            "yes"
        )
        .otherwise("no")
    )
    .withColumn(
        "NdiagonalYes",
        F.when(
            gwasComparable & drugComparable &
            has_lof_ch &
            (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))),
            "yes"
        )
        .when(
            gwasComparable & drugComparable &
            has_gof_ch &
            (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))),
            "yes"
        )
        .otherwise("no")
    )
    .withColumn("drugCoherency", drugCoherencyCol)
    .withColumn(
        "hasGenetics2",
        F.when(
            reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
            F.lit("no")
        ).otherwise(F.lit("yes"))
    )
    # If you want hasGenetics to reflect hasGenetics2, set it directly; otherwise keep your placeholder:
    .withColumn("hasGenetics", F.col("hasGenetics2"))
)


In [5]:
test2.persist()

DataFrame[targetId: string, diseaseId: string, maxClinPhase: double, actionType2: array<string>, biosampleName: string, projectId: string, rightStudyType: string, colocalisationMethod: string, drugLoF_protect: bigint, drugGoF_protect: bigint, LoF_protect: bigint, GoF_risk: bigint, LoF_risk: bigint, GoF_protect: bigint, gwasComparable: boolean, drugComparable: boolean, NoneCellYes: string, NdiagonalYes: string, drugCoherency: string, hasGenetics2: string, hasGenetics: string]

In [7]:
test2.filter(F.col('NoneCellYes')=='yes').count()

15854

In [None]:

# --- Max DoE names (safe to nulls) + coherency annotation only ---
# If you already have `doe_cols` defined:
# doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# 1) Recompute maxima safely (treat nulls as 0) and collect all tied names
safe_max = F.greatest(*[F.coalesce(F.col(c), F.lit(0)) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.coalesce(F.col(c), F.lit(0)) == safe_max, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# 2) Build a sorted distinct array to compare sets ignoring order
max_names_set = F.array_sort(F.array_distinct(max_names))

# 3) Define the only coherent cross-pairs
pair1 = F.array_sort(F.array(F.lit("GoF_protect"), F.lit("LoF_risk")))
pair2 = F.array_sort(F.array(F.lit("LoF_protect"), F.lit("GoF_risk")))

# 4) Annotate coherency of DoE maxima (NO filtering, just a label)
test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn(
        "maxDoECoherency",
        F.when(F.size(max_names_set) == 1, F.lit("single"))
         .when(
             (F.size(max_names_set) == 2) &
             ((max_names_set == pair1) | (max_names_set == pair2)),
             F.lit("coherent")
         )
         .when(F.size(max_names_set) >= 2, F.lit("incoherent"))   # includes size 3 or 4
         .otherwise(F.lit("single"))
    )
)



test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
test2.persist()

In [10]:
# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2_df, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2",
        # (optional diagnostics if you want them downstream)
        # "gwasComparable","drugComparable","maxDoECoherency"
    ]

    # Join phase flags once (LEFT join is correct here)
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    # Only rows present in test2 matter for feature/value tall view
    t2_with_phase = (
        test2_df
        .join(phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="right")
        # make flags explicit; null → "no"
        .withColumn("NoneCellYes",  F.coalesce(F.col("NoneCellYes"),  F.lit("no")))
        .withColumn("NdiagonalYes", F.coalesce(F.col("NdiagonalYes"), F.lit("no")))
        .withColumn("hasGenetics2", F.coalesce(F.col("hasGenetics2"), F.lit("no")))
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table (drop value=null)
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
        #.filter(F.col("value").isNotNull())
    )

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")  == "yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes") == "yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2") == "yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1  THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1  THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


In [12]:
agg_once.limit(5).show()



ERROR:root:KeyboardInterrupt while sending command.0][Stage 200:=>(9 + 5) / 17] ]
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/miniconda3/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_125 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_267_90 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_101 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_192_57 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_192_59 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_73_59 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_267_79 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_34 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_267_30 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_40_157 !
25/09/17 13:11:25 WARN BlockManagerMasterEndpoint: N

In [3]:
test2.filter(F.col('NoneCellYes')=='yes').show()

+---------------+-----------+------------+--------------------+--------------------+----------+--------------+--------------------+---------------+---------------+-----------+--------+--------+-----------+-----------+------------+-------------+------------+-----------+
|       targetId|  diseaseId|maxClinPhase|         actionType2|       biosampleName| projectId|rightStudyType|colocalisationMethod|drugLoF_protect|drugGoF_protect|LoF_protect|GoF_risk|LoF_risk|GoF_protect|NoneCellYes|NdiagonalYes|drugCoherency|hasGenetics2|hasGenetics|
+---------------+-----------+------------+--------------------+--------------------+----------+--------------+--------------------+---------------+---------------+-----------+--------+--------+-----------+-----------+------------+-------------+------------+-----------+
|ENSG00000164116|EFO_0000537|         4.0|[POSITIVE ALLOSTE...|dorsolateral pref...|CommonMind|          eqtl|             eCAVIAR|           NULL|              1|          0|       0|      

In [None]:

# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2"  # note: hasGenetics2 from your test2
    ]

    # Join phase flags once
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    t2_with_phase = test2.join(
        phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="right" #### maybe this should be right
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
    )#.filter(F.col("value").isNotNull())

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")


[info] agg_once not found — rebuilding it from test2/benchmark …
[info] agg_once rebuilt.


In [10]:
agg_once.persist()

25/09/10 22:01:21 WARN CacheManager: Asked to cache already cached data.


DataFrame[targetId: string, diseaseId: string, maxClinPhase: double, Phase>=4: string, Phase>=3: string, Phase>=2: string, Phase>=1: string, PhaseT: string, feature: string, value: string, NoneCellYes: int, NdiagonalYes: int, hasGenetics: int, NoneCellYes_flag: string, NdiagonalYes_flag: string, hasGenetics_flag: string]

In [11]:
agg_once.count()

85938

In [8]:
agg_once.groupBy('feature').count().show()

+--------------------+-----+
|             feature|count|
+--------------------+-----+
|         actionType2|78109|
|           projectId| 2055|
|colocalisationMethod|  937|
|       biosampleName| 3779|
|      rightStudyType| 1058|
+--------------------+-----+



In [None]:


# ============================
# Denominator = ALL pairs in analysis_chembl_indication (deduped)
# Build 2x2 counts using totals, then Fisher via applyInPandas
# ============================
from datetime import date
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# ---- 0) Universe of pairs & phase flags (only de-dup, no other filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")  # dedupe on these
    .distinct()
    .join(F.broadcast(negativeTD), on=["targetId","diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)

# Long view of phase flags for universe
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr("stack(5, "
           "'Phase>=4', `Phase>=4`, "
           "'Phase>=3', `Phase>=3`, "
           "'Phase>=2', `Phase>=2`, "
           "'Phase>=1', `Phase>=1`, "
           "'PhaseT',  `PhaseT`"
           ")").alias("phase_name","prediction")
)

# Totals per phase (denominator totals)
total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)

# ---- 1) Build analysis_long from agg_once (flags) + phases (prediction)
# metrics we’ll analyze
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# phase flags per (target,disease,maxClinPhase)
phase_flags = (
    benchmark.select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

# stack phases for the records present in agg_once (feature,value specific)
phases_long_for_records = (
    phase_flags.join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
                     on=["targetId","diseaseId","maxClinPhase"], how="inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    # comparison = metric flag yes/no at (target,disease,feature,value)
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value",
                        F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, on=["targetId","diseaseId","maxClinPhase"], how="inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))  # prettier label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# ---- 2) Count distinct pairs for 2x2 components using the fixed universe
# a = count of pairs with comparison=='yes' AND prediction=='yes'
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
# yes_total = count of pairs with comparison=='yes' (regardless of prediction)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

# Assemble b,c,d from totals
counts = (
    yes_total
    .join(yes_yes, on=["metric","feature","value","phase_name"], how="left")
    .join(total_pairs_by_phase, on="phase_name", how="left")
    .join(total_pred_yes_by_phase, on="phase_name", how="left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.when(F.col("a")<0,0).otherwise(F.col("a")).cast("int").alias("a"),
        F.when(F.col("b")<0,0).otherwise(F.col("b")).cast("int").alias("b"),
        F.when(F.col("c")<0,0).otherwise(F.col("c")).cast("int").alias("c"),
        F.when(F.col("d")<0,0).otherwise(F.col("d")).cast("int").alias("d"),
        "total_pairs","total_pred_yes"
    )
)

# Convert to two-row format (comparison yes/no) with columns yes/no → ready for Fisher
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
)

# Safety: ensure ints and no nulls
mat_counts = (
    mat_counts.fillna(0)
              .withColumn("yes", F.col("yes").cast("int"))
              .withColumn("no",  F.col("no").cast("int"))
)

#### LAST ITERATION


In [1]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings (Single-Node Option A)
# -------------------------------
driver_memory = "24g"                 # plenty for planning & small collects

executor_cores = 4                    # sweet spot for GC + Python workers
num_executors  = 12                   # 12 * 4 = 48 cores for executors; ~16 cores left for driver/OS
executor_memory = "32g"               # per executor heap
executor_memory_overhead = "8g"       # ~20% overhead for PySpark/Arrow/off-heap

# Totals: (32+8) * 12 = 480 GB executors + 24 GB driver ≈ 504 GB (adjust down if your hard cap is <500 GB)
# If you must stay strictly ≤ 500 GB, use executor_memory="30g", overhead="6g"  → (36 * 12) + 24 = 456 + 24 = 480 GB

shuffle_partitions   = 192            # ≈ 2–4× total cores (48) → start with 192
default_parallelism  = 192

# If you later move to a multi-worker cluster, replace the values above.

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    # core resources
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    # shuffle & parallelism
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    # adaptive query execution for better skew/partition sizing
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")



'''
# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "16g"
executor_memory = "32g"
executor_cores = "8"
num_executors = "16"
executor_memory_overhead = "8g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")
'''
# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)
benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()

test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    ).withColumn(
    "hasGenetics2",
    F.when(
        reduce(lambda acc, c: acc & F.col(c).isNull(), doe_cols[1:], F.col(doe_cols[0]).isNull()),
        F.lit("no")
    ).otherwise(F.lit("yes"))
)
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no")) #### we have to change it
)
#test2.persist()

# ---------- Guard: (re)build agg_once if not defined ----------
import pyspark.sql.functions as F

def _build_agg_once_from_test2_and_benchmark(test2, benchmark_df):
    # Columns we keep across all longified slices
    common_cols = [
        "targetId","diseaseId","maxClinPhase",
        "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
        "NoneCellYes","NdiagonalYes","hasGenetics2"  # note: hasGenetics2 from your test2
    ]

    # Join phase flags once
    phase_flags = (
        benchmark_df.select(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"])
    )

    t2_with_phase = test2.join(
        phase_flags, on=["targetId","diseaseId","maxClinPhase"], how="left"
    )

    # actionType2 is ARRAY<STRING> → explode
    long_action = (
        t2_with_phase
        .select(*common_cols, F.explode_outer("actionType2").alias("value"))
        .withColumn("feature", F.lit("actionType2"))
        .select(*common_cols, "feature", "value")
    )

    # helper for scalar columns
    def longify_scalar(colname: str):
        return (
            t2_with_phase
            .select(*common_cols, F.col(colname).alias("value"))
            .withColumn("feature", F.lit(colname))
            .select(*common_cols, "feature", "value")
        )

    long_biosample = longify_scalar("biosampleName")
    long_project   = longify_scalar("projectId")
    long_rstype    = longify_scalar("rightStudyType")
    long_colocm    = longify_scalar("colocalisationMethod")

    # union into one tall table
    long_features = (
        long_action
        .unionByName(long_biosample)
        .unionByName(long_project)
        .unionByName(long_rstype)
        .unionByName(long_colocm)
    ).filter(F.col("value").isNotNull())

    # single aggregation to compute flags
    agg_once_local = (
        long_features
        .groupBy(
            "targetId","diseaseId","maxClinPhase",
            "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
            "feature","value"
        )
        .agg(
            F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
            F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
            F.max(F.when(F.col("hasGenetics2")=="yes", 1).otherwise(0)).alias("hasGenetics"),
        )
        .selectExpr(
            "*",
            "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
            "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
            "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
        )
    )
    return agg_once_local

if 'agg_once' not in globals():
    print("[info] agg_once not found — rebuilding it from test2/benchmark …")
    agg_once = _build_agg_once_from_test2_and_benchmark(test2, benchmark)
    print("[info] agg_once rebuilt.")

'''
# ============================
# Denominator = ALL pairs in analysis_chembl_indication (deduped)
# Build 2x2 counts using totals, then Fisher via applyInPandas
# ============================
from datetime import date
import pandas as pd
import numpy as np
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio

# ---- 0) Universe of pairs & phase flags (only de-dup, no other filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")  # dedupe on these
    .distinct()
    .join(F.broadcast(negativeTD), on=["targetId","diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)
print('universe of pairs and phase flags built')

# Long view of phase flags for universe
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr("stack(5, "
           "'Phase>=4', `Phase>=4`, "
           "'Phase>=3', `Phase>=3`, "
           "'Phase>=2', `Phase>=2`, "
           "'Phase>=1', `Phase>=1`, "
           "'PhaseT',  `PhaseT`"
           ")").alias("phase_name","prediction")
)
print('phase_universe_long built')

# Totals per phase (denominator totals)
total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)
print('phase_universe_long built')

# ---- 1) Build analysis_long from agg_once (flags) + phases (prediction)
# metrics we’ll analyze
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# phase flags per (target,disease,maxClinPhase)
phase_flags = (
    benchmark.select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

# stack phases for the records present in agg_once (feature,value specific)
phases_long_for_records = (
    phase_flags.join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
                     on=["targetId","diseaseId","maxClinPhase"], how="inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    # comparison = metric flag yes/no at (target,disease,feature,value)
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value",
                        F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, on=["targetId","diseaseId","maxClinPhase"], how="inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))  # prettier label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# ---- 2) Count distinct pairs for 2x2 components using the fixed universe
# a = count of pairs with comparison=='yes' AND prediction=='yes'
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
# yes_total = count of pairs with comparison=='yes' (regardless of prediction)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

# Assemble b,c,d from totals
counts = (
    yes_total
    .join(yes_yes, on=["metric","feature","value","phase_name"], how="left")
    .join(total_pairs_by_phase, on="phase_name", how="left")
    .join(total_pred_yes_by_phase, on="phase_name", how="left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.when(F.col("a")<0,0).otherwise(F.col("a")).cast("int").alias("a"),
        F.when(F.col("b")<0,0).otherwise(F.col("b")).cast("int").alias("b"),
        F.when(F.col("c")<0,0).otherwise(F.col("c")).cast("int").alias("c"),
        F.when(F.col("d")<0,0).otherwise(F.col("d")).cast("int").alias("d"),
        "total_pairs","total_pred_yes"
    )
)

# Convert to two-row format (comparison yes/no) with columns yes/no → ready for Fisher
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
)

# Safety: ensure ints and no nulls
mat_counts = (
    mat_counts.fillna(0)
              .withColumn("yes", F.col("yes").cast("int"))
              .withColumn("no",  F.col("no").cast("int"))
)

# ---- 3) Fisher per group with applyInPandas
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(matrix_2x2: np.ndarray):
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    if total == 0:
        # Return a neutral row to avoid SciPy errors on empty tables
        return pd.DataFrame([{
            "group":        pdf["metric"].iloc[0],
            "comparison":   f"{pdf['value'].iloc[0]}_only",
            "phase":        pdf["phase_name"].iloc[0],
            "oddsRatio":    float("nan"),
            "pValue":       float("nan"),
            "lowerInterval":float("nan"),
            "upperInterval":float("nan"),
            "total":        "0",
            "values":       mat.tolist(),
            "relSuccess":   float("nan"),
            "rsLower":      float("nan"),
            "rsUpper":      float("nan"),
            "path":         "",
        }])

    from scipy.stats import fisher_exact
    from scipy.stats.contingency import odds_ratio

    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# (optional) Arrow for speed
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# ---- 4) Spreadsheet formatting + annotation + CSV
from itertools import chain
from pyspark.sql.functions import create_map

# build disdic from agg_once
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")
today_date = date.today().isoformat()

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
'''


spark session created at 2025-09-17 18:35:42.737737
Analysis started on 2025-09-17 at  2025-09-17 18:35:42.737737


25/09/17 18:35:48 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/17 18:35:48 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created with:
  spark.driver.memory: 24g
  spark.executor.memory: 32g
  spark.executor.cores: 4
  spark.executor.instances: 12
  spark.yarn.executor.memoryOverhead: 8g
  spark.sql.shuffle.partitions: 192
  spark.default.parallelism: 192
  spark.sql.adaptive.enabled: true
  spark.sql.adaptive.coalescePartitions.enabled: true
Spark UI: http://jr-doe-temp1-m.c.open-targets-eu-dev.internal:46761
Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc
Built temporary DoE datasets


                                                                                

Built analysis_chembl_indication
[info] agg_once not found — rebuilding it from test2/benchmark …
[info] agg_once rebuilt.




#### second part

In [None]:

# ============================
# Simple, single-node Fisher pipeline
# Denominator = ALL (target,disease,maxClinPhase) in analysis_chembl_indication
# ============================
from datetime import date
import numpy as np
import pandas as pd

import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)

# ---- 0) Universe of pairs & phase flags (dedupe only; no extra filtering)
universe = (
    analysis_chembl_indication
    .select("targetId", "diseaseId", "maxClinPhase")
    .distinct()
    .join(negativeTD.select("targetId","diseaseId","stopReason"), ["targetId","diseaseId"], "left")
    .withColumn("PhaseT",  F.when(F.col("stopReason")=="Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4",F.when((F.col("maxClinPhase")==4) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=3",F.when((F.col("maxClinPhase")>=3) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=2",F.when((F.col("maxClinPhase")>=2) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
    .withColumn("Phase>=1",F.when((F.col("maxClinPhase")>=1) & (F.col("PhaseT")=="no"), "yes").otherwise("no"))
)

# Long view of phase flags for universe (used for denominators)
phases_universe_long = universe.select(
    "targetId","diseaseId",
    F.expr(
      "stack(5,"
      " 'Phase>=4', `Phase>=4`,"
      " 'Phase>=3', `Phase>=3`,"
      " 'Phase>=2', `Phase>=2`,"
      " 'Phase>=1', `Phase>=1`,"
      " 'PhaseT',  `PhaseT`)"
    ).alias("phase_name","prediction")
)

total_pairs_by_phase = (
    phases_universe_long
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pairs"))
)
total_pred_yes_by_phase = (
    phases_universe_long
    .filter(F.col("prediction")=="yes")
    .groupBy("phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("total_pred_yes"))
)

# ---- 1) Build analysis_long = (metric flag yes/no) × (phase yes/no)
# We assume 'agg_once' exists with flags per (targetId,diseaseId,maxClinPhase,feature,value)
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

phase_flags = (
    benchmark
    .select("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
    .dropDuplicates(["targetId","diseaseId","maxClinPhase"])
)

phases_long_for_records = (
    phase_flags
    .join(agg_once.select("targetId","diseaseId","maxClinPhase").dropDuplicates(),
          ["targetId","diseaseId","maxClinPhase"], "inner")
    .select(
        "targetId","diseaseId","maxClinPhase",
        F.expr(
          "stack(5,"
          " 'Phase>=4', `Phase>=4`,"
          " 'Phase>=3', `Phase>=3`,"
          " 'Phase>=2', `Phase>=2`,"
          " 'Phase>=1', `Phase>=1`,"
          " 'PhaseT',  `PhaseT`)"
        ).alias("phase_name","prediction")
    )
)

def attach_metric(metric_col: str):
    return (
        agg_once
        .select("targetId","diseaseId","maxClinPhase","feature","value",
                F.col(metric_col).alias("comparison"))
        .join(phases_long_for_records, ["targetId","diseaseId","maxClinPhase"], "inner")
        .withColumn("metric", F.lit(metric_col.replace("_flag","")))
    )

analysis_long = attach_metric(metric_flags[0])
for m in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(m))

# ---- 2) Build 2×2 counts with the fixed denominators
# a = yes/yes, b = yes/no, c = no/yes, d = no/no
yes_yes = (
    analysis_long
    .filter((F.col("comparison")=="yes") & (F.col("prediction")=="yes"))
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("a"))
)
yes_total = (
    analysis_long
    .filter(F.col("comparison")=="yes")
    .groupBy("metric","feature","value","phase_name")
    .agg(F.countDistinct(F.struct("targetId","diseaseId")).alias("yes_total"))
)

counts = (
    yes_total
    .join(yes_yes, ["metric","feature","value","phase_name"], "left")
    .join(total_pairs_by_phase, ["phase_name"], "left")
    .join(total_pred_yes_by_phase, ["phase_name"], "left")
    .na.fill({"a":0})
    .withColumn("b", F.col("yes_total") - F.col("a"))
    .withColumn("c", F.col("total_pred_yes") - F.col("a"))
    .withColumn("d", F.col("total_pairs") - F.col("a") - F.col("b") - F.col("c"))
    .select(
        "metric","feature","value","phase_name",
        F.greatest(F.lit(0), F.col("a")).cast("int").alias("a"),
        F.greatest(F.lit(0), F.col("b")).cast("int").alias("b"),
        F.greatest(F.lit(0), F.col("c")).cast("int").alias("c"),
        F.greatest(F.lit(0), F.col("d")).cast("int").alias("d")
    )
)

# Convert to two-row format (comparison yes/no) for applyInPandas
mat_counts = (
    counts
    .select("metric","feature","value","phase_name",
            F.lit("yes").alias("comparison"),
            F.col("a").alias("yes"),
            F.col("b").alias("no"))
    .unionByName(
        counts.select("metric","feature","value","phase_name",
                      F.lit("no").alias("comparison"),
                      F.col("c").alias("yes"),
                      F.col("d").alias("no"))
    )
    .fillna(0)
    .withColumn("yes", F.col("yes").cast("int"))
    .withColumn("no",  F.col("no").cast("int"))
)

# ---- 3) Fisher per (metric,feature,value,phase)
result_schema = StructType([
    StructField("group",        StringType(),  True),
    StructField("comparison",   StringType(),  True),
    StructField("phase",        StringType(),  True),
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

def _relative_success(mat: np.ndarray):
    a, b = mat[0,0], mat[0,1]
    c, d = mat[1,0], mat[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    # simple Wald CI
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)
    total = int(mat.sum())

    if total == 0:
        return pd.DataFrame([{
            "group":      pdf["metric"].iloc[0],
            "comparison": f"{pdf['value'].iloc[0]}_only",
            "phase":      pdf["phase_name"].iloc[0],
            "oddsRatio":  float("nan"),
            "pValue":     float("nan"),
            "lowerInterval": float("nan"),
            "upperInterval": float("nan"),
            "total":      "0",
            "values":     mat.tolist(),
            "relSuccess": float("nan"),
            "rsLower":    float("nan"),
            "rsUpper":    float("nan"),
            "path":       "",
        }])

    from scipy.stats import fisher_exact
    from scipy.stats.contingency import odds_ratio

    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    return pd.DataFrame([{
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    round(float(or_val), 2),
        "pValue":       float(p_val),
        "lowerInterval":round(float(ci[0]), 2),
        "upperInterval":round(float(ci[1]), 2),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   round(float(rs), 2),
        "rsLower":      round(float(rs_lo), 2),
        "rsUpper":      round(float(rs_hi), 2),
        "path":         "",
    }])

# Keep Arrow on; limit batch size for stability on one node
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "20000")

results_df = (
    mat_counts
    .repartition(96, "metric", "phase_name")  # moderate parallelism on single node
    .groupBy("metric","feature","value","phase_name")
    .applyInPandas(fisher_by_group, schema=result_schema)
)

# ---- 4) Write out (Parquet → CSV shards); no toPandas
today_date = date.today().isoformat()
base = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_denominator_allPairs"

parquet_path = base + "_parquet"
csv_path     = base + "_csv"

# Stage to parquet (break lineage)
'''
(results_df
  .repartition(96, "group", "phase")   # smallish files, even spread
  .write.mode("overwrite")
  .parquet(parquet_path))

# Read back and spill to CSV shards
prev = spark.conf.get("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false")
(spark.read.parquet(parquet_path)
  .repartition(96, "group", "phase")
  .write.mode("overwrite")
  .option("header","true")
  .option("maxRecordsPerFile","250000")
  .csv(csv_path))
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", prev)
'''
print(f"Wrote results to:\n  Parquet: {parquet_path}\n  CSV:     {csv_path}")


25/09/17 18:40:42 ERROR AsyncEventQueue: Dropping event from queue appStatus. This likely means one of the listeners is too slow and cannot keep up with the rate at which tasks are being started by the scheduler.
25/09/17 18:40:42 WARN AsyncEventQueue: Dropped 1 events from appStatus since the application started.
ERROR:root:Exception while sending command.+ 0) / 192][Stage 102:(127 + 3) / 192]
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/l

Py4JError: An error occurred while calling o2521.parquet

In [2]:
today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")

25/09/17 17:35:09 ERROR AsyncEventQueue: Dropping event from queue appStatus. This likely means one of the listeners is too slow and cannot keep up with the rate at which tasks are being started by the scheduler.
25/09/17 17:35:09 WARN AsyncEventQueue: Dropped 1 events from appStatus since the application started.
ERROR:root:Exception while sending command.+ 2) / 192][Stage 125:(182 + 2) / 192]
Traceback (most recent call last):
  File "/usr/lib/spark/python/pyspark/sql/pandas/conversion.py", line 280, in _collect_as_arrow
    results = list(batch_stream)
              ^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/pyspark/sql/pandas/serializers.py", line 69, in load_stream
    for batch in self.serializer.load_stream(stream):
  File "/usr/lib/spark/python/pyspark/sql/pandas/serializers.py", line 111, in load_stream
    reader = pa.ipc.open_stream(stream)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/miniconda3/lib/python3.11/site-packages/pyarrow/ipc.py", line 190, in o

Py4JError: An error occurred while calling o2980.getResult

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving


In [2]:

(out_path, coalesce_n) = (f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try", 1)
print('preparing df_annot to write')
(df_annot
  .coalesce(coalesce_n)                 # 1 file if feasible; increase if OOM on shuffle
  .write.mode("overwrite")
  .option("header", "true")
  .csv(out_path))

print(f"Wrote CSV shards to: {out_path}")

preparing df_annot to write


25/09/17 17:23:09 ERROR AsyncEventQueue: Dropping event from queue appStatus. This likely means one of the listeners is too slow and cannot keep up with the rate at which tasks are being started by the scheduler.
25/09/17 17:23:09 WARN AsyncEventQueue: Dropped 1 events from appStatus since the application started.
ERROR:root:KeyboardInterrupt while sending command.28][Stage 137:(126 + -1) / 128]
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/miniconda3/lib/python3.11/socket.py", line 706, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
Keyboa

KeyboardInterrupt: 

In [None]:
prev = spark.conf.get("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false")


out_path    = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_try"
parquet_path = out_path + "_parquet"

writer_parts = 64   # try 48/64/80 depending on size

# Stage 1: write Parquet
(df_annot
  #.repartition(writer_parts, "metric", "phase")   # pick keys you have; omit if not present
  .write.mode("overwrite")
  .parquet(parquet_path))

# Free lineage
df_annot.unpersist()
spark.catalog.clearCache()

# Stage 2: read Parquet back and write CSV shards (no coalesce(1)!)
(spark.read.parquet(parquet_path)
  .repartition(writer_parts, "metric", "phase")
  .write.mode("overwrite")
  .option("header","true")
  .option("maxRecordsPerFile","200000")   # smaller chunks → lower peak memory
  .option("compression","gzip")           # optional; reduces IO/space
  .csv(out_path))

# Restore AQE setting
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", prev)

print(f"Wrote CSV shards under: {out_path}")


[Stage 317:> (9 + 1) / 80][Stage 318:(113 + 1) / 174][Stage 319:> (8 + 1) / 80]]]

##### compare with the code we already had

In [2]:
import time
from array import ArrayType
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
# from stoppedTrials import terminated_td
from DoEAssessment import directionOfEffect
# from membraneTargets import target_membrane
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from datetime import datetime
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    DecimalType,
    StringType,
    FloatType,
)
import pandas as pd
from functools import reduce

spark = SparkSession.builder.getOrCreate()
spark.conf.set(
    "spark.sql.shuffle.partitions", "400"
)  # Default is 200, increase if needed

print('joint groups')
path_n='gs://open-targets-data-releases/25.06/output/'

target = spark.read.parquet(f"{path_n}target/")

diseases = spark.read.parquet(f"{path_n}disease/")

evidences = spark.read.parquet(f"{path_n}evidence")

credible = spark.read.parquet(f"{path_n}credible_set")

new = spark.read.parquet(f"{path_n}colocalisation_coloc") 

index=spark.read.parquet(f"{path_n}study/")

variantIndex = spark.read.parquet(f"{path_n}variant")

biosample = spark.read.parquet(f"{path_n}biosample")

ecaviar=spark.read.parquet(f"{path_n}colocalisation_ecaviar")

all_coloc=ecaviar.unionByName(new, allowMissingColumns=True)

print("loaded files")

#### FIRST MODULE: BUILDING COLOC 
newColoc=buildColocData(all_coloc,credible,index)

print("loaded newColoc")

### SECOND MODULE: PROCESS EVIDENCES TO AVOID EXCESS OF COLUMNS 
gwasComplete = gwasDataset(evidences,credible)

#### THIRD MODULE: INCLUDE COLOC IN THE 
resolvedColoc = (
    (
        newColoc.withColumnRenamed("geneId", "targetId")
        .join(
            gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
            on=["leftStudyLocusId", "targetId"],
            how="right",
        )

        .join(  ### propagated using parent terms
            diseases.selectExpr(
                "id as diseaseId", "name", "parents", "therapeuticAreas"
            ),
            on="diseaseId",
            how="left",
        )
        .withColumn(
            "diseaseId",
            F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
        )
        .drop("parents", "oldDiseaseId")
      
    ).withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(
                ["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]
            ),
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_protect"),
            ),
        ).when(
            F.col("rightStudyType").isin(
                ["sqtl", "scsqtl"]
            ),  ### opposite directionality than sqtl
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_protect"),
            ),
        ),
    )
    # .persist()
)
print("loaded resolvedColloc")

datasource_filter = [
#   "ot_genetics_portal",
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]

assessment, evidences, actionType, oncolabel = temporary_directionOfEffect(
    path_n, datasource_filter
)

print("run temporary direction of effect")


print("built drugApproved dataset")


#### FOURTH MODULE BUILDING CHEMBL ASSOCIATIONS - HERE TAKE CARE WITH FILTERING STEP 
analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl"))
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
    # .persist()
)

####2 Define agregation function
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.types import *


def convertTuple(tup):
    st = ",".join(map(str, tup))
    return st


#####3 run in a function
def aggregations_original(
    df,
    data,
    listado,
    comparisonColumn,
    comparisonType,
    predictionColumn,
    predictionType,
    today_date,
):
    wComparison = Window.partitionBy(comparisonColumn)
    wPrediction = Window.partitionBy(predictionColumn)
    wPredictionComparison = Window.partitionBy(comparisonColumn, predictionColumn)
    results = []
    # uniqIds = df.select("targetId", "diseaseId").distinct().count()
    out = (
        df.withColumn("comparisonType", F.lit(comparisonType))
        .withColumn("dataset", F.lit(data))
        .withColumn("predictionType", F.lit(predictionType))
        # .withColumn("total", F.lit(uniqIds))
        .withColumn("a", F.count("targetId").over(wPredictionComparison))
        .withColumn("comparisonColumn", F.lit(comparisonColumn))
        .withColumn("predictionColumnValue", F.lit(predictionColumn))
        .withColumn(
            "predictionTotal",
            F.count("targetId").over(wPrediction),
        )
        .withColumn(
            "comparisonTotal",
            F.count("targetId").over(wComparison),
        )
        .select(
            F.col(predictionColumn).alias("prediction"),
            F.col(comparisonColumn).alias("comparison"),
            "dataset",
            "comparisonColumn",
            "predictionColumnValue",
            "comparisonType",
            "predictionType",
            "a",
            "predictionTotal",
            "comparisonTotal",
        )
        .filter(F.col("prediction").isNotNull())
        .filter(F.col("comparison").isNotNull())
        .distinct()
    )
    '''
    out.write.mode("overwrite").parquet(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    '''

    listado.append(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    path = "gs://ot-team/jroldan/" + str(
        today_date
        + "_"
        + "analysis/"
        + data
        # + "_propagated"
        + "/"
        + comparisonColumn
        + "_"
        + comparisonType
        + "_"
        + predictionColumn
        + ".parquet"
    )
    print(path)
    
    ### making analysis
    array1 = np.delete(
        out.join(full_data, on=["prediction", "comparison"], how="outer")
        .groupBy("comparison")
        .pivot("prediction")
        .agg(F.first("a"))
        .sort(F.col("comparison").desc())
        .select("comparison", "yes", "no")
        .fillna(0)
        .toPandas()
        .to_numpy(),
        [0],
        1,
    )
    total = np.sum(array1)
    res_npPhaseX = np.array(array1, dtype=int)
    resX = convertTuple(fisher_exact(res_npPhaseX, alternative="two-sided"))
    resx_CI = convertTuple(
        odds_ratio(res_npPhaseX).confidence_interval(confidence_level=0.95)
    )

    result_st.append(resX)
    result_ci.append(resx_CI)
    (rs_result, rs_ci) = relative_success(array1)
    results.extend(
        [
            comparisonType,
            comparisonColumn,
            predictionColumn,
            round(float(resX.split(",")[0]), 2),
            float(resX.split(",")[1]),
            round(float(resx_CI.split(",")[0]), 2),
            round(float(resx_CI.split(",")[1]), 2),
            str(total),
            np.array(res_npPhaseX).tolist(),
            round(float(rs_result), 2),
            round(float(rs_ci[0]), 2),
            round(float(rs_ci[1]), 2),
            # studies,
            # tissues,
            path,
        ]
    )
    return results


#### 3 Loop over different datasets (as they will have different rows and columns)


def comparisons_df_iterative(elements):
    # toAnalysis = [(key, value) for key, value in disdic.items() if value == projectId]
    toAnalysis = [(col, "predictor") for col in elements]
    schema = StructType(
        [
            StructField("comparison", StringType(), True),
            StructField("comparisonType", StringType(), True),
        ]
    )

    comparisons = spark.createDataFrame(toAnalysis, schema=schema)
    ### include all the columns as predictor

    predictions = spark.createDataFrame(
        data=[
            ("Phase>=4", "clinical"),
            ('Phase>=3','clinical'),
            ('Phase>=2','clinical'),
            ('Phase>=1','clinical'),
            ("PhaseT", "clinical"),
        ]
    )
    return comparisons.join(predictions, how="full").collect()


print("load comparisons_df_iterative function")


full_data = spark.createDataFrame(
    data=[
        ("yes", "yes"),
        ("yes", "no"),
        ("no", "yes"),
        ("no", "no"),
    ],
    schema=StructType(
        [
            StructField("prediction", StringType(), True),
            StructField("comparison", StringType(), True),
        ]
    ),
)
print("created full_data and lists")

#rightTissue = spark.read.csv(
#    'gs://ot-team/jroldan/analysis/20250526_rightTissue.csv',
#    header=True,
#).drop("_c0")

print("loaded rightTissue dataset")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

print("built negativeTD dataset")

print("built bench2 dataset")

###### cut from here
print("looping for variables_study")

#### new part with chatgpt -- TEST

## QUESTIONS TO ANSWER:
# HAVE ECAVIAR >=0.8
# HAVE COLOC 
# HAVE COLOC >= 0.8
# HAVE COLOC + ECAVIAR >= 0.01
# HAVE COLOC >= 0.8 + ECAVIAR >= 0.01
# RIGHT JOING WITH CHEMBL 

### FIFTH MODULE: BUILDING BENCHMARK OF THE DATASET TO EXTRACT EHE ANALYSIS 

resolvedColocFiltered = resolvedColoc.filter((F.col('clpp')>=0.01) | (F.col('h4')>=0.8))
benchmark = (
    (
        resolvedColocFiltered.filter(F.col("betaGwas") < 0).filter(
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")

#bench2 = benchmark.join(
#    rightTissue, on=["name", "bioSampleName"], how="left"
#).withColumn(
#    "rightTissue",
#    F.when(F.col("rightTissue1") == "yes", F.lit("yes")).otherwise(F.lit("no")),
#)

print("built benchmark dataset")

## write the benchmark 
#name='benchmark'
#output_partitioned_path = f"gs://ot-team/jroldan/analysis/parquetFiles/{name}"
#benchmark.write.mode("overwrite").parquet(output_partitioned_path)
#print(f'written {name}')
#### Analysis

#### 1 Build a dictionary with the distinct values as key and column names as value
#variables_study = ["projectId", "biosampleName", "rightStudyType", "colocDoE","colocalisationMethod"]
variables_study = ["projectId"]

# List to hold temporary DataFrames
temp_dfs_for_union = []

# Iterate over the column names to prepare DataFrames for union
for col_name in variables_study:
    # Select the current column, alias it to 'distinct_value' for consistent schema
    # Filter out nulls, then get distinct values
    # Add a literal column with the original 'col_name'
    df_temp = (
        benchmark.select(F.col(col_name).alias("distinct_value"))
        .filter(F.col("distinct_value").isNotNull()) # Exclude None (null) values
        .distinct()
        .withColumn("column_name", F.lit(col_name))
    )
    temp_dfs_for_union.append(df_temp)

disdic = {}

if temp_dfs_for_union:
    # Union all the temporary DataFrames.
    # unionByName is crucial to handle potential schema differences (e.g., if columns have same name but different types)
    # and ensures columns are matched by name.
    combined_distinct_values_df = temp_dfs_for_union[0]
    for i in range(1, len(temp_dfs_for_union)):
        combined_distinct_values_df = combined_distinct_values_df.unionByName(temp_dfs_for_union[i])

    # Now, collect the combined distinct values.
    # This is a single collect operation on the aggregated DataFrame.
    print("Collecting combined distinct values from the cluster...")
    collected_rows = combined_distinct_values_df.collect()

    # Populate the dictionary from the collected rows
    for row in collected_rows:
        disdic[row.distinct_value] = row.column_name
else:
    print("variables_study list is empty, disdic will be empty.")


print("\nFinal disdic:", disdic)

# Assuming 'spark' session, 'benchmark' DataFrame, 'negativeTD' DataFrame, and 'disdic' dictionary are defined

# --- Step 1: Pre-compute 'hasboth' ONCE ---
# This is a shuffle, but only happens once.
print("Pre-computing 'hasboth' column...")
window_target_disease_only = Window.partitionBy('targetId', 'diseaseId')
benchmark_processed = benchmark.withColumn(
    'hasboth',
    F.size(F.collect_set('colocalisationMethod').over(window_target_disease_only))
)

# You might consider caching this intermediate result if 'benchmark' is very large
# and you have enough memory, to avoid re-reading from source if possible.
# benchmark_processed.cache() # or .persist(StorageLevel.MEMORY_AND_DISK)
# benchmark_processed.count() # Force computation if you cache

pivoted_dfs = {}

# --- Step 2: Loop for each variable_study column ---
for col_name in variables_study:
    print(f"Processing pivot for: {col_name}")

    # Define window specs for the current iteration, including 'col_name' in partition
    # (This shuffle is still per iteration, but unavoidable if 'resolvedAgreeDrug' depends on 'col_name' values)
    current_col_window_spec_qtl = Window.partitionBy("targetId", "diseaseId", col_name).orderBy(F.col("qtlPValueExponent").asc())
    current_col_pvalue_order_window = Window.partitionBy("targetId", "diseaseId", col_name).orderBy(F.col('colocalisationMethod').asc(), F.col("qtlPValueExponent").asc())

    # Calculate 'resolvedAgreeDrug' for the current 'col_name'
    # This involves a shuffle per iteration.
    temp_df_with_resolved = benchmark_processed.withColumn('resolvedAgreeDrug',
        F.when(F.col('hasboth') > 1,
            F.first(F.col('AgreeDrug'), ignorenulls=True).over(current_col_pvalue_order_window)
        ).otherwise(F.first(F.col('AgreeDrug'), ignorenulls=True).over(current_col_window_spec_qtl))
    )

    # --- Step 3: Perform the pivot and join ---
    # This is an expensive operation (shuffle, potential wide dataframe)
    pivoted_df = (
        temp_df_with_resolved
        .groupBy(
            "targetId",
            "diseaseId",
            "maxClinPhase",
        )
        .pivot(col_name) # Pivoting on values of the 'col_name' column
        .agg(F.collect_set("resolvedAgreeDrug"))
        .join(negativeTD, on=["targetId", "diseaseId"], how="left") # Ensure negativeTD is broadcast if small
    )

    # --- Step 4: Add derived columns (these are generally cheap) ---
    for phase in [1, 2, 3, 4]:
        pivoted_df = pivoted_df.withColumn(
            f"Phase>={phase}",
            F.when(F.col("maxClinPhase") >= phase, F.lit("yes")).otherwise(F.lit("no")),
        )

    pivoted_df = pivoted_df.withColumn(
        "PhaseT",
        F.when(F.col("stopReason") == "Negative", F.lit("yes")).otherwise(F.lit("no")),
    ).withColumn(
        "Phase>=4",
        F.when(
            (F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "Phase>=3",
        F.when(
            (F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "Phase>=2",
        F.when(
            (F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "Phase>=1",
        F.when(
            (F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    )

    # Add _only columns dynamically based on disdic values matching current column
    matching_keys = [key for key, val in disdic.items() if val == col_name]

    for key in matching_keys:
        # F.col(key) assumes 'key' refers to a column that exists in pivoted_df after the pivot.
        pivoted_df = pivoted_df.withColumn(
            f"{key}_only",
            F.when(F.array_contains(F.col(key), "yes"), F.lit("yes")).otherwise(F.lit("no")),
        )

### making columns for the 

    # --- Step 5: Store result. Consider writing to GCS to break lineage if memory is an issue ---
    # This is highly recommended if 'variables_study' is very large.
    # Write to Parquet for efficient storage and schema preservation.
    # output_path = f"gs://your-bucket/temp_pivoted_results/{col_name}"
    # print(f"Writing results for {col_name} to {output_path}")
    # pivoted_df.write.mode("overwrite").parquet(output_path)
    # pivoted_dfs[col_name] = spark.read.parquet(output_path) # Read back if needed later
    # output_partitioned_path = f"gs://ot-team/jroldan/analysis/parquetFiles/pivoted_df_{col_name}"
    # pivoted_df.write.mode("overwrite").parquet(output_partitioned_path)
    # print(f"DataFrame successfully written and partitioned to {output_partitioned_path}")
    # If not writing to GCS, just store the DF in memory (be cautious for large number of DFs)

    pivoted_dfs[col_name] = pivoted_df

##### PROJECTID
project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
main=['GTEx_only', 'UKB_PPP_EUR_only']
#stimulated=['Alasoo_2018_only','Cytoimmgen_only','Fairfax_2014_only','Kim-Hellmuth_2017_only','Nathan_2022_only','Nedelec_2016_only','Quach_2016_only','Randolph_2021_only','Schmiedel_2018_only']
cellLine=['CAP_only','HipSci_only','iPSCORE_only','Jerber_2021_only','PhLiPS_only','Schwartzentruber_2018_only','TwinsUK_only']
stimulated=['Schmiedel_2018_only','Bossini-Castillo_2019_only','Alasoo_2018_only','Cytoimmgen_only','Gilchrist_2021_only','CAP_only','Quach_2016_only','Randolph_2021_only','Sun_2018_only','Nedelec_2016_only','Kim-Hellmuth_2017_only']
others=[item for item in project_keys if item not in main]
nonStimulated=[item for item in project_keys if item not in stimulated]
otherCellLine=[item for item in project_keys if item not in cellLine]

# First condition: any "yes" in list1
condition1 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), others[1:], F.col(others[0]) == "yes")
# estimulated
condition2 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), stimulated[1:], F.col(stimulated[0]) == "yes")
## non estimulated:
condition3 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonStimulated[1:], F.col(nonStimulated[0]) == "yes")
# cellLine
condition4 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), cellLine[1:], F.col(cellLine[0]) == "yes")
# non cellline
condition5 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), otherCellLine[1:], F.col(otherCellLine[0]) == "yes")
# non cellline
condition6 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), main[1:], F.col(main[0]) == "yes")

# Add both columns
pivoted_dfs['projectId'] = pivoted_dfs['projectId'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
pivoted_dfs['projectId'] = pivoted_dfs['projectId'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
pivoted_dfs['projectId'] = pivoted_dfs['projectId'].withColumn("nonStimulated_only", F.when(condition3, "yes").otherwise("no")) 
pivoted_dfs['projectId'] = pivoted_dfs['projectId'].withColumn("cellLine", F.when(condition4, "yes").otherwise("no")) 
pivoted_dfs['projectId'] = pivoted_dfs['projectId'].withColumn("nonCellLine", F.when(condition5, "yes").otherwise("no")) 
pivoted_dfs['projectId'] = pivoted_dfs['projectId'].withColumn("GTExUKB", F.when(condition6, "yes").otherwise("no")) 


###append to dictionary

disdic.update({'othersProjectId': 'projectId','Stimulated': 'projectId','cellLine': 'projectId', 'othersBiosampleName_only': 'biosampleName', 'otherRightStudyType':'rightStudyType'})


result = []
result_st = []
result_ci = []
array2 = []
listado = []
result_all = []
today_date = str(date.today())


spark session created at 2025-09-17 20:24:13.866899
Analysis started on 2025-09-17 at  2025-09-17 20:24:13.866899


25/09/17 20:24:18 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


joint groups
loaded files
loaded newColoc


                                                                                

loaded gwasComplete
loaded resolvedColloc
run temporary direction of effect
built drugApproved dataset


                                                                                

load comparisons_df_iterative function
created full_data and lists
loaded rightTissue dataset
built negativeTD dataset
built bench2 dataset
looping for variables_study
built benchmark dataset
Collecting combined distinct values from the cluster...


                                                                                


Final disdic: {'HipSci': 'projectId', 'van_de_Bunt_2015': 'projectId', 'GTEx': 'projectId', 'Lepik_2017': 'projectId', 'Bossini-Castillo_2019': 'projectId', 'ROSMAP': 'projectId', 'BLUEPRINT': 'projectId', 'TwinsUK': 'projectId', 'FUSION': 'projectId', 'Cytoimmgen': 'projectId', 'Gilchrist_2021': 'projectId', 'PhLiPS': 'projectId', 'Fairfax_2014': 'projectId', 'BrainSeq': 'projectId', 'GEUVADIS': 'projectId', 'Kim-Hellmuth_2017': 'projectId', 'Schmiedel_2018': 'projectId', 'Peng_2018': 'projectId', 'CEDAR': 'projectId', 'Nathan_2022': 'projectId', 'UKB_PPP_EUR': 'projectId', 'Quach_2016': 'projectId', 'iPSCORE': 'projectId', 'Jerber_2021': 'projectId', 'Alasoo_2018': 'projectId', 'Perez_2022': 'projectId', 'CommonMind': 'projectId', 'CAP': 'projectId', 'Walker_2019': 'projectId', 'GENCORD': 'projectId', 'Nedelec_2016': 'projectId', 'Steinberg_2020': 'projectId', 'OneK1K': 'projectId', 'Fairfax_2012': 'projectId', 'Aygun_2021': 'projectId', 'Schwartzentruber_2018': 'projectId', 'Kasela

                                                                                

In [None]:
columns_to_aggregate = ['NoneCellYes', 'NdiagonalYes','hasGenetics'] # The values you want to collect in the pivoted cells
all_pivoted_dfs = {}

doe_columns=["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]
diagonal_lof=['LoF_protect','GoF_risk']
diagonal_gof=['LoF_risk','GoF_protect']

conditions = [
    F.when(F.col(c) == F.col("maxDoE"), F.lit(c)).otherwise(F.lit(None)) for c in doe_columns
    ]

# --- Nested Loops for Dynamic Pivoting ---
for agg_col_name in columns_to_aggregate:
    for pivot_col_name in columns_to_pivot_on:
        print(f"\n--- Creating DataFrame for Aggregation: '{agg_col_name}' and Pivot: '{pivot_col_name}' ---")
        current_col_pvalue_order_window = Window.partitionBy("targetId", "diseaseId", "maxClinPhase", pivot_col_name).orderBy(F.col('colocalisationMethod').asc(), F.col("qtlPValueExponent").asc())
        test2=discrepancifier(benchmark.withColumn('actionType2', F.concat_ws(",", F.col("actionType2"))).withColumn('qtlColocDoE',F.first('colocDoE').over(current_col_pvalue_order_window)).groupBy(
        "targetId", "diseaseId", "maxClinPhase", "drugLoF_protect", "drugGoF_protect",pivot_col_name)
        .pivot("colocDoE")
        .count()
        .withColumnRenamed('drugLoF_protect', 'LoF_protect_ch')
        .withColumnRenamed('drugGoF_protect', 'GoF_protect_ch')).withColumn( ## .filter(F.col('coherencyDiagonal')!='noEvid')
    "arrayN", F.array(*[F.col(c) for c in doe_columns])
    ).withColumn(
        "maxDoE", F.array_max(F.col("arrayN"))
    ).withColumn("maxDoE_names", F.array(*conditions)
    ).withColumn("maxDoE_names", F.expr("filter(maxDoE_names, x -> x is not null)")
    ).withColumn(
        "NoneCellYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("LoF_protect")))==True, F.lit('yes'))
        .when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("GoF_protect")))==True, F.lit('yes')
            ).otherwise(F.lit('no'))  # If the value is null, return null # Otherwise, check if name is in array
    ).withColumn(
        "NdiagonalYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_lof]))) > 0),
            F.lit("yes")
        ).when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_gof]))) > 0),
            F.lit("yes")
        ).otherwise(F.lit('no'))
    ).withColumn(
        "drugCoherency",
        F.when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("dispar")
        )
        .otherwise(F.lit("other")),
    ).join(negativeTD, on=["targetId", "diseaseId"], how="left").withColumn(
        "PhaseT",
        F.when(F.col("stopReason") == "Negative", F.lit("yes")).otherwise(F.lit("no")),
    ).withColumn(
        "phase4Clean",
        F.when(
            (F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase3Clean",
        F.when(
            (F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase2Clean",
        F.when(
            (F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase1Clean",
        F.when(
            (F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "hasGenetics",
        F.when(F.col("coherencyDiagonal") != "noEvid", F.lit("yes")).otherwise(F.lit("no")),
    )

In [5]:
temp_df_with_resolved.show()



+-----------+---------------+-------------+----------------+------------+-----------------+----------+--------------+--------------------------+----+--------------------+--------------------+----+----+----+----+----+-----------+-------------+---------------------+--------------+----------------------+-----------------+----------+---------+--------------+---------+------------+----------+-------------------------+-------------+------------------+----+-----+--------+-------+---------+--------+--------------+----+----------------+--------+------------+---------------+---------------+---------+-------------+-------+-----------------+
|biosampleId|       targetId|    diseaseId|leftStudyLocusId|rightStudyId|rightStudyLocusId|chromosome|rightStudyType|numberColocalisingVariants|clpp|colocalisationMethod|betaRatioSignAverage|  h0|  h1|  h2|  h3|  h4|leftStudyId|leftVariantId|credibleLeftStudyType|rightVariantId|credibleRightStudyType|qtlPValueExponent|isTransQtl|projectId|indexStudyType|condit

                                                                                

In [3]:
pivoted_dfs['projectId'].persist()

DataFrame[targetId: string, diseaseId: string, maxClinPhase: double, null: array<string>, Alasoo_2018: array<string>, Aygun_2021: array<string>, BLUEPRINT: array<string>, Bossini-Castillo_2019: array<string>, BrainSeq: array<string>, Braineac2: array<string>, CAP: array<string>, CEDAR: array<string>, CommonMind: array<string>, Cytoimmgen: array<string>, FUSION: array<string>, Fairfax_2012: array<string>, Fairfax_2014: array<string>, GENCORD: array<string>, GEUVADIS: array<string>, GTEx: array<string>, Gilchrist_2021: array<string>, HipSci: array<string>, Jerber_2021: array<string>, Kasela_2017: array<string>, Kim-Hellmuth_2017: array<string>, Lepik_2017: array<string>, Naranbhai_2015: array<string>, Nathan_2022: array<string>, Nedelec_2016: array<string>, OneK1K: array<string>, PISA: array<string>, Peng_2018: array<string>, Perez_2022: array<string>, PhLiPS: array<string>, Quach_2016: array<string>, ROSMAP: array<string>, Randolph_2021: array<string>, Schmiedel_2018: array<string>, Sch

In [4]:
pivoted_dfs['projectId'].show()



+---------------+-------------+------------+----+-----------+----------+---------+---------------------+--------+---------+---+-----+----------+----------+------+------------+------------+-------+--------+----+--------------+------+-----------+-----------+-----------------+----------+--------------+-----------+------------+------+----+---------+----------+------+----------+------+-------------+--------------+---------------------+--------------+--------+-------+-----------+-----------+----------+-------+----------------+----------+--------+--------+--------+--------+------+-----------+---------------------+---------+---------------+--------------------------+-----------+--------------+------------+-----------+---------------+-------------------+-----------+-----------------+-------------+-------------+----------------------+-------------------+--------------+----------+----------------+----------------+---------------+------------+----------------+----------------+---------------+-----

                                                                                

In [4]:
pivoted_dfs['projectId'].unpersist()
##### PROJECT ID ###### 
print('working with projectId')
pivoted_dfs['projectId'].persist()
#unique_values = benchmark.select('projectId').filter(F.col('projectId').isNotNull()).distinct().rdd.flatMap(lambda x: x).collect()
#filter = len(pivoted_dfs['projectId'].drop(*unique_values).columns[10:])
print('There are ', filter, 'columns to analyse with phases')
rows = comparisons_df_iterative(pivoted_dfs['projectId'].columns[-6:])

# If needed, now process the rest
for row in rows:
    results = aggregations_original(
        pivoted_dfs['projectId'], "propagated", listado, *row, today_date
    )
    result_all.append(results)

pivoted_dfs['projectId'].unpersist()
print('df unpersisted')


schema = StructType(
    [
        StructField("group", StringType(), True),
        StructField("comparison", StringType(), True),
        StructField("phase", StringType(), True),
        StructField("oddsRatio", DoubleType(), True),
        StructField("pValue", DoubleType(), True),
        StructField("lowerInterval", DoubleType(), True),
        StructField("upperInterval", DoubleType(), True),
        StructField("total", StringType(), True),
        StructField("values", ArrayType(ArrayType(IntegerType())), True),
        StructField("relSuccess", DoubleType(), True),
        StructField("rsLower", DoubleType(), True),
        StructField("rsUpper", DoubleType(), True),
        StructField("path", StringType(), True),
    ]
)
import re

# Define the list of patterns to search for
patterns = [
    "_only",
    #"_tissue",
    #"_isSignalFromRightTissue",
    "_isRightTissueSignalAgreed",
]
# Create a regex pattern to match any of the substrings
regex_pattern = "(" + "|".join(map(re.escape, patterns)) + ")"

# Convert list of lists to DataFrame
df = (
    spreadSheetFormatter(spark.createDataFrame(result_all, schema=schema))
    .withColumn(
        "prefix",
        F.regexp_replace(
            F.col("comparison"), regex_pattern + ".*", ""
        ),  # Extract part before the pattern
    )
    .withColumn(
        "suffix",
        F.regexp_extract(
            F.col("comparison"), regex_pattern, 0
        ),  # Extract the pattern itself
    )
)
### annotate projectId, tissue, qtl type and doe type:

from pyspark.sql.functions import create_map
from itertools import chain

mapping_expr=create_map([F.lit(x) for x in chain(*disdic.items())])

df_annot=df.withColumn('annotation',mapping_expr.getItem(F.col('prefix')))

df_annot.toPandas().to_csv(
    f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phases_jointProjects_rightJoin.csv"
)

print("dataframe written \n Analysis finished")


working with projectId
There are  <class 'filter'> columns to analyse with phases
gs://ot-team/jroldan/2025-09-17_analysis/propagated/othersProjectId_only_predictor_Phase>=4.parquet


                                                                                ]

gs://ot-team/jroldan/2025-09-17_analysis/propagated/othersProjectId_only_predictor_Phase>=3.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/othersProjectId_only_predictor_Phase>=2.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/othersProjectId_only_predictor_Phase>=1.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/othersProjectId_only_predictor_PhaseT.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/estimulated_only_predictor_Phase>=4.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/estimulated_only_predictor_Phase>=3.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/estimulated_only_predictor_Phase>=2.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/estimulated_only_predictor_Phase>=1.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/estimulated_only_predictor_PhaseT.parquet
gs://ot-team/jroldan/2025-09-17_analysis/propagated/nonStimulated_only_predictor_Phase>=4.parquet
gs://ot-team/jroldan/2025-



dataframe written 
 Analysis finished


#### TRY WITH THE DRUG DATASET TO SEE

In [1]:
import time
from array import ArrayType
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
# from stoppedTrials import terminated_td
from DoEAssessment import directionOfEffect
# from membraneTargets import target_membrane
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from datetime import datetime
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    DecimalType,
    StringType,
    FloatType,
)
import pandas as pd
from functools import reduce


# --- Build the SparkSession ---
# Use the .config() method to set these parameters before calling .getOrCreate()
# This ensures Spark requests the correct resources from YARN at the start.
driver_memory = "24g"                 # plenty for planning & small collects
executor_cores = 4                    # sweet spot for GC + Python workers
num_executors  = 12                   # 12 * 4 = 48 cores for executors; ~16 cores left for driver/OS
executor_memory = "32g"               # per executor heap
executor_memory_overhead = "8g"       # ~20% overhead for PySpark/Arrow/off-heap
# Totals: (32+8) * 12 = 480 GB executors + 24 GB driver ≈ 504 GB (adjust down if your hard cap is <500 GB)
# If you must stay strictly ≤ 500 GB, use executor_memory="30g", overhead="6g"  → (36 * 12) + 24 = 456 + 24 = 480 GB

shuffle_partitions   = 192            # ≈ 2–4× total cores (48) → start with 192
default_parallelism  = 192

spark = SparkSession.builder \
    .appName("MyOptimizedPySparkApp") \
    .config("spark.master", "yarn") \
    .config("spark.driver.memory", driver_memory) \
    .config("spark.executor.memory", executor_memory) \
    .config("spark.executor.cores", executor_cores) \
    .config("spark.executor.instances", num_executors) \
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead) \
    .config("spark.sql.shuffle.partitions", shuffle_partitions) \
    .config("spark.default.parallelism", default_parallelism) \
    .getOrCreate()

print(f"SparkSession created successfully with the following configurations:")
print(f"  spark.driver.memory: {spark.conf.get('spark.driver.memory')}")
print(f"  spark.executor.memory: {spark.conf.get('spark.executor.memory')}")
print(f"  spark.executor.cores: {spark.conf.get('spark.executor.cores')}")
print(f"  spark.executor.instances: {spark.conf.get('spark.executor.instances')}")
print(f"  spark.yarn.executor.memoryOverhead: {spark.conf.get('spark.yarn.executor.memoryOverhead')}")
print(f"  spark.sql.shuffle.partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
print(f"  spark.default.parallelism: {spark.conf.get('spark.default.parallelism')}")
print(f"Spark UI available at: {spark.sparkContext.uiWebUrl}")

# --- Your PySpark Code Here ---
# Now you can proceed with your data loading and processing.
# Example:
# df = spark.read.parquet("hdfs:///user/your_user/your_large_data.parquet")
# print(f"Number of rows in DataFrame: {df.count()}")
# df.groupBy("some_column").agg({"another_column": "sum"}).show()

# Remember to stop the SparkSession when you are done
# spark.stop()

path_n='gs://open-targets-data-releases/25.06/output/'

target = spark.read.parquet(f"{path_n}target/")

diseases = spark.read.parquet(f"{path_n}disease/")

evidences = spark.read.parquet(f"{path_n}evidence")

credible = spark.read.parquet(f"{path_n}credible_set")

new = spark.read.parquet(f"{path_n}colocalisation_coloc") 

index=spark.read.parquet(f"{path_n}study/")

variantIndex = spark.read.parquet(f"{path_n}variant")

biosample = spark.read.parquet(f"{path_n}biosample")

ecaviar=spark.read.parquet(f"{path_n}colocalisation_ecaviar")

all_coloc=ecaviar.unionByName(new, allowMissingColumns=True)

print("loaded files")

#### FIRST MODULE: BUILDING COLOC 
newColoc=buildColocData(all_coloc,credible,index)

print("loaded newColoc")

### SECOND MODULE: PROCESS EVIDENCES TO AVOID EXCESS OF COLUMNS 
gwasComplete = gwasDataset(evidences,credible)

#### THIRD MODULE: INCLUDE COLOC IN THE 
resolvedColoc = (
    (
        newColoc.withColumnRenamed("geneId", "targetId")
        .join(
            gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
            on=["leftStudyLocusId", "targetId"],
            how="inner",
        )
        .join(  ### propagated using parent terms
            diseases.selectExpr(
                "id as diseaseId", "name", "parents", "therapeuticAreas"
            ),
            on="diseaseId",
            how="left",
        )
        .withColumn(
            "diseaseId",
            F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
        )
        .drop("parents", "oldDiseaseId")
    ).withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(
                ["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]
            ),
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_protect"),
            ),
        ).when(
            F.col("rightStudyType").isin(
                ["sqtl", "scsqtl"]
            ),  ### opposite directionality than sqtl
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_protect"),
            ),
        ),
    )
    # .persist()
)
print("loaded resolvedColloc")

datasource_filter = [
#   "ot_genetics_portal",
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]

assessment, evidences, actionType, oncolabel = temporary_directionOfEffect(
    path_n, datasource_filter
)

print("run temporary direction of effect")


print("built drugApproved dataset")


#### FOURTH MODULE BUILDING CHEMBL ASSOCIATIONS - HERE TAKE CARE WITH FILTERING STEP 
analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl"))
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
    # .persist()
)

####2 Define agregation function
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.types import *


def convertTuple(tup):
    st = ",".join(map(str, tup))
    return st


#####3 run in a function
def aggregations_original(
    df,
    data,
    listado,
    comparisonColumn,
    comparisonType,
    predictionColumn,
    predictionType,
    today_date,
):
    wComparison = Window.partitionBy(comparisonColumn)
    wPrediction = Window.partitionBy(predictionColumn)
    wPredictionComparison = Window.partitionBy(comparisonColumn, predictionColumn)
    results = []
    # uniqIds = df.select("targetId", "diseaseId").distinct().count()
    out = (
        df.withColumn("comparisonType", F.lit(comparisonType))
        .withColumn("dataset", F.lit(data))
        .withColumn("predictionType", F.lit(predictionType))
        # .withColumn("total", F.lit(uniqIds))
        .withColumn("a", F.count("targetId").over(wPredictionComparison))
        .withColumn("comparisonColumn", F.lit(comparisonColumn))
        .withColumn("predictionColumnValue", F.lit(predictionColumn))
        .withColumn(
            "predictionTotal",
            F.count("targetId").over(wPrediction),
        )
        .withColumn(
            "comparisonTotal",
            F.count("targetId").over(wComparison),
        )
        .select(
            F.col(predictionColumn).alias("prediction"),
            F.col(comparisonColumn).alias("comparison"),
            "dataset",
            "comparisonColumn",
            "predictionColumnValue",
            "comparisonType",
            "predictionType",
            "a",
            "predictionTotal",
            "comparisonTotal",
        )
        .filter(F.col("prediction").isNotNull())
        .filter(F.col("comparison").isNotNull())
        .distinct()
    )
    '''
    out.write.mode("overwrite").parquet(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    '''

    listado.append(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    path = "gs://ot-team/jroldan/" + str(
        today_date
        + "_"
        + "analysis/"
        + data
        # + "_propagated"
        + "/"
        + comparisonColumn
        + "_"
        + comparisonType
        + "_"
        + predictionColumn
        + ".parquet"
    )
    print(path)
    
    ### making analysis
    array1 = np.delete(
        out.join(full_data, on=["prediction", "comparison"], how="outer")
        .groupBy("comparison")
        .pivot("prediction")
        .agg(F.first("a"))
        .sort(F.col("comparison").desc())
        .select("comparison", "yes", "no")
        .fillna(0)
        .toPandas()
        .to_numpy(),
        [0],
        1,
    )
    total = np.sum(array1)
    res_npPhaseX = np.array(array1, dtype=int)
    resX = convertTuple(fisher_exact(res_npPhaseX, alternative="two-sided"))
    resx_CI = convertTuple(
        odds_ratio(res_npPhaseX).confidence_interval(confidence_level=0.95)
    )

    result_st.append(resX)
    result_ci.append(resx_CI)
    (rs_result, rs_ci) = relative_success(array1)
    results.extend(
        [
            comparisonType,
            comparisonColumn,
            predictionColumn,
            round(float(resX.split(",")[0]), 2),
            float(resX.split(",")[1]),
            round(float(resx_CI.split(",")[0]), 2),
            round(float(resx_CI.split(",")[1]), 2),
            str(total),
            np.array(res_npPhaseX).tolist(),
            round(float(rs_result), 2),
            round(float(rs_ci[0]), 2),
            round(float(rs_ci[1]), 2),
            # studies,
            # tissues,
            path,
        ]
    )
    return results


#### 3 Loop over different datasets (as they will have different rows and columns)


def comparisons_df_iterative(elements):
    #toAnalysis = [(key, value) for key, value in disdic.items() if value == projectId]
    toAnalysis = [(col, "predictor") for col in elements]
    schema = StructType(
        [
            StructField("comparison", StringType(), True),
            StructField("comparisonType", StringType(), True),
        ]
    )

    comparisons = spark.createDataFrame(toAnalysis, schema=schema)
    ### include all the columns as predictor

    predictions = spark.createDataFrame(
        data=[
            ("Phase>=4", "clinical"),
            ('Phase>=3','clinical'),
            ('Phase>=2','clinical'),
            ('Phase>=1','clinical'),
            ("PhaseT", "clinical"),
        ]
    )
    return comparisons.join(predictions, how="full").collect()


print("load comparisons_df_iterative function")


full_data = spark.createDataFrame(
    data=[
        ("yes", "yes"),
        ("yes", "no"),
        ("no", "yes"),
        ("no", "no"),
    ],
    schema=StructType(
        [
            StructField("prediction", StringType(), True),
            StructField("comparison", StringType(), True),
        ]
    ),
)
print("created full_data and lists")

#rightTissue = spark.read.csv(
#    'gs://ot-team/jroldan/analysis/20250526_rightTissue.csv',
#    header=True,
#).drop("_c0")

print("loaded rightTissue dataset")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

print("built negativeTD dataset")

print("built bench2 dataset")

###### cut from here
print("looping for variables_study")

#### new part with chatgpt -- TEST

## QUESTIONS TO ANSWER:
# HAVE ECAVIAR >=0.8
# HAVE COLOC 
# HAVE COLOC >= 0.8
# HAVE COLOC + ECAVIAR >= 0.01
# HAVE COLOC >= 0.8 + ECAVIAR >= 0.01
# RIGHT JOING WITH CHEMBL 

### FIFTH MODULE: BUILDING BENCHMARK OF THE DATASET TO EXTRACT EHE ANALYSIS 

resolvedColocFiltered = resolvedColoc.filter((F.col('clpp')>=0.01) | (F.col('h4')>=0.8))
benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")


### drug mechanism of action
mecact_path = f"{path_n}drug_mechanism_of_action/" #  mechanismOfAction == old version
mecact = spark.read.parquet(mecact_path)

inhibitors = [
    "RNAI INHIBITOR",
    "NEGATIVE MODULATOR",
    "NEGATIVE ALLOSTERIC MODULATOR",
    "ANTAGONIST",
    "ANTISENSE INHIBITOR",
    "BLOCKER",
    "INHIBITOR",
    "DEGRADER",
    "INVERSE AGONIST",
    "ALLOSTERIC ANTAGONIST",
    "DISRUPTING AGENT",
]

activators = [
    "PARTIAL AGONIST",
    "ACTIVATOR",
    "POSITIVE ALLOSTERIC MODULATOR",
    "POSITIVE MODULATOR",
    "AGONIST",
    "SEQUESTERING AGENT",  ## lost at 31.01.2025
    "STABILISER",
    # "EXOGENOUS GENE", ## added 24.06.2025
    # "EXOGENOUS PROTEIN" ## added 24.06.2025
]


actionType = (
        mecact.select(
            F.explode_outer("chemblIds").alias("drugId"),
            "actionType",
            "mechanismOfAction",
            "targets",
        )
        .select(
            F.explode_outer("targets").alias("targetId"),
            "drugId",
            "actionType",
            "mechanismOfAction",
        )
        .groupBy("targetId", "drugId")
        .agg(F.collect_set("actionType").alias("actionType2"))
    ).withColumn('nMoA', F.size(F.col('actionType2')))

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl")).join(actionType, on=['targetId','drugId'], how='left')
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase",'actionType2')
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)

benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

### create disdic dictionary
disdic={}

# --- Configuration for your iterative pivoting ---
group_by_columns = ['targetId', 'diseaseId','phase4Clean','phase3Clean','phase2Clean','phase1Clean','PhaseT']
#columns_to_pivot_on = ['actionType2', 'biosampleName', 'projectId', 'rightStudyType','colocalisationMethod']
columns_to_pivot_on = ['projectId']
columns_to_aggregate = ['NoneCellYes', 'NdiagonalYes','hasGenetics'] # The values you want to collect in the pivoted cells
all_pivoted_dfs = {}

doe_columns=["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]
diagonal_lof=['LoF_protect','GoF_risk']
diagonal_gof=['LoF_risk','GoF_protect']

conditions = [
    F.when(F.col(c) == F.col("maxDoE"), F.lit(c)).otherwise(F.lit(None)) for c in doe_columns
    ]
print('entering the big loops')
# --- Nested Loops for Dynamic Pivoting ---
for agg_col_name in columns_to_aggregate:
    for pivot_col_name in columns_to_pivot_on:
        print(f"\n--- Creating DataFrame for Aggregation: '{agg_col_name}' and Pivot: '{pivot_col_name}' ---")
        current_col_pvalue_order_window = Window.partitionBy("targetId", "diseaseId", "maxClinPhase", pivot_col_name).orderBy(F.col('colocalisationMethod').asc(), F.col("qtlPValueExponent").asc())
        test2=discrepancifier(benchmark.withColumn('actionType2', F.concat_ws(",", F.col("actionType2"))).withColumn('qtlColocDoE',F.first('colocDoE').over(current_col_pvalue_order_window)).groupBy(
        "targetId", "diseaseId", "maxClinPhase", "drugLoF_protect", "drugGoF_protect",pivot_col_name)
        .pivot("colocDoE")
        .count()
        .withColumnRenamed('drugLoF_protect', 'LoF_protect_ch')
        .withColumnRenamed('drugGoF_protect', 'GoF_protect_ch')).withColumn( ## .filter(F.col('coherencyDiagonal')!='noEvid')
    "arrayN", F.array(*[F.col(c) for c in doe_columns])
    ).withColumn(
        "maxDoE", F.array_max(F.col("arrayN"))
    ).withColumn("maxDoE_names", F.array(*conditions)
    ).withColumn("maxDoE_names", F.expr("filter(maxDoE_names, x -> x is not null)")
    ).withColumn(
        "NoneCellYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("LoF_protect")))==True, F.lit('yes'))
        .when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("GoF_protect")))==True, F.lit('yes')
            ).otherwise(F.lit('no'))  # If the value is null, return null # Otherwise, check if name is in array
    ).withColumn(
        "NdiagonalYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_lof]))) > 0),
            F.lit("yes")
        ).when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_gof]))) > 0),
            F.lit("yes")
        ).otherwise(F.lit('no'))
    ).withColumn(
        "drugCoherency",
        F.when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("dispar")
        )
        .otherwise(F.lit("other")),
    ).join(negativeTD, on=["targetId", "diseaseId"], how="left").withColumn(
        "PhaseT",
        F.when(F.col("stopReason") == "Negative", F.lit("yes")).otherwise(F.lit("no")),
    ).withColumn(
        "phase4Clean",
        F.when(
            (F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase3Clean",
        F.when(
            (F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase2Clean",
        F.when(
            (F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase1Clean",
        F.when(
            (F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "hasGenetics",
        F.when(F.col("coherencyDiagonal") != "noEvid", F.lit("yes")).otherwise(F.lit("no")),
    )
        # 1. Get distinct values for the pivot column (essential for pivot())
        # This brings a small amount of data to the driver, but is necessary for the pivot schema.
        #distinct_pivot_values = [row[0] for row in test2.select(pivot_col_name).distinct().collect()]
        # print(f"Distinct values for '{pivot_col_name}': {distinct_pivot_values}")

        # 2. Perform the groupBy, pivot, and aggregate operations
        # The .pivot() function requires the list of distinct values for better performance
        # and correct schema inference.
        pivoted_df = (
            test2.groupBy(*group_by_columns)
            .pivot(pivot_col_name) # Provide distinct values distinct_pivot_values
            .agg(F.collect_set(F.col(agg_col_name))) # Collect all values into a set
            .fillna(0) # Fill cells that have no data with an empty list instead of null
        )
        # 3. Add items to dictionary to map the columns:
        # filter out None and 'null':
        datasetColumns=pivoted_df.columns
        filtered = [x for x in datasetColumns if x is not None and x != 'null']
        # using list comprehension
        for item in filtered:
            disdic[item] = pivot_col_name

        # 3. Add the 'data' literal column dynamically
        # This column indicates which aggregation column was used.
        #pivoted_df = pivoted_df.withColumn('data', F.lit(f'Drug_{agg_col_name}'))

        array_columns_to_convert = [
            field.name for field in pivoted_df.schema.fields
            if isinstance(field.dataType, ArrayType)
        ]
        print(f"Identified ArrayType columns for conversion: {array_columns_to_convert}")

        # 4. Apply the conversion logic to each identified array column
        df_after_conversion = pivoted_df # Start with the pivoted_df
        for col_to_convert in array_columns_to_convert:
            df_after_conversion = df_after_conversion.withColumn(
                col_to_convert,
                F.when(F.col(col_to_convert).isNull(), F.lit('no'))          # Handle NULLs (from pivot for no data)
                .when(F.size(F.col(col_to_convert)) == 0, F.lit('no'))       # Empty array -> 'no'
                .when(F.array_contains(F.col(col_to_convert), F.lit('yes')), F.lit('yes')) # Contains 'yes' -> 'yes'
                .when(F.array_contains(F.col(col_to_convert), F.lit('no')), F.lit('no'))   # Contains 'no' -> 'no'
                .otherwise(F.lit('no')) # Fallback for unexpected array content (e.g., ['other'], ['yes','no'])
            )

        # 4. Generate a unique name for this DataFrame and store it
        df_key = f"df_pivot_{agg_col_name.lower()}_by_{pivot_col_name.lower()}"
        all_pivoted_dfs[df_key] = df_after_conversion.withColumnRenamed( 'phase4Clean','Phase>=4'
        ).withColumnRenamed('phase3Clean','Phase>=3'
        ).withColumnRenamed('phase2Clean','Phase>=2'
        ).withColumnRenamed('phase1Clean','Phase>=1')


# --- Accessing your generated DataFrames ---
print("\n--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---")
print("Keys available:", all_pivoted_dfs.keys())

spark session created at 2025-09-17 20:45:50.891922
Analysis started on 2025-09-17 at  2025-09-17 20:45:50.891922


25/09/17 20:45:55 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/17 20:45:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created successfully with the following configurations:
  spark.driver.memory: 24g
  spark.executor.memory: 32g
  spark.executor.cores: 4
  spark.executor.instances: 12
  spark.yarn.executor.memoryOverhead: 8g
  spark.sql.shuffle.partitions: 192
  spark.default.parallelism: 192
Spark UI available at: http://jr-doe-temp1-m.c.open-targets-eu-dev.internal:35471


                                                                                

loaded files
loaded newColoc


                                                                                

loaded gwasComplete
loaded resolvedColloc
run temporary direction of effect
built drugApproved dataset


                                                                                

load comparisons_df_iterative function
created full_data and lists
loaded rightTissue dataset
built negativeTD dataset
built bench2 dataset
looping for variables_study
entering the big loops

--- Creating DataFrame for Aggregation: 'NoneCellYes' and Pivot: 'projectId' ---


                                                                                

Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- Creating DataFrame for Aggregation: 'NdiagonalYes' and Pivot: 'projectId' ---


25/09/17 20:48:41 WARN CacheManager: Asked to cache already cached data.        


Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- Creating DataFrame for Aggregation: 'hasGenetics' and Pivot: 'projectId' ---


25/09/17 20:49:15 WARN CacheManager: Asked to cache already cached data.        


Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---
Keys available: dict_keys(['df_pivot_nonecellyes_by_projectid', 'df_pivot_ndiagonalyes_by_projectid', 'df_pivot_hasgenetics_by_projectid'])


In [11]:
all_pivoted_dfs

{'df_pivot_nonecellyes_by_projectid': DataFrame[targetId: string, diseaseId: string, Phase>=4: string, Phase>=3: string, Phase>=2: string, Phase>=1: string, PhaseT: string, null: string, Alasoo_2018: string, Aygun_2021: string, BLUEPRINT: string, Bossini-Castillo_2019: string, BrainSeq: string, Braineac2: string, CAP: string, CEDAR: string, CommonMind: string, Cytoimmgen: string, FUSION: string, Fairfax_2012: string, Fairfax_2014: string, GENCORD: string, GEUVADIS: string, GTEx: string, Gilchrist_2021: string, HipSci: string, Jerber_2021: string, Kasela_2017: string, Kim-Hellmuth_2017: string, Lepik_2017: string, Naranbhai_2015: string, Nathan_2022: string, Nedelec_2016: string, OneK1K: string, PISA: string, Peng_2018: string, Perez_2022: string, PhLiPS: string, Quach_2016: string, ROSMAP: string, Randolph_2021: string, Schmiedel_2018: string, Schwartzentruber_2018: string, Steinberg_2020: string, Sun_2018: string, TwinsUK: string, UKB_PPP_EUR: string, Walker_2019: string, Young_2019: 

In [14]:
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].drop('null').columns[-6:]

['othersProjectId_only',
 'estimulated_only',
 'nonStimulated_only',
 'cellLine',
 'nonCellLine',
 'GTExUKB']

In [20]:
############## HYBRID ##############
####################################
def strip_only(lst):
    return [x.removesuffix("_only") for x in lst]  # Python 3.9+
    # or: return [x[:-5] if x.endswith("_only") else x for x in lst]


##### PROJECTID
project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
main=['GTEx_only', 'UKB_PPP_EUR_only']
#stimulated=['Alasoo_2018_only','Cytoimmgen_only','Fairfax_2014_only','Kim-Hellmuth_2017_only','Nathan_2022_only','Nedelec_2016_only','Quach_2016_only','Randolph_2021_only','Schmiedel_2018_only']
#cellLine=['CAP_only','HipSci_only','iPSCORE_only','Jerber_2021_only','PhLiPS_only','Schwartzentruber_2018_only','TwinsUK_only']

derivedCellLine=['TwinsUK_only','PhLiPS_only','CAP_only','GENCORD_only','Sun_2018_only','Nedelec_2016_only']
canonicalCellLine=['Alasoo_2018_only','Jerber_2021_only','GEUVADIS_only','iPSCORE_only','Aygun_2021_only','Schwartzentruber_2018_only']
stimulated=['Schmiedel_2018_only','Bossini-Castillo_2019_only','Alasoo_2018_only','Cytoimmgen_only','Gilchrist_2021_only','CAP_only','Quach_2016_only','Randolph_2021_only','Sun_2018_only','Nedelec_2016_only','Kim-Hellmuth_2017_only']

# Apply
main = strip_only(main)
canonicalCellLine = strip_only(canonicalCellLine)
derivedCellLine = strip_only(derivedCellLine)
stimulated = strip_only(stimulated)

others=[item for item in strip_only(project_keys) if item not in main]
nonStimulated=[item for item in strip_only(project_keys) if item not in stimulated]
nonCanonicalCellLine = [item for item in strip_only(project_keys) if item not in canonicalCellLine]
nonDerivedCellLine = [item for item in strip_only(project_keys) if item not in derivedCellLine]

#otherCellLine=[item for item in strip_only(project_keys) if item not in cellLine]


def _or_yes(df, cols):
    """Return a Column that is TRUE if any of the given columns == 'yes'.
       Ignores columns not present in df. If none present, returns FALSE.
    """
    present = [c for c in cols if c in df.columns]
    if not present:
        return F.lit(False)
    # (col == 'yes') OR (col == 'yes') OR ...
    exprs = [F.col(c) == "yes" for c in present]
    return reduce(lambda a, b: a | b, exprs)

def add_project_group_flags(df, main, canonicalCellLine, derivedCellLine,stimulated):
    # project keys = every *_only column in this DF
    project_keys = [c for c in df.columns if c.endswith("_only")]

    # Derived buckets
    others=[item for item in strip_only(project_keys) if item not in main]
    nonStimulated=[item for item in strip_only(project_keys) if item not in stimulated]
    nonCanonicalCellLine = [item for item in strip_only(project_keys) if item not in canonicalCellLine]
    nonDerivedCellLine = [item for item in strip_only(project_keys) if item not in derivedCellLine]

    # Conditions
    condition1 = _or_yes(df, others)  
    condition1 = _or_yes(df, main)  
    condition2 = _or_yes(df, stimulated)    
    condition3 = _or_yes(df, nonStimulated)  
    condition4 = _or_yes(df, canonicalCellLine)     
    condition5 = _or_yes(df, nonCanonicalCellLine)  
    condition7 = _or_yes(df, derivedCellLine)  
    condition8 = _or_yes(df, nonDerivedCellLine)         

    # Add columns (write 'yes'/'no')
    return (
        df.withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no"))
          .withColumn("GTExUKB_only",     F.when(condition2, "yes").otherwise("no"))
          .withColumn("stimulated_only",   F.when(condition3, "yes").otherwise("no"))
          .withColumn("nonStimulated",             F.when(condition4, "yes").otherwise("no"))
          .withColumn("canonicalCellLine",          F.when(condition5, "yes").otherwise("no"))
          .withColumn("nonCanonicalCellLine",              F.when(condition6, "yes").otherwise("no"))
          .withColumn("derivedCellLine",          F.when(condition7, "yes").otherwise("no"))
          .withColumn("nonDerivedCellLine",              F.when(condition8, "yes").otherwise("no"))
    )

# --- Apply to the dict entry you mentioned ---
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = add_project_group_flags(
    all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'],main=main, canonicalCellLine=canonicalCellLine, derivedCellLine=derivedCellLine,stimulated=stimulated)
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = add_project_group_flags(
    all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'], main, canonicalCellLine, derivedCellLine,stimulated)
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = add_project_group_flags(
    all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'], main, canonicalCellLine, derivedCellLine,stimulated)
# If you wanted to apply to every DF in the dict (only if they all share *_only columns):
# for k, df in all_pivoted_dfs.items():
#     all_pivoted_dfs[k] = add_project_group_flags(df, main, stimulated, cellLine)

###append to dictionary

disdic.update({'othersProjectId': 'projectId','Stimulated': 'projectId','cellLine': 'projectId', 'othersBiosampleName_only': 'biosampleName', 'otherRightStudyType':'rightStudyType'})

###################################
###################################
result = []
result_st = []
result_ci = []
array2 = []
listado = []
result_all = []
today_date = str(date.today())

for key,df in all_pivoted_dfs.items():

    print(f'working with {key}')
    parts = key.split('_by_') ### take the part of key belonging to column name
    column_name = parts[1] ### take the last part which is column name
    all_pivoted_dfs[key].persist()
    #unique_values = all_pivoted_dfs[key].drop('null').columns[7:]
    unique_values = all_pivoted_dfs[key].drop('null').columns[-6:] ### just the interesting columns for us 
    filtered_unique_values = [x for x in unique_values if x is not None and x != 'null']
    print('There are ', len(filtered_unique_values), 'columns to analyse with phases')
    rows = comparisons_df_iterative(filtered_unique_values)

    # If needed, now process the rest
    for row in rows:
        print('performing', row)
        results = aggregations_original(
            all_pivoted_dfs[key], key, listado, *row, today_date
        )
        result_all.append(results)
        print('results appended')
    all_pivoted_dfs[key].unpersist()
    print('df unpersisted')


schema = StructType(
    [
        StructField("group", StringType(), True),
        StructField("comparison", StringType(), True),
        StructField("phase", StringType(), True),
        StructField("oddsRatio", DoubleType(), True),
        StructField("pValue", DoubleType(), True),
        StructField("lowerInterval", DoubleType(), True),
        StructField("upperInterval", DoubleType(), True),
        StructField("total", StringType(), True),
        StructField("values", ArrayType(ArrayType(IntegerType())), True),
        StructField("relSuccess", DoubleType(), True),
        StructField("rsLower", DoubleType(), True),
        StructField("rsUpper", DoubleType(), True),
        StructField("path", StringType(), True),
    ]
)
import re

# Define the list of patterns to search for
patterns = [
    "_only",
    #"_tissue",
    #"_isSignalFromRightTissue",
    "_isRightTissueSignalAgreed",
]
# Create a regex pattern to match any of the substrings
regex_pattern = "(" + "|".join(map(re.escape, patterns)) + ")"

# Convert list of lists to DataFrame
df = (
    spreadSheetFormatter(spark.createDataFrame(result_all, schema=schema))
    .withColumn(
        "prefix",
        F.regexp_replace(
            F.col("comparison"), regex_pattern + ".*", ""
        ),  # Extract part before the pattern
    )
    .withColumn(
        "suffix",
        F.regexp_extract(
            F.col("comparison"), regex_pattern, 0
        ),  # Extract the pattern itself
    )
)

### annotate projectId, tissue, qtl type and doe type:

from pyspark.sql.functions import create_map
from itertools import chain

mapping_expr=create_map([F.lit(x) for x in chain(*disdic.items())])

df_annot=df.withColumn('annotation',mapping_expr.getItem(F.col('prefix')))

df_annot.toPandas().to_csv(
    f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_AllPhases.csv"
)

print("dataframe written \n Analysis finished")

working with df_pivot_nonecellyes_by_projectid
There are  6 columns to analyse with phases
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=4', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_nonecellyes_by_projectid/stimulated_only_predictor_Phase>=4.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_nonecellyes_by_projectid/stimulated_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_nonecellyes_by_projectid/stimulated_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_nonecellyes_by_projectid/stimulated_only_p

                                                                                

results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_ndiagonalyes_by_projectid/stimulated_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_ndiagonalyes_by_projectid/stimulated_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_ndiagonalyes_by_projectid/stimulated_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_ndiagonalyes_by_projectid/stimulated_only_predictor_PhaseT.parquet
results appended
performing Row(comparison='nonS

                                                                                

results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_hasgenetics_by_projectid/stimulated_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_hasgenetics_by_projectid/stimulated_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_hasgenetics_by_projectid/stimulated_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='stimulated_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-17_analysis/df_pivot_hasgenetics_by_projectid/stimulated_only_predictor_PhaseT.parquet
results appended
performing Row(comparison='nonStimu



dataframe written 
 Analysis finished


##### at once

In [1]:
import time
from array import ArrayType
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
# from stoppedTrials import terminated_td
from DoEAssessment import directionOfEffect
# from membraneTargets import target_membrane
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from datetime import datetime
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    DecimalType,
    StringType,
    FloatType,
)
import pandas as pd
from functools import reduce


# --- Build the SparkSession ---
# Use the .config() method to set these parameters before calling .getOrCreate()
# This ensures Spark requests the correct resources from YARN at the start.
driver_memory = "24g"                 # plenty for planning & small collects
executor_cores = 4                    # sweet spot for GC + Python workers
num_executors  = 12                   # 12 * 4 = 48 cores for executors; ~16 cores left for driver/OS
executor_memory = "32g"               # per executor heap
executor_memory_overhead = "8g"       # ~20% overhead for PySpark/Arrow/off-heap
# Totals: (32+8) * 12 = 480 GB executors + 24 GB driver ≈ 504 GB (adjust down if your hard cap is <500 GB)
# If you must stay strictly ≤ 500 GB, use executor_memory="30g", overhead="6g"  → (36 * 12) + 24 = 456 + 24 = 480 GB

shuffle_partitions   = 192            # ≈ 2–4× total cores (48) → start with 192
default_parallelism  = 192

spark = SparkSession.builder \
    .appName("MyOptimizedPySparkApp") \
    .config("spark.master", "yarn") \
    .config("spark.driver.memory", driver_memory) \
    .config("spark.executor.memory", executor_memory) \
    .config("spark.executor.cores", executor_cores) \
    .config("spark.executor.instances", num_executors) \
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead) \
    .config("spark.sql.shuffle.partitions", shuffle_partitions) \
    .config("spark.default.parallelism", default_parallelism) \
    .getOrCreate()

print(f"SparkSession created successfully with the following configurations:")
print(f"  spark.driver.memory: {spark.conf.get('spark.driver.memory')}")
print(f"  spark.executor.memory: {spark.conf.get('spark.executor.memory')}")
print(f"  spark.executor.cores: {spark.conf.get('spark.executor.cores')}")
print(f"  spark.executor.instances: {spark.conf.get('spark.executor.instances')}")
print(f"  spark.yarn.executor.memoryOverhead: {spark.conf.get('spark.yarn.executor.memoryOverhead')}")
print(f"  spark.sql.shuffle.partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
print(f"  spark.default.parallelism: {spark.conf.get('spark.default.parallelism')}")
print(f"Spark UI available at: {spark.sparkContext.uiWebUrl}")

# --- Your PySpark Code Here ---
# Now you can proceed with your data loading and processing.
# Example:
# df = spark.read.parquet("hdfs:///user/your_user/your_large_data.parquet")
# print(f"Number of rows in DataFrame: {df.count()}")
# df.groupBy("some_column").agg({"another_column": "sum"}).show()

# Remember to stop the SparkSession when you are done
# spark.stop()

path_n='gs://open-targets-data-releases/25.06/output/'

target = spark.read.parquet(f"{path_n}target/")

diseases = spark.read.parquet(f"{path_n}disease/")

evidences = spark.read.parquet(f"{path_n}evidence")

credible = spark.read.parquet(f"{path_n}credible_set")

new = spark.read.parquet(f"{path_n}colocalisation_coloc") 

index=spark.read.parquet(f"{path_n}study/")

variantIndex = spark.read.parquet(f"{path_n}variant")

biosample = spark.read.parquet(f"{path_n}biosample")

ecaviar=spark.read.parquet(f"{path_n}colocalisation_ecaviar")

all_coloc=ecaviar.unionByName(new, allowMissingColumns=True)

print("loaded files")

#### FIRST MODULE: BUILDING COLOC 
newColoc=buildColocData(all_coloc,credible,index)

print("loaded newColoc")

### SECOND MODULE: PROCESS EVIDENCES TO AVOID EXCESS OF COLUMNS 
gwasComplete = gwasDataset(evidences,credible)

#### THIRD MODULE: INCLUDE COLOC IN THE 
resolvedColoc = (
    (
        newColoc.withColumnRenamed("geneId", "targetId")
        .join(
            gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
            on=["leftStudyLocusId", "targetId"],
            how="inner",
        )
        .join(  ### propagated using parent terms
            diseases.selectExpr(
                "id as diseaseId", "name", "parents", "therapeuticAreas"
            ),
            on="diseaseId",
            how="left",
        )
        .withColumn(
            "diseaseId",
            F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
        )
        .drop("parents", "oldDiseaseId")
    ).withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(
                ["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]
            ),
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_protect"),
            ),
        ).when(
            F.col("rightStudyType").isin(
                ["sqtl", "scsqtl"]
            ),  ### opposite directionality than sqtl
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_protect"),
            ),
        ),
    )
    # .persist()
)
print("loaded resolvedColloc")

datasource_filter = [
#   "ot_genetics_portal",
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]

assessment, evidences, actionType, oncolabel = temporary_directionOfEffect(
    path_n, datasource_filter
)

print("run temporary direction of effect")


print("built drugApproved dataset")


#### FOURTH MODULE BUILDING CHEMBL ASSOCIATIONS - HERE TAKE CARE WITH FILTERING STEP 
analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl"))
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
    # .persist()
)

####2 Define agregation function
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.types import *


def convertTuple(tup):
    st = ",".join(map(str, tup))
    return st


#####3 run in a function
def aggregations_original(
    df,
    data,
    listado,
    comparisonColumn,
    comparisonType,
    predictionColumn,
    predictionType,
    today_date,
):
    wComparison = Window.partitionBy(comparisonColumn)
    wPrediction = Window.partitionBy(predictionColumn)
    wPredictionComparison = Window.partitionBy(comparisonColumn, predictionColumn)
    results = []
    # uniqIds = df.select("targetId", "diseaseId").distinct().count()
    out = (
        df.withColumn("comparisonType", F.lit(comparisonType))
        .withColumn("dataset", F.lit(data))
        .withColumn("predictionType", F.lit(predictionType))
        # .withColumn("total", F.lit(uniqIds))
        .withColumn("a", F.count("targetId").over(wPredictionComparison))
        .withColumn("comparisonColumn", F.lit(comparisonColumn))
        .withColumn("predictionColumnValue", F.lit(predictionColumn))
        .withColumn(
            "predictionTotal",
            F.count("targetId").over(wPrediction),
        )
        .withColumn(
            "comparisonTotal",
            F.count("targetId").over(wComparison),
        )
        .select(
            F.col(predictionColumn).alias("prediction"),
            F.col(comparisonColumn).alias("comparison"),
            "dataset",
            "comparisonColumn",
            "predictionColumnValue",
            "comparisonType",
            "predictionType",
            "a",
            "predictionTotal",
            "comparisonTotal",
        )
        .filter(F.col("prediction").isNotNull())
        .filter(F.col("comparison").isNotNull())
        .distinct()
    )
    '''
    out.write.mode("overwrite").parquet(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    '''

    listado.append(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    path = "gs://ot-team/jroldan/" + str(
        today_date
        + "_"
        + "analysis/"
        + data
        # + "_propagated"
        + "/"
        + comparisonColumn
        + "_"
        + comparisonType
        + "_"
        + predictionColumn
        + ".parquet"
    )
    print(path)
    
    ### making analysis
    array1 = np.delete(
        out.join(full_data, on=["prediction", "comparison"], how="outer")
        .groupBy("comparison")
        .pivot("prediction")
        .agg(F.first("a"))
        .sort(F.col("comparison").desc())
        .select("comparison", "yes", "no")
        .fillna(0)
        .toPandas()
        .to_numpy(),
        [0],
        1,
    )
    total = np.sum(array1)
    res_npPhaseX = np.array(array1, dtype=int)
    resX = convertTuple(fisher_exact(res_npPhaseX, alternative="two-sided"))
    resx_CI = convertTuple(
        odds_ratio(res_npPhaseX).confidence_interval(confidence_level=0.95)
    )

    result_st.append(resX)
    result_ci.append(resx_CI)
    (rs_result, rs_ci) = relative_success(array1)
    results.extend(
        [
            comparisonType,
            comparisonColumn,
            predictionColumn,
            round(float(resX.split(",")[0]), 2),
            float(resX.split(",")[1]),
            round(float(resx_CI.split(",")[0]), 2),
            round(float(resx_CI.split(",")[1]), 2),
            str(total),
            np.array(res_npPhaseX).tolist(),
            round(float(rs_result), 2),
            round(float(rs_ci[0]), 2),
            round(float(rs_ci[1]), 2),
            # studies,
            # tissues,
            path,
        ]
    )
    return results


#### 3 Loop over different datasets (as they will have different rows and columns)


def comparisons_df_iterative(elements):
    #toAnalysis = [(key, value) for key, value in disdic.items() if value == projectId]
    toAnalysis = [(col, "predictor") for col in elements]
    schema = StructType(
        [
            StructField("comparison", StringType(), True),
            StructField("comparisonType", StringType(), True),
        ]
    )

    comparisons = spark.createDataFrame(toAnalysis, schema=schema)
    ### include all the columns as predictor

    predictions = spark.createDataFrame(
        data=[
            ("Phase>=4", "clinical"),
            ('Phase>=3','clinical'),
            ('Phase>=2','clinical'),
            ('Phase>=1','clinical'),
            ("PhaseT", "clinical"),
        ]
    )
    return comparisons.join(predictions, how="full").collect()


print("load comparisons_df_iterative function")


full_data = spark.createDataFrame(
    data=[
        ("yes", "yes"),
        ("yes", "no"),
        ("no", "yes"),
        ("no", "no"),
    ],
    schema=StructType(
        [
            StructField("prediction", StringType(), True),
            StructField("comparison", StringType(), True),
        ]
    ),
)
print("created full_data and lists")

#rightTissue = spark.read.csv(
#    'gs://ot-team/jroldan/analysis/20250526_rightTissue.csv',
#    header=True,
#).drop("_c0")

print("loaded rightTissue dataset")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

print("built negativeTD dataset")

print("built bench2 dataset")

###### cut from here
print("looping for variables_study")

#### new part with chatgpt -- TEST

## QUESTIONS TO ANSWER:
# HAVE ECAVIAR >=0.8
# HAVE COLOC 
# HAVE COLOC >= 0.8
# HAVE COLOC + ECAVIAR >= 0.01
# HAVE COLOC >= 0.8 + ECAVIAR >= 0.01
# RIGHT JOING WITH CHEMBL 

### FIFTH MODULE: BUILDING BENCHMARK OF THE DATASET TO EXTRACT EHE ANALYSIS 

resolvedColocFiltered = resolvedColoc.filter((F.col('clpp')>=0.01) | (F.col('h4')>=0.8))
benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")


### drug mechanism of action
mecact_path = f"{path_n}drug_mechanism_of_action/" #  mechanismOfAction == old version
mecact = spark.read.parquet(mecact_path)

inhibitors = [
    "RNAI INHIBITOR",
    "NEGATIVE MODULATOR",
    "NEGATIVE ALLOSTERIC MODULATOR",
    "ANTAGONIST",
    "ANTISENSE INHIBITOR",
    "BLOCKER",
    "INHIBITOR",
    "DEGRADER",
    "INVERSE AGONIST",
    "ALLOSTERIC ANTAGONIST",
    "DISRUPTING AGENT",
]

activators = [
    "PARTIAL AGONIST",
    "ACTIVATOR",
    "POSITIVE ALLOSTERIC MODULATOR",
    "POSITIVE MODULATOR",
    "AGONIST",
    "SEQUESTERING AGENT",  ## lost at 31.01.2025
    "STABILISER",
    # "EXOGENOUS GENE", ## added 24.06.2025
    # "EXOGENOUS PROTEIN" ## added 24.06.2025
]


actionType = (
        mecact.select(
            F.explode_outer("chemblIds").alias("drugId"),
            "actionType",
            "mechanismOfAction",
            "targets",
        )
        .select(
            F.explode_outer("targets").alias("targetId"),
            "drugId",
            "actionType",
            "mechanismOfAction",
        )
        .groupBy("targetId", "drugId")
        .agg(F.collect_set("actionType").alias("actionType2"))
    ).withColumn('nMoA', F.size(F.col('actionType2')))

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl")).join(actionType, on=['targetId','drugId'], how='left')
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase",'actionType2')
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)

benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

### create disdic dictionary
disdic={}

# --- Configuration for your iterative pivoting ---
group_by_columns = ['targetId', 'diseaseId','phase4Clean','phase3Clean','phase2Clean','phase1Clean','PhaseT']
#columns_to_pivot_on = ['actionType2', 'biosampleName', 'projectId', 'rightStudyType','colocalisationMethod']
columns_to_pivot_on = ['projectId']
columns_to_aggregate = ['NoneCellYes', 'NdiagonalYes','hasGenetics'] # The values you want to collect in the pivoted cells
all_pivoted_dfs = {}

doe_columns=["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]
diagonal_lof=['LoF_protect','GoF_risk']
diagonal_gof=['LoF_risk','GoF_protect']

conditions = [
    F.when(F.col(c) == F.col("maxDoE"), F.lit(c)).otherwise(F.lit(None)) for c in doe_columns
    ]
print('entering the big loops')
# --- Nested Loops for Dynamic Pivoting ---
for agg_col_name in columns_to_aggregate:
    for pivot_col_name in columns_to_pivot_on:
        print(f"\n--- Creating DataFrame for Aggregation: '{agg_col_name}' and Pivot: '{pivot_col_name}' ---")
        current_col_pvalue_order_window = Window.partitionBy("targetId", "diseaseId", "maxClinPhase", pivot_col_name).orderBy(F.col('colocalisationMethod').asc(), F.col("qtlPValueExponent").asc())
        test2=discrepancifier(benchmark.withColumn('actionType2', F.concat_ws(",", F.col("actionType2"))).withColumn('qtlColocDoE',F.first('colocDoE').over(current_col_pvalue_order_window)).groupBy(
        "targetId", "diseaseId", "maxClinPhase", "drugLoF_protect", "drugGoF_protect",pivot_col_name)
        .pivot("colocDoE")
        .count()
        .withColumnRenamed('drugLoF_protect', 'LoF_protect_ch')
        .withColumnRenamed('drugGoF_protect', 'GoF_protect_ch')).withColumn( ## .filter(F.col('coherencyDiagonal')!='noEvid')
    "arrayN", F.array(*[F.col(c) for c in doe_columns])
    ).withColumn(
        "maxDoE", F.array_max(F.col("arrayN"))
    ).withColumn("maxDoE_names", F.array(*conditions)
    ).withColumn("maxDoE_names", F.expr("filter(maxDoE_names, x -> x is not null)")
    ).withColumn(
        "NoneCellYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("LoF_protect")))==True, F.lit('yes'))
        .when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("GoF_protect")))==True, F.lit('yes')
            ).otherwise(F.lit('no'))  # If the value is null, return null # Otherwise, check if name is in array
    ).withColumn(
        "NdiagonalYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_lof]))) > 0),
            F.lit("yes")
        ).when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_gof]))) > 0),
            F.lit("yes")
        ).otherwise(F.lit('no'))
    ).withColumn(
        "drugCoherency",
        F.when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("dispar")
        )
        .otherwise(F.lit("other")),
    ).join(negativeTD, on=["targetId", "diseaseId"], how="left").withColumn(
        "PhaseT",
        F.when(F.col("stopReason") == "Negative", F.lit("yes")).otherwise(F.lit("no")),
    ).withColumn(
        "phase4Clean",
        F.when(
            (F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase3Clean",
        F.when(
            (F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase2Clean",
        F.when(
            (F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase1Clean",
        F.when(
            (F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "hasGenetics",
        F.when(F.col("coherencyDiagonal") != "noEvid", F.lit("yes")).otherwise(F.lit("no")),
    )
        # 1. Get distinct values for the pivot column (essential for pivot())
        # This brings a small amount of data to the driver, but is necessary for the pivot schema.
        #distinct_pivot_values = [row[0] for row in test2.select(pivot_col_name).distinct().collect()]
        # print(f"Distinct values for '{pivot_col_name}': {distinct_pivot_values}")

        # 2. Perform the groupBy, pivot, and aggregate operations
        # The .pivot() function requires the list of distinct values for better performance
        # and correct schema inference.
        pivoted_df = (
            test2.groupBy(*group_by_columns)
            .pivot(pivot_col_name) # Provide distinct values distinct_pivot_values
            .agg(F.collect_set(F.col(agg_col_name))) # Collect all values into a set
            .fillna(0) # Fill cells that have no data with an empty list instead of null
        )
        # 3. Add items to dictionary to map the columns:
        # filter out None and 'null':
        datasetColumns=pivoted_df.columns
        filtered = [x for x in datasetColumns if x is not None and x != 'null']
        # using list comprehension
        for item in filtered:
            disdic[item] = pivot_col_name

        # 3. Add the 'data' literal column dynamically
        # This column indicates which aggregation column was used.
        #pivoted_df = pivoted_df.withColumn('data', F.lit(f'Drug_{agg_col_name}'))

        array_columns_to_convert = [
            field.name for field in pivoted_df.schema.fields
            if isinstance(field.dataType, ArrayType)
        ]
        print(f"Identified ArrayType columns for conversion: {array_columns_to_convert}")

        # 4. Apply the conversion logic to each identified array column
        df_after_conversion = pivoted_df # Start with the pivoted_df
        for col_to_convert in array_columns_to_convert:
            df_after_conversion = df_after_conversion.withColumn(
                col_to_convert,
                F.when(F.col(col_to_convert).isNull(), F.lit('no'))          # Handle NULLs (from pivot for no data)
                .when(F.size(F.col(col_to_convert)) == 0, F.lit('no'))       # Empty array -> 'no'
                .when(F.array_contains(F.col(col_to_convert), F.lit('yes')), F.lit('yes')) # Contains 'yes' -> 'yes'
                .when(F.array_contains(F.col(col_to_convert), F.lit('no')), F.lit('no'))   # Contains 'no' -> 'no'
                .otherwise(F.lit('no')) # Fallback for unexpected array content (e.g., ['other'], ['yes','no'])
            )

        # 4. Generate a unique name for this DataFrame and store it
        df_key = f"df_pivot_{agg_col_name.lower()}_by_{pivot_col_name.lower()}"
        all_pivoted_dfs[df_key] = df_after_conversion.withColumnRenamed( 'phase4Clean','Phase>=4'
        ).withColumnRenamed('phase3Clean','Phase>=3'
        ).withColumnRenamed('phase2Clean','Phase>=2'
        ).withColumnRenamed('phase1Clean','Phase>=1')


# --- Accessing your generated DataFrames ---
print("\n--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---")
print("Keys available:", all_pivoted_dfs.keys())
############## HYBRID ##############
####################################
def strip_only(lst):
    return [x.removesuffix("_only") for x in lst]  # Python 3.9+
    # or: return [x[:-5] if x.endswith("_only") else x for x in lst]


##### PROJECTID
project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
main=['GTEx_only', 'UKB_PPP_EUR_only']
#stimulated=['Alasoo_2018_only','Cytoimmgen_only','Fairfax_2014_only','Kim-Hellmuth_2017_only','Nathan_2022_only','Nedelec_2016_only','Quach_2016_only','Randolph_2021_only','Schmiedel_2018_only']
#cellLine=['CAP_only','HipSci_only','iPSCORE_only','Jerber_2021_only','PhLiPS_only','Schwartzentruber_2018_only','TwinsUK_only']

derivedCellLine=['TwinsUK_only','PhLiPS_only','CAP_only','GENCORD_only','Sun_2018_only','Nedelec_2016_only']
canonicalCellLine=['Alasoo_2018_only','Jerber_2021_only','GEUVADIS_only','iPSCORE_only','Aygun_2021_only','Schwartzentruber_2018_only']
stimulated=['Schmiedel_2018_only','Bossini-Castillo_2019_only','Alasoo_2018_only','Cytoimmgen_only','Gilchrist_2021_only','CAP_only','Quach_2016_only','Randolph_2021_only','Sun_2018_only','Nedelec_2016_only','Kim-Hellmuth_2017_only']

# Apply
main = strip_only(main)
canonicalCellLine = strip_only(canonicalCellLine)
derivedCellLine = strip_only(derivedCellLine)
stimulated = strip_only(stimulated)

others=[item for item in strip_only(project_keys) if item not in main]
nonStimulated=[item for item in strip_only(project_keys) if item not in stimulated]
nonCanonicalCellLine = [item for item in strip_only(project_keys) if item not in canonicalCellLine]
nonDerivedCellLine = [item for item in strip_only(project_keys) if item not in derivedCellLine]

#otherCellLine=[item for item in strip_only(project_keys) if item not in cellLine]



spark session created at 2025-09-18 05:40:10.768588
Analysis started on 2025-09-18 at  2025-09-18 05:40:10.768588


25/09/18 05:40:17 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/18 05:40:17 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created successfully with the following configurations:
  spark.driver.memory: 24g
  spark.executor.memory: 32g
  spark.executor.cores: 4
  spark.executor.instances: 12
  spark.yarn.executor.memoryOverhead: 8g
  spark.sql.shuffle.partitions: 192
  spark.default.parallelism: 192
Spark UI available at: http://jr-doe-temp1-m.c.open-targets-eu-dev.internal:45511
loaded files
loaded newColoc


                                                                                

loaded gwasComplete
loaded resolvedColloc
run temporary direction of effect
built drugApproved dataset


                                                                                

load comparisons_df_iterative function
created full_data and lists
loaded rightTissue dataset
built negativeTD dataset
built bench2 dataset
looping for variables_study
entering the big loops

--- Creating DataFrame for Aggregation: 'NoneCellYes' and Pivot: 'projectId' ---


                                                                                

Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- Creating DataFrame for Aggregation: 'NdiagonalYes' and Pivot: 'projectId' ---


25/09/18 05:44:18 WARN CacheManager: Asked to cache already cached data.        


Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- Creating DataFrame for Aggregation: 'hasGenetics' and Pivot: 'projectId' ---


25/09/18 05:44:53 WARN CacheManager: Asked to cache already cached data.        


Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---
Keys available: dict_keys(['df_pivot_nonecellyes_by_projectid', 'df_pivot_ndiagonalyes_by_projectid', 'df_pivot_hasgenetics_by_projectid'])


In [2]:
# Apply
main = strip_only(main)
canonicalCellLine = strip_only(canonicalCellLine)
derivedCellLine = strip_only(derivedCellLine)
stimulated = strip_only(stimulated)

others=[item for item in strip_only(project_keys) if item not in main]
nonStimulated=[item for item in strip_only(project_keys) if item not in stimulated]
nonCanonicalCellLine = [item for item in strip_only(project_keys) if item not in canonicalCellLine]
nonDerivedCellLine = [item for item in strip_only(project_keys) if item not in derivedCellLine]

In [14]:

def _or_yes(df, cols):
    """Return a Column that is TRUE if any of the given columns == 'yes'.
       Ignores columns not present in df. If none present, returns FALSE.
    """
    present = [c for c in cols if c in df.columns]
    if not present:
        return F.lit(False)
    # (col == 'yes') OR (col == 'yes') OR ...
    exprs = [F.col(c) == "yes" for c in present]
    return reduce(lambda a, b: a | b, exprs)
def add_project_group_flags(df, main, canonicalCellLine, derivedCellLine, stimulated):
    # project keys = every *_only column in this DF
    #project_keys = [c for c in df.columns if c.endswith("_only")]
    project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
    # Buckets (suffix preserved)
    others=[item for item in strip_only(project_keys) if item not in main]
    nonStimulated=[item for item in strip_only(project_keys) if item not in stimulated]
    nonCanonicalCellLine = [item for item in strip_only(project_keys) if item not in canonicalCellLine]
    nonDerivedCellLine = [item for item in strip_only(project_keys) if item not in derivedCellLine]

    # Conditions (each independent)
    cond_others          = _or_yes(df, others)  
    cond_main            = _or_yes(df, main)  
    cond_stimulated      = _or_yes(df, stimulated)    
    cond_nonStimulated   = _or_yes(df, nonStimulated)  
    cond_canonical       = _or_yes(df, canonicalCellLine)     
    cond_nonCanonical    = _or_yes(df, nonCanonicalCellLine)  
    cond_derived         = _or_yes(df, derivedCellLine)  
    cond_nonDerived      = _or_yes(df, nonDerivedCellLine)         

    return (
        df.withColumn("othersProjectId_only",   F.when(cond_others, "yes").otherwise("no"))
          .withColumn("GTExUKB_only",           F.when(cond_main, "yes").otherwise("no"))
          .withColumn("stimulated_only",        F.when(cond_stimulated, "yes").otherwise("no"))
          .withColumn("nonStimulated",          F.when(cond_nonStimulated, "yes").otherwise("no"))
          .withColumn("canonicalCellLine",      F.when(cond_canonical, "yes").otherwise("no"))
          .withColumn("nonCanonicalCellLine",   F.when(cond_nonCanonical, "yes").otherwise("no"))
          .withColumn("derivedCellLine",        F.when(cond_derived, "yes").otherwise("no"))
          .withColumn("nonDerivedCellLine",     F.when(cond_nonDerived, "yes").otherwise("no"))
    )


# --- Apply to the dict entry you mentioned ---
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = add_project_group_flags(
    df=all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'],
    main=main,
    canonicalCellLine=canonicalCellLine,
    derivedCellLine=derivedCellLine,
    stimulated=stimulated
)

all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = add_project_group_flags(
    df=all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'],
    main=main,
    canonicalCellLine=canonicalCellLine,
    derivedCellLine=derivedCellLine,
    stimulated=stimulated
)

all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = add_project_group_flags(
    df=all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'],
    main=main,
    canonicalCellLine=canonicalCellLine,
    derivedCellLine=derivedCellLine,
    stimulated=stimulated
)

# If you wanted to apply to every DF in the dict (only if they all share *_only columns):
# for k, df in all_pivoted_dfs.items():
#     all_pivoted_dfs[k] = add_project_group_flags(df, main, stimulated, cellLine)

###append to dictionary

disdic.update({'othersProjectId': 'projectId','Stimulated': 'projectId','cellLine': 'projectId', 'othersBiosampleName_only': 'biosampleName', 'otherRightStudyType':'rightStudyType'})

###################################
###################################
result = []
result_st = []
result_ci = []
array2 = []
listado = []
result_all = []
today_date = str(date.today())

for key,df in all_pivoted_dfs.items():

    print(f'working with {key}')
    parts = key.split('_by_') ### take the part of key belonging to column name
    column_name = parts[1] ### take the last part which is column name
    all_pivoted_dfs[key].persist()
    #unique_values = all_pivoted_dfs[key].drop('null').columns[7:]
    unique_values = all_pivoted_dfs[key].drop('null').columns[-8:] ### just the interesting columns for us 
    filtered_unique_values = [x for x in unique_values if x is not None and x != 'null']
    print('There are ', len(filtered_unique_values), 'columns to analyse with phases')
    rows = comparisons_df_iterative(filtered_unique_values)

    # If needed, now process the rest
    for row in rows:
        print('performing', row)
        results = aggregations_original(
            all_pivoted_dfs[key], key, listado, *row, today_date
        )
        result_all.append(results)
        print('results appended')
    all_pivoted_dfs[key].unpersist()
    print('df unpersisted')


schema = StructType(
    [
        StructField("group", StringType(), True),
        StructField("comparison", StringType(), True),
        StructField("phase", StringType(), True),
        StructField("oddsRatio", DoubleType(), True),
        StructField("pValue", DoubleType(), True),
        StructField("lowerInterval", DoubleType(), True),
        StructField("upperInterval", DoubleType(), True),
        StructField("total", StringType(), True),
        StructField("values", ArrayType(ArrayType(IntegerType())), True),
        StructField("relSuccess", DoubleType(), True),
        StructField("rsLower", DoubleType(), True),
        StructField("rsUpper", DoubleType(), True),
        StructField("path", StringType(), True),
    ]
)
import re

# Define the list of patterns to search for
patterns = [
    "_only",
    #"_tissue",
    #"_isSignalFromRightTissue",
    "_isRightTissueSignalAgreed",
]
# Create a regex pattern to match any of the substrings
regex_pattern = "(" + "|".join(map(re.escape, patterns)) + ")"

# Convert list of lists to DataFrame
df = (
    spreadSheetFormatter(spark.createDataFrame(result_all, schema=schema))
    .withColumn(
        "prefix",
        F.regexp_replace(
            F.col("comparison"), regex_pattern + ".*", ""
        ),  # Extract part before the pattern
    )
    .withColumn(
        "suffix",
        F.regexp_extract(
            F.col("comparison"), regex_pattern, 0
        ),  # Extract the pattern itself
    )
)

### annotate projectId, tissue, qtl type and doe type:

from pyspark.sql.functions import create_map
from itertools import chain

mapping_expr=create_map([F.lit(x) for x in chain(*disdic.items())])

df_annot=df.withColumn('annotation',mapping_expr.getItem(F.col('prefix')))

df_annot.toPandas().to_csv(
    f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_AllPhasesMixtures2.csv"
)

print("dataframe written \n Analysis finished")

working with df_pivot_nonecellyes_by_projectid
There are  8 columns to analyse with phases
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=4', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=4.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecel

                                                                                

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_PhaseT.parquet
results 

                                                                                

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_PhaseT.parquet
results appe



dataframe written 
 Analysis finished


In [16]:
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].groupBy('UKB_PPP_EUR','othersProjectId_only').count().show()

+-----------+--------------------+-----+
|UKB_PPP_EUR|othersProjectId_only|count|
+-----------+--------------------+-----+
|         no|                 yes| 4461|
|         no|                  no|69699|
|        yes|                  no|   22|
|        yes|                 yes|    5|
+-----------+--------------------+-----+



In [17]:
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].filter((F.col('UKB_PPP_EUR')=='yes') & (F.col('othersProjectId_only')=='yes')).show()

+---------------+-------------+--------+--------+--------+--------+------+----+-----------+----------+---------+---------------------+--------+---------+---+-----+----------+----------+------+------------+------------+-------+--------+----+--------------+------+-----------+-----------+-----------------+----------+--------------+-----------+------------+------+----+---------+----------+------+----------+------+-------------+--------------+---------------------+--------------+--------+-------+-----------+-----------+----------+-------+----------------+--------------------+------------+---------------+-------------+-----------------+--------------------+---------------+------------------+
|       targetId|    diseaseId|Phase>=4|Phase>=3|Phase>=2|Phase>=1|PhaseT|null|Alasoo_2018|Aygun_2021|BLUEPRINT|Bossini-Castillo_2019|BrainSeq|Braineac2|CAP|CEDAR|CommonMind|Cytoimmgen|FUSION|Fairfax_2012|Fairfax_2014|GENCORD|GEUVADIS|GTEx|Gilchrist_2021|HipSci|Jerber_2021|Kasela_2017|Kim-Hellmuth_2017|L

In [33]:
project_list = benchmark.select("projectId").distinct().rdd.map(lambda r: r[0]).collect()


Exception in thread "serve-DataFrame" java.net.SocketTimeoutException: Accept timed out]
	at java.base/java.net.PlainSocketImpl.socketAccept(Native Method)
	at java.base/java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:474)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:565)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:533)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:65)
                                                                                

In [34]:
project_list

['Quach_2016',
 'UKB_PPP_EUR',
 'FUSION',
 'BLUEPRINT',
 'GTEx',
 'GEUVADIS',
 'BrainSeq',
 'TwinsUK',
 'Lepik_2017',
 'HipSci',
 'Schmiedel_2018',
 'Fairfax_2014',
 'Bossini-Castillo_2019',
 'ROSMAP',
 'Peng_2018',
 'CommonMind',
 'Cytoimmgen',
 'Alasoo_2018',
 'Kim-Hellmuth_2017',
 'CAP',
 'Nedelec_2016',
 'Walker_2019',
 'Jerber_2021',
 'Kasela_2017',
 'Fairfax_2012',
 'CEDAR',
 'OneK1K',
 'Aygun_2021',
 'GENCORD',
 'PhLiPS',
 'Schwartzentruber_2018',
 'Young_2019',
 'Sun_2018',
 'van_de_Bunt_2015',
 'PISA',
 'Perez_2022',
 'Steinberg_2020',
 'Gilchrist_2021',
 'iPSCORE',
 'Braineac2',
 'Nathan_2022',
 'Naranbhai_2015',
 'Randolph_2021',
 None]

In [None]:
##### PROJECTID

project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
main=['GTEx_only', 'UKB_PPP_EUR_only']
#stimulated=['Alasoo_2018_only','Cytoimmgen_only','Fairfax_2014_only','Kim-Hellmuth_2017_only','Nathan_2022_only','Nedelec_2016_only','Quach_2016_only','Randolph_2021_only','Schmiedel_2018_only']
#cellLine=['CAP_only','HipSci_only','iPSCORE_only','Jerber_2021_only','PhLiPS_only','Schwartzentruber_2018_only','TwinsUK_only']

derivedCellLine=['TwinsUK_only','PhLiPS_only','CAP_only','GENCORD_only','Sun_2018_only','Nedelec_2016_only']
canonicalCellLine=['Alasoo_2018_only','Jerber_2021_only','GEUVADIS_only','iPSCORE_only','Aygun_2021_only','Schwartzentruber_2018_only']
stimulated=['Schmiedel_2018_only','Bossini-Castillo_2019_only','Alasoo_2018_only','Cytoimmgen_only','Gilchrist_2021_only','CAP_only','Quach_2016_only','Randolph_2021_only','Sun_2018_only','Nedelec_2016_only','Kim-Hellmuth_2017_only']

# Apply
main = strip_only(main)
canonicalCellLine = strip_only(canonicalCellLine)
derivedCellLine = strip_only(derivedCellLine)
stimulated = strip_only(stimulated)

others=[item for item in strip_only(project_keys[7:]) if item not in main]
nonStimulated=[item for item in strip_only(project_keys[7:]) if item not in stimulated]
nonCanonicalCellLine = [item for item in strip_only(project_keys[7:]) if item not in canonicalCellLine]
nonDerivedCellLine = [item for item in strip_only(project_keys[7:]) if item not in derivedCellLine]

In [37]:
others[0:]

['Kasela_2017',
 'Schmiedel_2018',
 'Fairfax_2012',
 'Cytoimmgen',
 'Bossini-Castillo_2019',
 'CEDAR',
 'OneK1K',
 'BLUEPRINT',
 'ROSMAP',
 'CommonMind',
 'BrainSeq',
 'HipSci',
 'Quach_2016',
 'Nathan_2022',
 'Steinberg_2020',
 'CAP',
 'TwinsUK',
 'iPSCORE',
 'GENCORD',
 'Peng_2018',
 'Nedelec_2016',
 'Alasoo_2018',
 'Schwartzentruber_2018',
 'Aygun_2021',
 'Walker_2019',
 'GEUVADIS',
 'FUSION',
 'Lepik_2017',
 'van_de_Bunt_2015',
 'Perez_2022',
 'Fairfax_2014',
 'PISA',
 'PhLiPS',
 'Sun_2018',
 'Jerber_2021',
 'Kim-Hellmuth_2017',
 'Gilchrist_2021',
 'Braineac2',
 'Young_2019',
 'Naranbhai_2015',
 'Randolph_2021',
 None]

In [1]:
import time
from array import ArrayType
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
# from stoppedTrials import terminated_td
from DoEAssessment import directionOfEffect
# from membraneTargets import target_membrane
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from datetime import datetime
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    DecimalType,
    StringType,
    FloatType,
)
import pandas as pd
from functools import reduce


# --- Build the SparkSession ---
# Use the .config() method to set these parameters before calling .getOrCreate()
# This ensures Spark requests the correct resources from YARN at the start.
driver_memory = "24g"                 # plenty for planning & small collects
executor_cores = 4                    # sweet spot for GC + Python workers
num_executors  = 12                   # 12 * 4 = 48 cores for executors; ~16 cores left for driver/OS
executor_memory = "32g"               # per executor heap
executor_memory_overhead = "8g"       # ~20% overhead for PySpark/Arrow/off-heap
# Totals: (32+8) * 12 = 480 GB executors + 24 GB driver ≈ 504 GB (adjust down if your hard cap is <500 GB)
# If you must stay strictly ≤ 500 GB, use executor_memory="30g", overhead="6g"  → (36 * 12) + 24 = 456 + 24 = 480 GB

shuffle_partitions   = 192            # ≈ 2–4× total cores (48) → start with 192
default_parallelism  = 192

spark = SparkSession.builder \
    .appName("MyOptimizedPySparkApp") \
    .config("spark.master", "yarn") \
    .config("spark.driver.memory", driver_memory) \
    .config("spark.executor.memory", executor_memory) \
    .config("spark.executor.cores", executor_cores) \
    .config("spark.executor.instances", num_executors) \
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead) \
    .config("spark.sql.shuffle.partitions", shuffle_partitions) \
    .config("spark.default.parallelism", default_parallelism) \
    .getOrCreate()

print(f"SparkSession created successfully with the following configurations:")
print(f"  spark.driver.memory: {spark.conf.get('spark.driver.memory')}")
print(f"  spark.executor.memory: {spark.conf.get('spark.executor.memory')}")
print(f"  spark.executor.cores: {spark.conf.get('spark.executor.cores')}")
print(f"  spark.executor.instances: {spark.conf.get('spark.executor.instances')}")
print(f"  spark.yarn.executor.memoryOverhead: {spark.conf.get('spark.yarn.executor.memoryOverhead')}")
print(f"  spark.sql.shuffle.partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
print(f"  spark.default.parallelism: {spark.conf.get('spark.default.parallelism')}")
print(f"Spark UI available at: {spark.sparkContext.uiWebUrl}")

# --- Your PySpark Code Here ---
# Now you can proceed with your data loading and processing.
# Example:
# df = spark.read.parquet("hdfs:///user/your_user/your_large_data.parquet")
# print(f"Number of rows in DataFrame: {df.count()}")
# df.groupBy("some_column").agg({"another_column": "sum"}).show()

# Remember to stop the SparkSession when you are done
# spark.stop()

path_n='gs://open-targets-data-releases/25.06/output/'

target = spark.read.parquet(f"{path_n}target/")

diseases = spark.read.parquet(f"{path_n}disease/")

evidences = spark.read.parquet(f"{path_n}evidence")

credible = spark.read.parquet(f"{path_n}credible_set")

new = spark.read.parquet(f"{path_n}colocalisation_coloc") 

index=spark.read.parquet(f"{path_n}study/")

variantIndex = spark.read.parquet(f"{path_n}variant")

biosample = spark.read.parquet(f"{path_n}biosample")

ecaviar=spark.read.parquet(f"{path_n}colocalisation_ecaviar")

all_coloc=ecaviar.unionByName(new, allowMissingColumns=True)

print("loaded files")

#### FIRST MODULE: BUILDING COLOC 
newColoc=buildColocData(all_coloc,credible,index)

print("loaded newColoc")

### SECOND MODULE: PROCESS EVIDENCES TO AVOID EXCESS OF COLUMNS 
gwasComplete = gwasDataset(evidences,credible)

#### THIRD MODULE: INCLUDE COLOC IN THE 
resolvedColoc = (
    (
        newColoc.withColumnRenamed("geneId", "targetId")
        .join(
            gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
            on=["leftStudyLocusId", "targetId"],
            how="inner",
        )
        .join(  ### propagated using parent terms
            diseases.selectExpr(
                "id as diseaseId", "name", "parents", "therapeuticAreas"
            ),
            on="diseaseId",
            how="left",
        )
        .withColumn(
            "diseaseId",
            F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
        )
        .drop("parents", "oldDiseaseId")
    ).withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(
                ["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]
            ),
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_protect"),
            ),
        ).when(
            F.col("rightStudyType").isin(
                ["sqtl", "scsqtl"]
            ),  ### opposite directionality than sqtl
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_protect"),
            ),
        ),
    )
    # .persist()
)
print("loaded resolvedColloc")

datasource_filter = [
#   "ot_genetics_portal",
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]

assessment, evidences, actionType, oncolabel = temporary_directionOfEffect(
    path_n, datasource_filter
)

print("run temporary direction of effect")


print("built drugApproved dataset")


#### FOURTH MODULE BUILDING CHEMBL ASSOCIATIONS - HERE TAKE CARE WITH FILTERING STEP 
analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl"))
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
    # .persist()
)

####2 Define agregation function
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.types import *


def convertTuple(tup):
    st = ",".join(map(str, tup))
    return st


#####3 run in a function
def aggregations_original(
    df,
    data,
    listado,
    comparisonColumn,
    comparisonType,
    predictionColumn,
    predictionType,
    today_date,
):
    wComparison = Window.partitionBy(comparisonColumn)
    wPrediction = Window.partitionBy(predictionColumn)
    wPredictionComparison = Window.partitionBy(comparisonColumn, predictionColumn)
    results = []
    # uniqIds = df.select("targetId", "diseaseId").distinct().count()
    out = (
        df.withColumn("comparisonType", F.lit(comparisonType))
        .withColumn("dataset", F.lit(data))
        .withColumn("predictionType", F.lit(predictionType))
        # .withColumn("total", F.lit(uniqIds))
        .withColumn("a", F.count("targetId").over(wPredictionComparison))
        .withColumn("comparisonColumn", F.lit(comparisonColumn))
        .withColumn("predictionColumnValue", F.lit(predictionColumn))
        .withColumn(
            "predictionTotal",
            F.count("targetId").over(wPrediction),
        )
        .withColumn(
            "comparisonTotal",
            F.count("targetId").over(wComparison),
        )
        .select(
            F.col(predictionColumn).alias("prediction"),
            F.col(comparisonColumn).alias("comparison"),
            "dataset",
            "comparisonColumn",
            "predictionColumnValue",
            "comparisonType",
            "predictionType",
            "a",
            "predictionTotal",
            "comparisonTotal",
        )
        .filter(F.col("prediction").isNotNull())
        .filter(F.col("comparison").isNotNull())
        .distinct()
    )
    '''
    out.write.mode("overwrite").parquet(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    '''

    listado.append(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    path = "gs://ot-team/jroldan/" + str(
        today_date
        + "_"
        + "analysis/"
        + data
        # + "_propagated"
        + "/"
        + comparisonColumn
        + "_"
        + comparisonType
        + "_"
        + predictionColumn
        + ".parquet"
    )
    print(path)
    
    ### making analysis
    array1 = np.delete(
        out.join(full_data, on=["prediction", "comparison"], how="outer")
        .groupBy("comparison")
        .pivot("prediction")
        .agg(F.first("a"))
        .sort(F.col("comparison").desc())
        .select("comparison", "yes", "no")
        .fillna(0)
        .toPandas()
        .to_numpy(),
        [0],
        1,
    )
    total = np.sum(array1)
    res_npPhaseX = np.array(array1, dtype=int)
    resX = convertTuple(fisher_exact(res_npPhaseX, alternative="two-sided"))
    resx_CI = convertTuple(
        odds_ratio(res_npPhaseX).confidence_interval(confidence_level=0.95)
    )

    result_st.append(resX)
    result_ci.append(resx_CI)
    (rs_result, rs_ci) = relative_success(array1)
    results.extend(
        [
            comparisonType,
            comparisonColumn,
            predictionColumn,
            round(float(resX.split(",")[0]), 2),
            float(resX.split(",")[1]),
            round(float(resx_CI.split(",")[0]), 2),
            round(float(resx_CI.split(",")[1]), 2),
            str(total),
            np.array(res_npPhaseX).tolist(),
            round(float(rs_result), 2),
            round(float(rs_ci[0]), 2),
            round(float(rs_ci[1]), 2),
            # studies,
            # tissues,
            path,
        ]
    )
    return results


#### 3 Loop over different datasets (as they will have different rows and columns)


def comparisons_df_iterative(elements):
    #toAnalysis = [(key, value) for key, value in disdic.items() if value == projectId]
    toAnalysis = [(col, "predictor") for col in elements]
    schema = StructType(
        [
            StructField("comparison", StringType(), True),
            StructField("comparisonType", StringType(), True),
        ]
    )

    comparisons = spark.createDataFrame(toAnalysis, schema=schema)
    ### include all the columns as predictor

    predictions = spark.createDataFrame(
        data=[
            ("Phase>=4", "clinical"),
            ('Phase>=3','clinical'),
            ('Phase>=2','clinical'),
            ('Phase>=1','clinical'),
            ("PhaseT", "clinical"),
        ]
    )
    return comparisons.join(predictions, how="full").collect()


print("load comparisons_df_iterative function")


full_data = spark.createDataFrame(
    data=[
        ("yes", "yes"),
        ("yes", "no"),
        ("no", "yes"),
        ("no", "no"),
    ],
    schema=StructType(
        [
            StructField("prediction", StringType(), True),
            StructField("comparison", StringType(), True),
        ]
    ),
)
print("created full_data and lists")

#rightTissue = spark.read.csv(
#    'gs://ot-team/jroldan/analysis/20250526_rightTissue.csv',
#    header=True,
#).drop("_c0")

print("loaded rightTissue dataset")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

print("built negativeTD dataset")

print("built bench2 dataset")

###### cut from here
print("looping for variables_study")

#### new part with chatgpt -- TEST

## QUESTIONS TO ANSWER:
# HAVE ECAVIAR >=0.8
# HAVE COLOC 
# HAVE COLOC >= 0.8
# HAVE COLOC + ECAVIAR >= 0.01
# HAVE COLOC >= 0.8 + ECAVIAR >= 0.01
# RIGHT JOING WITH CHEMBL 

### FIFTH MODULE: BUILDING BENCHMARK OF THE DATASET TO EXTRACT EHE ANALYSIS 

resolvedColocFiltered = resolvedColoc.filter((F.col('clpp')>=0.01) | (F.col('h4')>=0.8))
benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")


### drug mechanism of action
mecact_path = f"{path_n}drug_mechanism_of_action/" #  mechanismOfAction == old version
mecact = spark.read.parquet(mecact_path)

inhibitors = [
    "RNAI INHIBITOR",
    "NEGATIVE MODULATOR",
    "NEGATIVE ALLOSTERIC MODULATOR",
    "ANTAGONIST",
    "ANTISENSE INHIBITOR",
    "BLOCKER",
    "INHIBITOR",
    "DEGRADER",
    "INVERSE AGONIST",
    "ALLOSTERIC ANTAGONIST",
    "DISRUPTING AGENT",
]

activators = [
    "PARTIAL AGONIST",
    "ACTIVATOR",
    "POSITIVE ALLOSTERIC MODULATOR",
    "POSITIVE MODULATOR",
    "AGONIST",
    "SEQUESTERING AGENT",  ## lost at 31.01.2025
    "STABILISER",
    # "EXOGENOUS GENE", ## added 24.06.2025
    # "EXOGENOUS PROTEIN" ## added 24.06.2025
]


actionType = (
        mecact.select(
            F.explode_outer("chemblIds").alias("drugId"),
            "actionType",
            "mechanismOfAction",
            "targets",
        )
        .select(
            F.explode_outer("targets").alias("targetId"),
            "drugId",
            "actionType",
            "mechanismOfAction",
        )
        .groupBy("targetId", "drugId")
        .agg(F.collect_set("actionType").alias("actionType2"))
    ).withColumn('nMoA', F.size(F.col('actionType2')))

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl")).join(actionType, on=['targetId','drugId'], how='left')
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase",'actionType2')
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)

benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

### create disdic dictionary
disdic={}

# --- Configuration for your iterative pivoting ---
group_by_columns = ['targetId', 'diseaseId','phase4Clean','phase3Clean','phase2Clean','phase1Clean','PhaseT']
#columns_to_pivot_on = ['actionType2', 'biosampleName', 'projectId', 'rightStudyType','colocalisationMethod']
columns_to_pivot_on = ['projectId']
columns_to_aggregate = ['NoneCellYes', 'NdiagonalYes','hasGenetics'] # The values you want to collect in the pivoted cells
all_pivoted_dfs = {}

doe_columns=["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]
diagonal_lof=['LoF_protect','GoF_risk']
diagonal_gof=['LoF_risk','GoF_protect']

conditions = [
    F.when(F.col(c) == F.col("maxDoE"), F.lit(c)).otherwise(F.lit(None)) for c in doe_columns
    ]
print('entering the big loops')
# --- Nested Loops for Dynamic Pivoting ---
for agg_col_name in columns_to_aggregate:
    for pivot_col_name in columns_to_pivot_on:
        print(f"\n--- Creating DataFrame for Aggregation: '{agg_col_name}' and Pivot: '{pivot_col_name}' ---")
        current_col_pvalue_order_window = Window.partitionBy("targetId", "diseaseId", "maxClinPhase", pivot_col_name).orderBy(F.col('colocalisationMethod').asc(), F.col("qtlPValueExponent").asc())
        test2=discrepancifier(benchmark.withColumn('actionType2', F.concat_ws(",", F.col("actionType2"))).withColumn('qtlColocDoE',F.first('colocDoE').over(current_col_pvalue_order_window)).groupBy(
        "targetId", "diseaseId", "maxClinPhase", "drugLoF_protect", "drugGoF_protect",pivot_col_name)
        .pivot("colocDoE")
        .count()
        .withColumnRenamed('drugLoF_protect', 'LoF_protect_ch')
        .withColumnRenamed('drugGoF_protect', 'GoF_protect_ch')).withColumn( ## .filter(F.col('coherencyDiagonal')!='noEvid')
    "arrayN", F.array(*[F.col(c) for c in doe_columns])
    ).withColumn(
        "maxDoE", F.array_max(F.col("arrayN"))
    ).withColumn("maxDoE_names", F.array(*conditions)
    ).withColumn("maxDoE_names", F.expr("filter(maxDoE_names, x -> x is not null)")
    ).withColumn(
        "NoneCellYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("LoF_protect")))==True, F.lit('yes'))
        .when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("GoF_protect")))==True, F.lit('yes')
            ).otherwise(F.lit('no'))  # If the value is null, return null # Otherwise, check if name is in array
    ).withColumn(
        "NdiagonalYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_lof]))) > 0),
            F.lit("yes")
        ).when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_gof]))) > 0),
            F.lit("yes")
        ).otherwise(F.lit('no'))
    ).withColumn(
        "drugCoherency",
        F.when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("dispar")
        )
        .otherwise(F.lit("other")),
    ).join(negativeTD, on=["targetId", "diseaseId"], how="left").withColumn(
        "PhaseT",
        F.when(F.col("stopReason") == "Negative", F.lit("yes")).otherwise(F.lit("no")),
    ).withColumn(
        "phase4Clean",
        F.when(
            (F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase3Clean",
        F.when(
            (F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase2Clean",
        F.when(
            (F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase1Clean",
        F.when(
            (F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "hasGenetics",
        F.when(F.col("coherencyDiagonal") != "noEvid", F.lit("yes")).otherwise(F.lit("no")),
    )
        # 1. Get distinct values for the pivot column (essential for pivot())
        # This brings a small amount of data to the driver, but is necessary for the pivot schema.
        #distinct_pivot_values = [row[0] for row in test2.select(pivot_col_name).distinct().collect()]
        # print(f"Distinct values for '{pivot_col_name}': {distinct_pivot_values}")

        # 2. Perform the groupBy, pivot, and aggregate operations
        # The .pivot() function requires the list of distinct values for better performance
        # and correct schema inference.
        pivoted_df = (
            test2.groupBy(*group_by_columns)
            .pivot(pivot_col_name) # Provide distinct values distinct_pivot_values
            .agg(F.collect_set(F.col(agg_col_name))) # Collect all values into a set
            .fillna(0) # Fill cells that have no data with an empty list instead of null
        )
        # 3. Add items to dictionary to map the columns:
        # filter out None and 'null':
        datasetColumns=pivoted_df.columns
        filtered = [x for x in datasetColumns if x is not None and x != 'null']
        # using list comprehension
        for item in filtered:
            disdic[item] = pivot_col_name

        # 3. Add the 'data' literal column dynamically
        # This column indicates which aggregation column was used.
        #pivoted_df = pivoted_df.withColumn('data', F.lit(f'Drug_{agg_col_name}'))

        array_columns_to_convert = [
            field.name for field in pivoted_df.schema.fields
            if isinstance(field.dataType, ArrayType)
        ]
        print(f"Identified ArrayType columns for conversion: {array_columns_to_convert}")

        # 4. Apply the conversion logic to each identified array column
        df_after_conversion = pivoted_df # Start with the pivoted_df
        for col_to_convert in array_columns_to_convert:
            df_after_conversion = df_after_conversion.withColumn(
                col_to_convert,
                F.when(F.col(col_to_convert).isNull(), F.lit('no'))          # Handle NULLs (from pivot for no data)
                .when(F.size(F.col(col_to_convert)) == 0, F.lit('no'))       # Empty array -> 'no'
                .when(F.array_contains(F.col(col_to_convert), F.lit('yes')), F.lit('yes')) # Contains 'yes' -> 'yes'
                .when(F.array_contains(F.col(col_to_convert), F.lit('no')), F.lit('no'))   # Contains 'no' -> 'no'
                .otherwise(F.lit('no')) # Fallback for unexpected array content (e.g., ['other'], ['yes','no'])
            )

        # 4. Generate a unique name for this DataFrame and store it
        df_key = f"df_pivot_{agg_col_name.lower()}_by_{pivot_col_name.lower()}"
        all_pivoted_dfs[df_key] = df_after_conversion.withColumnRenamed( 'phase4Clean','Phase>=4'
        ).withColumnRenamed('phase3Clean','Phase>=3'
        ).withColumnRenamed('phase2Clean','Phase>=2'
        ).withColumnRenamed('phase1Clean','Phase>=1')


# --- Accessing your generated DataFrames ---
print("\n--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---")
print("Keys available:", all_pivoted_dfs.keys())

spark session created at 2025-09-18 06:58:53.503052
Analysis started on 2025-09-18 at  2025-09-18 06:58:53.503052


25/09/18 06:58:58 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
25/09/18 06:58:58 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


SparkSession created successfully with the following configurations:
  spark.driver.memory: 24g
  spark.executor.memory: 32g
  spark.executor.cores: 4
  spark.executor.instances: 12
  spark.yarn.executor.memoryOverhead: 8g
  spark.sql.shuffle.partitions: 192
  spark.default.parallelism: 192
Spark UI available at: http://jr-doe-temp1-m.c.open-targets-eu-dev.internal:38947


                                                                                

loaded files
loaded newColoc


                                                                                

loaded gwasComplete
loaded resolvedColloc
run temporary direction of effect
built drugApproved dataset


                                                                                

load comparisons_df_iterative function
created full_data and lists
loaded rightTissue dataset
built negativeTD dataset
built bench2 dataset
looping for variables_study
entering the big loops

--- Creating DataFrame for Aggregation: 'NoneCellYes' and Pivot: 'projectId' ---


                                                                                

Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- Creating DataFrame for Aggregation: 'NdiagonalYes' and Pivot: 'projectId' ---


25/09/18 07:01:57 WARN CacheManager: Asked to cache already cached data.        


Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- Creating DataFrame for Aggregation: 'hasGenetics' and Pivot: 'projectId' ---


25/09/18 07:02:31 WARN CacheManager: Asked to cache already cached data.        


Identified ArrayType columns for conversion: ['null', 'Alasoo_2018', 'Aygun_2021', 'BLUEPRINT', 'Bossini-Castillo_2019', 'BrainSeq', 'Braineac2', 'CAP', 'CEDAR', 'CommonMind', 'Cytoimmgen', 'FUSION', 'Fairfax_2012', 'Fairfax_2014', 'GENCORD', 'GEUVADIS', 'GTEx', 'Gilchrist_2021', 'HipSci', 'Jerber_2021', 'Kasela_2017', 'Kim-Hellmuth_2017', 'Lepik_2017', 'Naranbhai_2015', 'Nathan_2022', 'Nedelec_2016', 'OneK1K', 'PISA', 'Peng_2018', 'Perez_2022', 'PhLiPS', 'Quach_2016', 'ROSMAP', 'Randolph_2021', 'Schmiedel_2018', 'Schwartzentruber_2018', 'Steinberg_2020', 'Sun_2018', 'TwinsUK', 'UKB_PPP_EUR', 'Walker_2019', 'Young_2019', 'iPSCORE', 'van_de_Bunt_2015']

--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---
Keys available: dict_keys(['df_pivot_nonecellyes_by_projectid', 'df_pivot_ndiagonalyes_by_projectid', 'df_pivot_hasgenetics_by_projectid'])


In [4]:
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].show()

+---------------+--------------+--------+--------+--------+--------+------+----+-----------+----------+---------+---------------------+--------+---------+---+-----+----------+----------+------+------------+------------+-------+--------+----+--------------+------+-----------+-----------+-----------------+----------+--------------+-----------+------------+------+----+---------+----------+------+----------+------+-------------+--------------+---------------------+--------------+--------+-------+-----------+-----------+----------+-------+----------------+--------------------+----------------+--------------+-----------------+--------------------+---------------+------------------+-------+
|       targetId|     diseaseId|Phase>=4|Phase>=3|Phase>=2|Phase>=1|PhaseT|null|Alasoo_2018|Aygun_2021|BLUEPRINT|Bossini-Castillo_2019|BrainSeq|Braineac2|CAP|CEDAR|CommonMind|Cytoimmgen|FUSION|Fairfax_2012|Fairfax_2014|GENCORD|GEUVADIS|GTEx|Gilchrist_2021|HipSci|Jerber_2021|Kasela_2017|Kim-Hellmuth_2017|Le

In [3]:
##### PROJECTID
project_keys = (
    benchmark
    .select("projectId")
    .distinct()
    .rdd
    .map(lambda r: r[0])
    .filter(lambda x: x is not None)  # <- remove NULLs
    .collect()
)
#project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
main=['GTEx_only', 'UKB_PPP_EUR_only']
#stimulated=['Alasoo_2018_only','Cytoimmgen_only','Fairfax_2014_only','Kim-Hellmuth_2017_only','Nathan_2022_only','Nedelec_2016_only','Quach_2016_only','Randolph_2021_only','Schmiedel_2018_only']
#cellLine=['CAP_only','HipSci_only','iPSCORE_only','Jerber_2021_only','PhLiPS_only','Schwartzentruber_2018_only','TwinsUK_only']
derivedCellLine=['TwinsUK_only','PhLiPS_only','CAP_only','GENCORD_only','Sun_2018_only','Nedelec_2016_only']
canonicalCellLine=['Alasoo_2018_only','Jerber_2021_only','GEUVADIS_only','iPSCORE_only','Aygun_2021_only','Schwartzentruber_2018_only']
stimulated=['Schmiedel_2018_only','Bossini-Castillo_2019_only','Alasoo_2018_only','Cytoimmgen_only','Gilchrist_2021_only','CAP_only','Quach_2016_only','Randolph_2021_only','Sun_2018_only','Nedelec_2016_only','Kim-Hellmuth_2017_only']

def strip_only(lst):
    return [x.removesuffix("_only") for x in lst]  # Python 3.9+
    # or: return [x[:-5] if x.endswith("_only") else x for x in lst]

# Apply
main = strip_only(main)
canonicalCellLine = strip_only(canonicalCellLine)
derivedCellLine = strip_only(derivedCellLine)
stimulated = strip_only(stimulated)

others=[item for item in project_keys if item not in main]
nonStimulated=[item for item in project_keys if item not in stimulated]
nonCanonicalCellLine = [item for item in project_keys if item not in canonicalCellLine]
nonDerivedCellLine = [item for item in project_keys if item not in derivedCellLine]


# First condition: any "yes" in list1
# others
condition1 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), others[1:], F.col(others[0]) == "yes")
# estimulated
condition2 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), stimulated[1:], F.col(stimulated[0]) == "yes")
## non estimulated:
condition3 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonStimulated[1:], F.col(nonStimulated[0]) == "yes")
# canonical cellLine
condition4 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), canonicalCellLine[1:], F.col(canonicalCellLine[0]) == "yes")
# non canonical cellline
condition5 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonCanonicalCellLine[1:], F.col(nonCanonicalCellLine[0]) == "yes")
# derived cell line 
condition6 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), derivedCellLine[1:], F.col(derivedCellLine[0]) == "yes")
# non derived cellline
condition7 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonDerivedCellLine[1:], F.col(nonDerivedCellLine[0]) == "yes")
# mainprojects
condition8 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), main[1:], F.col(main[0]) == "yes")


# Add both columns
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("nonEstimulated", F.when(condition3, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("canonicalCellLine", F.when(condition4, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("nonCanonicalCellLine", F.when(condition5, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("derivedCellLine", F.when(condition6, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("nonDerivedCellLine", F.when(condition7, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("GTExUKB", F.when(condition8, "yes").otherwise("no")) 

# Add both columns
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("nonEstimulated", F.when(condition3, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("canonicalCellLine", F.when(condition4, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("nonCanonicalCellLine", F.when(condition5, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("derivedCellLine", F.when(condition6, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("nonDerivedCellLine", F.when(condition7, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("GTExUKB", F.when(condition8, "yes").otherwise("no")) 

# Add both columns
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("nonEstimulated", F.when(condition3, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("canonicalCellLine", F.when(condition4, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("nonCanonicalCellLine", F.when(condition5, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("derivedCellLine", F.when(condition6, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("nonDerivedCellLine", F.when(condition7, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("GTExUKB", F.when(condition8, "yes").otherwise("no")) 


# If you wanted to apply to every DF in the dict (only if they all share *_only columns):
# for k, df in all_pivoted_dfs.items():
#     all_pivoted_dfs[k] = add_project_group_flags(df, main, stimulated, cellLine)

###append to dictionary

disdic.update({'othersProjectId': 'projectId','Stimulated': 'projectId','cellLine': 'projectId', 'othersBiosampleName_only': 'biosampleName', 'otherRightStudyType':'rightStudyType'})

###################################
###################################
result = []
result_st = []
result_ci = []
array2 = []
listado = []
result_all = []
today_date = str(date.today())

for key,df in all_pivoted_dfs.items():

    print(f'working with {key}')
    parts = key.split('_by_') ### take the part of key belonging to column name
    column_name = parts[1] ### take the last part which is column name
    all_pivoted_dfs[key].persist()
    #unique_values = all_pivoted_dfs[key].drop('null').columns[7:]
    unique_values = all_pivoted_dfs[key].drop('null').columns[-8:] ### just the interesting columns for us 
    filtered_unique_values = [x for x in unique_values if x is not None and x != 'null']
    print('There are ', len(filtered_unique_values), 'columns to analyse with phases')
    rows = comparisons_df_iterative(filtered_unique_values)

    # If needed, now process the rest
    for row in rows:
        print('performing', row)
        results = aggregations_original(
            all_pivoted_dfs[key], key, listado, *row, today_date
        )
        result_all.append(results)
        print('results appended')
    all_pivoted_dfs[key].unpersist()
    print('df unpersisted')


schema = StructType(
    [
        StructField("group", StringType(), True),
        StructField("comparison", StringType(), True),
        StructField("phase", StringType(), True),
        StructField("oddsRatio", DoubleType(), True),
        StructField("pValue", DoubleType(), True),
        StructField("lowerInterval", DoubleType(), True),
        StructField("upperInterval", DoubleType(), True),
        StructField("total", StringType(), True),
        StructField("values", ArrayType(ArrayType(IntegerType())), True),
        StructField("relSuccess", DoubleType(), True),
        StructField("rsLower", DoubleType(), True),
        StructField("rsUpper", DoubleType(), True),
        StructField("path", StringType(), True),
    ]
)
import re

# Define the list of patterns to search for
patterns = [
    "_only",
    #"_tissue",
    #"_isSignalFromRightTissue",
    "_isRightTissueSignalAgreed",
]
# Create a regex pattern to match any of the substrings
regex_pattern = "(" + "|".join(map(re.escape, patterns)) + ")"

# Convert list of lists to DataFrame
df = (
    spreadSheetFormatter(spark.createDataFrame(result_all, schema=schema))
    .withColumn(
        "prefix",
        F.regexp_replace(
            F.col("comparison"), regex_pattern + ".*", ""
        ),  # Extract part before the pattern
    )
    .withColumn(
        "suffix",
        F.regexp_extract(
            F.col("comparison"), regex_pattern, 0
        ),  # Extract the pattern itself
    )
)

### annotate projectId, tissue, qtl type and doe type:

from pyspark.sql.functions import create_map
from itertools import chain

mapping_expr=create_map([F.lit(x) for x in chain(*disdic.items())])

df_annot=df.withColumn('annotation',mapping_expr.getItem(F.col('prefix')))

df_annot.toPandas().to_csv(
    f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_AllPhasesMixtures3.csv"
)

print("dataframe written \n Analysis finished")

                                                                                

working with df_pivot_nonecellyes_by_projectid
There are  8 columns to analyse with phases


                                                                                

performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=4', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=4.parquet


25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_595_84 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_155_27 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_260_43 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_38_96 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_71_11 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_595_43 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_38_37 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_155_133 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_260_122 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_260_11 !
25/09/18 07:12:17 WARN BlockManagerMasterEndpoint: 

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=3.parquet


                                                                                

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=2.parquet


                                                                                

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/othersProjectId_only_predictor_PhaseT.parquet
results appended
performing Row(comparison='estimulated_only', comparisonType='predictor', _1='Phase>=4', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/estimulated_only_predictor_Phase>=4.parquet
results appended
performing Row(comparison='estimulated_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_nonecellyes_by_projectid/estimulated_only_predictor_Phase>=3.parquet
results appended
performing 

                                                                                

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_ndiagonalyes_by_projectid/othersProjectId_only_predictor_PhaseT.parquet
results 

                                                                                

results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=3', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_Phase>=3.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=2', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_Phase>=2.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='Phase>=1', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_Phase>=1.parquet
results appended
performing Row(comparison='othersProjectId_only', comparisonType='predictor', _1='PhaseT', _2='clinical')
gs://ot-team/jroldan/2025-09-18_analysis/df_pivot_hasgenetics_by_projectid/othersProjectId_only_predictor_PhaseT.parquet
results appe



dataframe written 
 Analysis finished


In [40]:
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].columns[-8:]

['nonStimulated',
 'canonicalCellLine',
 'nonCanonicalCellLine',
 'derivedCellLine',
 'nonDerivedCellLine',
 'estimulated_only',
 'nonEstimulated',
 'GTExUKB']

#### I put the code here to notice that is the one working for studiesJoint, is the same as above, but altogether in a chunk:

In [None]:
import time
from array import ArrayType
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
# from stoppedTrials import terminated_td
from DoEAssessment import directionOfEffect
# from membraneTargets import target_membrane
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from datetime import datetime
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.types import (
    StructType,
    StructField,
    DoubleType,
    DecimalType,
    StringType,
    FloatType,
)
import pandas as pd
from functools import reduce


# --- Build the SparkSession ---
# Use the .config() method to set these parameters before calling .getOrCreate()
# This ensures Spark requests the correct resources from YARN at the start.
driver_memory = "24g"                 # plenty for planning & small collects
executor_cores = 4                    # sweet spot for GC + Python workers
num_executors  = 12                   # 12 * 4 = 48 cores for executors; ~16 cores left for driver/OS
executor_memory = "32g"               # per executor heap
executor_memory_overhead = "8g"       # ~20% overhead for PySpark/Arrow/off-heap
# Totals: (32+8) * 12 = 480 GB executors + 24 GB driver ≈ 504 GB (adjust down if your hard cap is <500 GB)
# If you must stay strictly ≤ 500 GB, use executor_memory="30g", overhead="6g"  → (36 * 12) + 24 = 456 + 24 = 480 GB

shuffle_partitions   = 192            # ≈ 2–4× total cores (48) → start with 192
default_parallelism  = 192

spark = SparkSession.builder \
    .appName("MyOptimizedPySparkApp") \
    .config("spark.master", "yarn") \
    .config("spark.driver.memory", driver_memory) \
    .config("spark.executor.memory", executor_memory) \
    .config("spark.executor.cores", executor_cores) \
    .config("spark.executor.instances", num_executors) \
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead) \
    .config("spark.sql.shuffle.partitions", shuffle_partitions) \
    .config("spark.default.parallelism", default_parallelism) \
    .getOrCreate()

print(f"SparkSession created successfully with the following configurations:")
print(f"  spark.driver.memory: {spark.conf.get('spark.driver.memory')}")
print(f"  spark.executor.memory: {spark.conf.get('spark.executor.memory')}")
print(f"  spark.executor.cores: {spark.conf.get('spark.executor.cores')}")
print(f"  spark.executor.instances: {spark.conf.get('spark.executor.instances')}")
print(f"  spark.yarn.executor.memoryOverhead: {spark.conf.get('spark.yarn.executor.memoryOverhead')}")
print(f"  spark.sql.shuffle.partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
print(f"  spark.default.parallelism: {spark.conf.get('spark.default.parallelism')}")
print(f"Spark UI available at: {spark.sparkContext.uiWebUrl}")

# --- Your PySpark Code Here ---
# Now you can proceed with your data loading and processing.
# Example:
# df = spark.read.parquet("hdfs:///user/your_user/your_large_data.parquet")
# print(f"Number of rows in DataFrame: {df.count()}")
# df.groupBy("some_column").agg({"another_column": "sum"}).show()

# Remember to stop the SparkSession when you are done
# spark.stop()

path_n='gs://open-targets-data-releases/25.06/output/'

target = spark.read.parquet(f"{path_n}target/")

diseases = spark.read.parquet(f"{path_n}disease/")

evidences = spark.read.parquet(f"{path_n}evidence")

credible = spark.read.parquet(f"{path_n}credible_set")

new = spark.read.parquet(f"{path_n}colocalisation_coloc") 

index=spark.read.parquet(f"{path_n}study/")

variantIndex = spark.read.parquet(f"{path_n}variant")

biosample = spark.read.parquet(f"{path_n}biosample")

ecaviar=spark.read.parquet(f"{path_n}colocalisation_ecaviar")

all_coloc=ecaviar.unionByName(new, allowMissingColumns=True)

print("loaded files")

#### FIRST MODULE: BUILDING COLOC 
newColoc=buildColocData(all_coloc,credible,index)

print("loaded newColoc")

### SECOND MODULE: PROCESS EVIDENCES TO AVOID EXCESS OF COLUMNS 
gwasComplete = gwasDataset(evidences,credible)

#### THIRD MODULE: INCLUDE COLOC IN THE 
resolvedColoc = (
    (
        newColoc.withColumnRenamed("geneId", "targetId")
        .join(
            gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
            on=["leftStudyLocusId", "targetId"],
            how="inner",
        )
        .join(  ### propagated using parent terms
            diseases.selectExpr(
                "id as diseaseId", "name", "parents", "therapeuticAreas"
            ),
            on="diseaseId",
            how="left",
        )
        .withColumn(
            "diseaseId",
            F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
        )
        .drop("parents", "oldDiseaseId")
    ).withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(
                ["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]
            ),
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_protect"),
            ),
        ).when(
            F.col("rightStudyType").isin(
                ["sqtl", "scsqtl"]
            ),  ### opposite directionality than sqtl
            F.when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("LoF_risk"),
            )
            .when(
                (F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("GoF_risk"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0),
                F.lit("GoF_protect"),
            )
            .when(
                (F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0),
                F.lit("LoF_protect"),
            ),
        ),
    )
    # .persist()
)
print("loaded resolvedColloc")

datasource_filter = [
#   "ot_genetics_portal",
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]

assessment, evidences, actionType, oncolabel = temporary_directionOfEffect(
    path_n, datasource_filter
)

print("run temporary direction of effect")


print("built drugApproved dataset")


#### FOURTH MODULE BUILDING CHEMBL ASSOCIATIONS - HERE TAKE CARE WITH FILTERING STEP 
analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl"))
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
    # .persist()
)

####2 Define agregation function
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.types import *


def convertTuple(tup):
    st = ",".join(map(str, tup))
    return st


#####3 run in a function
def aggregations_original(
    df,
    data,
    listado,
    comparisonColumn,
    comparisonType,
    predictionColumn,
    predictionType,
    today_date,
):
    wComparison = Window.partitionBy(comparisonColumn)
    wPrediction = Window.partitionBy(predictionColumn)
    wPredictionComparison = Window.partitionBy(comparisonColumn, predictionColumn)
    results = []
    # uniqIds = df.select("targetId", "diseaseId").distinct().count()
    out = (
        df.withColumn("comparisonType", F.lit(comparisonType))
        .withColumn("dataset", F.lit(data))
        .withColumn("predictionType", F.lit(predictionType))
        # .withColumn("total", F.lit(uniqIds))
        .withColumn("a", F.count("targetId").over(wPredictionComparison))
        .withColumn("comparisonColumn", F.lit(comparisonColumn))
        .withColumn("predictionColumnValue", F.lit(predictionColumn))
        .withColumn(
            "predictionTotal",
            F.count("targetId").over(wPrediction),
        )
        .withColumn(
            "comparisonTotal",
            F.count("targetId").over(wComparison),
        )
        .select(
            F.col(predictionColumn).alias("prediction"),
            F.col(comparisonColumn).alias("comparison"),
            "dataset",
            "comparisonColumn",
            "predictionColumnValue",
            "comparisonType",
            "predictionType",
            "a",
            "predictionTotal",
            "comparisonTotal",
        )
        .filter(F.col("prediction").isNotNull())
        .filter(F.col("comparison").isNotNull())
        .distinct()
    )
    '''
    out.write.mode("overwrite").parquet(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    '''

    listado.append(
        "gs://ot-team/jroldan/"
        + str(
            today_date
            + "_"
            + "analysis/"
            + data
            # + "_propagated"
            + "/"
            + comparisonColumn
            + "_"
            + comparisonType
            + "_"
            + predictionColumn
            + ".parquet"
        )
    )
    path = "gs://ot-team/jroldan/" + str(
        today_date
        + "_"
        + "analysis/"
        + data
        # + "_propagated"
        + "/"
        + comparisonColumn
        + "_"
        + comparisonType
        + "_"
        + predictionColumn
        + ".parquet"
    )
    print(path)
    
    ### making analysis
    array1 = np.delete(
        out.join(full_data, on=["prediction", "comparison"], how="outer")
        .groupBy("comparison")
        .pivot("prediction")
        .agg(F.first("a"))
        .sort(F.col("comparison").desc())
        .select("comparison", "yes", "no")
        .fillna(0)
        .toPandas()
        .to_numpy(),
        [0],
        1,
    )
    total = np.sum(array1)
    res_npPhaseX = np.array(array1, dtype=int)
    resX = convertTuple(fisher_exact(res_npPhaseX, alternative="two-sided"))
    resx_CI = convertTuple(
        odds_ratio(res_npPhaseX).confidence_interval(confidence_level=0.95)
    )

    result_st.append(resX)
    result_ci.append(resx_CI)
    (rs_result, rs_ci) = relative_success(array1)
    results.extend(
        [
            comparisonType,
            comparisonColumn,
            predictionColumn,
            round(float(resX.split(",")[0]), 2),
            float(resX.split(",")[1]),
            round(float(resx_CI.split(",")[0]), 2),
            round(float(resx_CI.split(",")[1]), 2),
            str(total),
            np.array(res_npPhaseX).tolist(),
            round(float(rs_result), 2),
            round(float(rs_ci[0]), 2),
            round(float(rs_ci[1]), 2),
            # studies,
            # tissues,
            path,
        ]
    )
    return results


#### 3 Loop over different datasets (as they will have different rows and columns)


def comparisons_df_iterative(elements):
    #toAnalysis = [(key, value) for key, value in disdic.items() if value == projectId]
    toAnalysis = [(col, "predictor") for col in elements]
    schema = StructType(
        [
            StructField("comparison", StringType(), True),
            StructField("comparisonType", StringType(), True),
        ]
    )

    comparisons = spark.createDataFrame(toAnalysis, schema=schema)
    ### include all the columns as predictor

    predictions = spark.createDataFrame(
        data=[
            ("Phase>=4", "clinical"),
            ('Phase>=3','clinical'),
            ('Phase>=2','clinical'),
            ('Phase>=1','clinical'),
            ("PhaseT", "clinical"),
        ]
    )
    return comparisons.join(predictions, how="full").collect()


print("load comparisons_df_iterative function")


full_data = spark.createDataFrame(
    data=[
        ("yes", "yes"),
        ("yes", "no"),
        ("no", "yes"),
        ("no", "no"),
    ],
    schema=StructType(
        [
            StructField("prediction", StringType(), True),
            StructField("comparison", StringType(), True),
        ]
    ),
)
print("created full_data and lists")

#rightTissue = spark.read.csv(
#    'gs://ot-team/jroldan/analysis/20250526_rightTissue.csv',
#    header=True,
#).drop("_c0")

print("loaded rightTissue dataset")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

print("built negativeTD dataset")

print("built bench2 dataset")

###### cut from here
print("looping for variables_study")

#### new part with chatgpt -- TEST

## QUESTIONS TO ANSWER:
# HAVE ECAVIAR >=0.8
# HAVE COLOC 
# HAVE COLOC >= 0.8
# HAVE COLOC + ECAVIAR >= 0.01
# HAVE COLOC >= 0.8 + ECAVIAR >= 0.01
# RIGHT JOING WITH CHEMBL 

### FIFTH MODULE: BUILDING BENCHMARK OF THE DATASET TO EXTRACT EHE ANALYSIS 

resolvedColocFiltered = resolvedColoc.filter((F.col('clpp')>=0.01) | (F.col('h4')>=0.8))
benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")


### drug mechanism of action
mecact_path = f"{path_n}drug_mechanism_of_action/" #  mechanismOfAction == old version
mecact = spark.read.parquet(mecact_path)

inhibitors = [
    "RNAI INHIBITOR",
    "NEGATIVE MODULATOR",
    "NEGATIVE ALLOSTERIC MODULATOR",
    "ANTAGONIST",
    "ANTISENSE INHIBITOR",
    "BLOCKER",
    "INHIBITOR",
    "DEGRADER",
    "INVERSE AGONIST",
    "ALLOSTERIC ANTAGONIST",
    "DISRUPTING AGENT",
]

activators = [
    "PARTIAL AGONIST",
    "ACTIVATOR",
    "POSITIVE ALLOSTERIC MODULATOR",
    "POSITIVE MODULATOR",
    "AGONIST",
    "SEQUESTERING AGENT",  ## lost at 31.01.2025
    "STABILISER",
    # "EXOGENOUS GENE", ## added 24.06.2025
    # "EXOGENOUS PROTEIN" ## added 24.06.2025
]


actionType = (
        mecact.select(
            F.explode_outer("chemblIds").alias("drugId"),
            "actionType",
            "mechanismOfAction",
            "targets",
        )
        .select(
            F.explode_outer("targets").alias("targetId"),
            "drugId",
            "actionType",
            "mechanismOfAction",
        )
        .groupBy("targetId", "drugId")
        .agg(F.collect_set("actionType").alias("actionType2"))
    ).withColumn('nMoA', F.size(F.col('actionType2')))

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter((F.col("datasourceId") == "chembl")).join(actionType, on=['targetId','drugId'], how='left')
        .withColumn(
            "maxClinPhase",
            F.max(F.col("clinicalPhase")).over(
                Window.partitionBy("targetId", "diseaseId")
            ),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase",'actionType2')
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    #.filter(F.col("coherencyDiagonal") == "coherent")
    .drop(
        "coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk"
    )
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)

benchmark = (
    (
        resolvedColocFiltered.filter( ## .filter(F.col("betaGwas") < 0)
        F.col("name") != "COVID-19"
    )
        .join(  ### select just GWAS giving protection
            analysis_chembl_indication, on=["targetId", "diseaseId"], how="right"  ### RIGHT SIDE
        )
        .withColumn(
            "AgreeDrug",
            F.when(
                (F.col("drugGoF_protect").isNotNull())
                & (F.col("colocDoE") == "GoF_protect"),
                F.lit("yes"),
            )
            .when(
                (F.col("drugLoF_protect").isNotNull())
                & (F.col("colocDoE") == "LoF_protect"),
                F.lit("yes"),
            )
            .otherwise(F.lit("no")),
        )
    )  #### remove COVID-19 associations
).join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId")
    .count()
    .withColumn("stopReason", F.lit("Negative"))
    .drop("count")
)

### create disdic dictionary
disdic={}

# --- Configuration for your iterative pivoting ---
group_by_columns = ['targetId', 'diseaseId','phase4Clean','phase3Clean','phase2Clean','phase1Clean','PhaseT']
#columns_to_pivot_on = ['actionType2', 'biosampleName', 'projectId', 'rightStudyType','colocalisationMethod']
columns_to_pivot_on = ['projectId']
columns_to_aggregate = ['NoneCellYes', 'NdiagonalYes','hasGenetics'] # The values you want to collect in the pivoted cells
all_pivoted_dfs = {}

doe_columns=["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]
diagonal_lof=['LoF_protect','GoF_risk']
diagonal_gof=['LoF_risk','GoF_protect']

conditions = [
    F.when(F.col(c) == F.col("maxDoE"), F.lit(c)).otherwise(F.lit(None)) for c in doe_columns
    ]
print('entering the big loops')
# --- Nested Loops for Dynamic Pivoting ---
for agg_col_name in columns_to_aggregate:
    for pivot_col_name in columns_to_pivot_on:
        print(f"\n--- Creating DataFrame for Aggregation: '{agg_col_name}' and Pivot: '{pivot_col_name}' ---")
        current_col_pvalue_order_window = Window.partitionBy("targetId", "diseaseId", "maxClinPhase", pivot_col_name).orderBy(F.col('colocalisationMethod').asc(), F.col("qtlPValueExponent").asc())
        test2=discrepancifier(benchmark.withColumn('actionType2', F.concat_ws(",", F.col("actionType2"))).withColumn('qtlColocDoE',F.first('colocDoE').over(current_col_pvalue_order_window)).groupBy(
        "targetId", "diseaseId", "maxClinPhase", "drugLoF_protect", "drugGoF_protect",pivot_col_name)
        .pivot("colocDoE")
        .count()
        .withColumnRenamed('drugLoF_protect', 'LoF_protect_ch')
        .withColumnRenamed('drugGoF_protect', 'GoF_protect_ch')).withColumn( ## .filter(F.col('coherencyDiagonal')!='noEvid')
    "arrayN", F.array(*[F.col(c) for c in doe_columns])
    ).withColumn(
        "maxDoE", F.array_max(F.col("arrayN"))
    ).withColumn("maxDoE_names", F.array(*conditions)
    ).withColumn("maxDoE_names", F.expr("filter(maxDoE_names, x -> x is not null)")
    ).withColumn(
        "NoneCellYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("LoF_protect")))==True, F.lit('yes'))
        .when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & (F.array_contains(F.col("maxDoE_names"), F.lit("GoF_protect")))==True, F.lit('yes')
            ).otherwise(F.lit('no'))  # If the value is null, return null # Otherwise, check if name is in array
    ).withColumn(
        "NdiagonalYes",
        F.when((F.col("LoF_protect_ch").isNotNull() & (F.col('GoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_lof]))) > 0),
            F.lit("yes")
        ).when((F.col("GoF_protect_ch").isNotNull() & (F.col('LoF_protect_ch').isNull())) & 
            (F.size(F.array_intersect(F.col("maxDoE_names"), F.array([F.lit(x) for x in diagonal_gof]))) > 0),
            F.lit("yes")
        ).otherwise(F.lit('no'))
    ).withColumn(
        "drugCoherency",
        F.when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("coherent")
        )
        .when(
            (F.col("LoF_protect_ch").isNotNull())
            & (F.col("GoF_protect_ch").isNotNull()), F.lit("dispar")
        )
        .otherwise(F.lit("other")),
    ).join(negativeTD, on=["targetId", "diseaseId"], how="left").withColumn(
        "PhaseT",
        F.when(F.col("stopReason") == "Negative", F.lit("yes")).otherwise(F.lit("no")),
    ).withColumn(
        "phase4Clean",
        F.when(
            (F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase3Clean",
        F.when(
            (F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase2Clean",
        F.when(
            (F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "phase1Clean",
        F.when(
            (F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), F.lit("yes")
        ).otherwise(F.lit("no")),
    ).withColumn(
        "hasGenetics",
        F.when(F.col("coherencyDiagonal") != "noEvid", F.lit("yes")).otherwise(F.lit("no")),
    )
        # 1. Get distinct values for the pivot column (essential for pivot())
        # This brings a small amount of data to the driver, but is necessary for the pivot schema.
        #distinct_pivot_values = [row[0] for row in test2.select(pivot_col_name).distinct().collect()]
        # print(f"Distinct values for '{pivot_col_name}': {distinct_pivot_values}")

        # 2. Perform the groupBy, pivot, and aggregate operations
        # The .pivot() function requires the list of distinct values for better performance
        # and correct schema inference.
        pivoted_df = (
            test2.groupBy(*group_by_columns)
            .pivot(pivot_col_name) # Provide distinct values distinct_pivot_values
            .agg(F.collect_set(F.col(agg_col_name))) # Collect all values into a set
            .fillna(0) # Fill cells that have no data with an empty list instead of null
        )
        # 3. Add items to dictionary to map the columns:
        # filter out None and 'null':
        datasetColumns=pivoted_df.columns
        filtered = [x for x in datasetColumns if x is not None and x != 'null']
        # using list comprehension
        for item in filtered:
            disdic[item] = pivot_col_name

        # 3. Add the 'data' literal column dynamically
        # This column indicates which aggregation column was used.
        #pivoted_df = pivoted_df.withColumn('data', F.lit(f'Drug_{agg_col_name}'))

        array_columns_to_convert = [
            field.name for field in pivoted_df.schema.fields
            if isinstance(field.dataType, ArrayType)
        ]
        print(f"Identified ArrayType columns for conversion: {array_columns_to_convert}")

        # 4. Apply the conversion logic to each identified array column
        df_after_conversion = pivoted_df # Start with the pivoted_df
        for col_to_convert in array_columns_to_convert:
            df_after_conversion = df_after_conversion.withColumn(
                col_to_convert,
                F.when(F.col(col_to_convert).isNull(), F.lit('no'))          # Handle NULLs (from pivot for no data)
                .when(F.size(F.col(col_to_convert)) == 0, F.lit('no'))       # Empty array -> 'no'
                .when(F.array_contains(F.col(col_to_convert), F.lit('yes')), F.lit('yes')) # Contains 'yes' -> 'yes'
                .when(F.array_contains(F.col(col_to_convert), F.lit('no')), F.lit('no'))   # Contains 'no' -> 'no'
                .otherwise(F.lit('no')) # Fallback for unexpected array content (e.g., ['other'], ['yes','no'])
            )

        # 4. Generate a unique name for this DataFrame and store it
        df_key = f"df_pivot_{agg_col_name.lower()}_by_{pivot_col_name.lower()}"
        all_pivoted_dfs[df_key] = df_after_conversion.withColumnRenamed( 'phase4Clean','Phase>=4'
        ).withColumnRenamed('phase3Clean','Phase>=3'
        ).withColumnRenamed('phase2Clean','Phase>=2'
        ).withColumnRenamed('phase1Clean','Phase>=1')


# --- Accessing your generated DataFrames ---
print("\n--- All generated DataFrames are stored in 'all_pivoted_dfs' dictionary ---")
print("Keys available:", all_pivoted_dfs.keys())
##### PROJECTID
project_keys = (
    benchmark
    .select("projectId")
    .distinct()
    .rdd
    .map(lambda r: r[0])
    .filter(lambda x: x is not None)  # <- remove NULLs
    .collect()
)
#project_keys=[f"{k}_only" for k,v in disdic.items() if v == 'projectId']
main=['GTEx_only', 'UKB_PPP_EUR_only']
#stimulated=['Alasoo_2018_only','Cytoimmgen_only','Fairfax_2014_only','Kim-Hellmuth_2017_only','Nathan_2022_only','Nedelec_2016_only','Quach_2016_only','Randolph_2021_only','Schmiedel_2018_only']
#cellLine=['CAP_only','HipSci_only','iPSCORE_only','Jerber_2021_only','PhLiPS_only','Schwartzentruber_2018_only','TwinsUK_only']
derivedCellLine=['TwinsUK_only','PhLiPS_only','CAP_only','GENCORD_only','Sun_2018_only','Nedelec_2016_only']
canonicalCellLine=['Alasoo_2018_only','Jerber_2021_only','GEUVADIS_only','iPSCORE_only','Aygun_2021_only','Schwartzentruber_2018_only']
stimulated=['Schmiedel_2018_only','Bossini-Castillo_2019_only','Alasoo_2018_only','Cytoimmgen_only','Gilchrist_2021_only','CAP_only','Quach_2016_only','Randolph_2021_only','Sun_2018_only','Nedelec_2016_only','Kim-Hellmuth_2017_only']

def strip_only(lst):
    return [x.removesuffix("_only") for x in lst]  # Python 3.9+
    # or: return [x[:-5] if x.endswith("_only") else x for x in lst]

# Apply
main = strip_only(main)
canonicalCellLine = strip_only(canonicalCellLine)
derivedCellLine = strip_only(derivedCellLine)
stimulated = strip_only(stimulated)

others=[item for item in project_keys if item not in main]
nonStimulated=[item for item in project_keys if item not in stimulated]
nonCanonicalCellLine = [item for item in project_keys if item not in canonicalCellLine]
nonDerivedCellLine = [item for item in project_keys if item not in derivedCellLine]


# First condition: any "yes" in list1
# others
condition1 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), others[1:], F.col(others[0]) == "yes")
# estimulated
condition2 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), stimulated[1:], F.col(stimulated[0]) == "yes")
## non estimulated:
condition3 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonStimulated[1:], F.col(nonStimulated[0]) == "yes")
# canonical cellLine
condition4 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), canonicalCellLine[1:], F.col(canonicalCellLine[0]) == "yes")
# non canonical cellline
condition5 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonCanonicalCellLine[1:], F.col(nonCanonicalCellLine[0]) == "yes")
# derived cell line 
condition6 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), derivedCellLine[1:], F.col(derivedCellLine[0]) == "yes")
# non derived cellline
condition7 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), nonDerivedCellLine[1:], F.col(nonDerivedCellLine[0]) == "yes")
# mainprojects
condition8 = reduce(lambda acc, col: acc | (F.col(col) == "yes"), main[1:], F.col(main[0]) == "yes")


# Add both columns
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("nonEstimulated", F.when(condition3, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("canonicalCellLine", F.when(condition4, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("nonCanonicalCellLine", F.when(condition5, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("derivedCellLine", F.when(condition6, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("nonDerivedCellLine", F.when(condition7, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'] = all_pivoted_dfs['df_pivot_nonecellyes_by_projectid'].withColumn("GTExUKB", F.when(condition8, "yes").otherwise("no")) 

# Add both columns
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("nonEstimulated", F.when(condition3, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("canonicalCellLine", F.when(condition4, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("nonCanonicalCellLine", F.when(condition5, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("derivedCellLine", F.when(condition6, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("nonDerivedCellLine", F.when(condition7, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'] = all_pivoted_dfs['df_pivot_ndiagonalyes_by_projectid'].withColumn("GTExUKB", F.when(condition8, "yes").otherwise("no")) 

# Add both columns
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("othersProjectId_only", F.when(condition1, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("estimulated_only", F.when(condition2, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("nonEstimulated", F.when(condition3, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("canonicalCellLine", F.when(condition4, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("nonCanonicalCellLine", F.when(condition5, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("derivedCellLine", F.when(condition6, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("nonDerivedCellLine", F.when(condition7, "yes").otherwise("no")) 
all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'] = all_pivoted_dfs['df_pivot_hasgenetics_by_projectid'].withColumn("GTExUKB", F.when(condition8, "yes").otherwise("no")) 


# If you wanted to apply to every DF in the dict (only if they all share *_only columns):
# for k, df in all_pivoted_dfs.items():
#     all_pivoted_dfs[k] = add_project_group_flags(df, main, stimulated, cellLine)

###append to dictionary

disdic.update({'othersProjectId': 'projectId','Stimulated': 'projectId','cellLine': 'projectId', 'othersBiosampleName_only': 'biosampleName', 'otherRightStudyType':'rightStudyType'})

###################################
###################################
result = []
result_st = []
result_ci = []
array2 = []
listado = []
result_all = []
today_date = str(date.today())

for key,df in all_pivoted_dfs.items():

    print(f'working with {key}')
    parts = key.split('_by_') ### take the part of key belonging to column name
    column_name = parts[1] ### take the last part which is column name
    all_pivoted_dfs[key].persist()
    #unique_values = all_pivoted_dfs[key].drop('null').columns[7:]
    unique_values = all_pivoted_dfs[key].drop('null').columns[-8:] ### just the interesting columns for us 
    filtered_unique_values = [x for x in unique_values if x is not None and x != 'null']
    print('There are ', len(filtered_unique_values), 'columns to analyse with phases')
    rows = comparisons_df_iterative(filtered_unique_values)

    # If needed, now process the rest
    for row in rows:
        print('performing', row)
        results = aggregations_original(
            all_pivoted_dfs[key], key, listado, *row, today_date
        )
        result_all.append(results)
        print('results appended')
    all_pivoted_dfs[key].unpersist()
    print('df unpersisted')


schema = StructType(
    [
        StructField("group", StringType(), True),
        StructField("comparison", StringType(), True),
        StructField("phase", StringType(), True),
        StructField("oddsRatio", DoubleType(), True),
        StructField("pValue", DoubleType(), True),
        StructField("lowerInterval", DoubleType(), True),
        StructField("upperInterval", DoubleType(), True),
        StructField("total", StringType(), True),
        StructField("values", ArrayType(ArrayType(IntegerType())), True),
        StructField("relSuccess", DoubleType(), True),
        StructField("rsLower", DoubleType(), True),
        StructField("rsUpper", DoubleType(), True),
        StructField("path", StringType(), True),
    ]
)
import re

# Define the list of patterns to search for
patterns = [
    "_only",
    #"_tissue",
    #"_isSignalFromRightTissue",
    "_isRightTissueSignalAgreed",
]
# Create a regex pattern to match any of the substrings
regex_pattern = "(" + "|".join(map(re.escape, patterns)) + ")"

# Convert list of lists to DataFrame
df = (
    spreadSheetFormatter(spark.createDataFrame(result_all, schema=schema))
    .withColumn(
        "prefix",
        F.regexp_replace(
            F.col("comparison"), regex_pattern + ".*", ""
        ),  # Extract part before the pattern
    )
    .withColumn(
        "suffix",
        F.regexp_extract(
            F.col("comparison"), regex_pattern, 0
        ),  # Extract the pattern itself
    )
)

### annotate projectId, tissue, qtl type and doe type:

from pyspark.sql.functions import create_map
from itertools import chain

mapping_expr=create_map([F.lit(x) for x in chain(*disdic.items())])

df_annot=df.withColumn('annotation',mapping_expr.getItem(F.col('prefix')))

df_annot.toPandas().to_csv(
    f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue_AllPhasesMixtures3.csv"
)

print("dataframe written \n Analysis finished")