In [8]:
# -*- coding: utf-8 -*-
# Single-script, loop-free PySpark job (tall/unpivot + single aggregation)

import os
from datetime import date
from functools import reduce

from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, DoubleType, ArrayType
)

# Your helpers
from functions import (
    relative_success,
    spreadSheetFormatter,
    discrepancifier,
    temporary_directionOfEffect,
    buildColocData,
    gwasDataset,
)
from DoEAssessment import directionOfEffect  # noqa: F401  (kept if you need it later)

# -------------------------------
# Spark / YARN resource settings
# -------------------------------
driver_memory = "4g"
executor_memory = "8g"
executor_cores = "4"
num_executors = "10"
executor_memory_overhead = "2g"
shuffle_partitions = "150"
default_parallelism = str(int(executor_cores) * int(num_executors) * 2)  # 80

spark = (
    SparkSession.builder
    .appName("MyOptimizedPySparkApp")
    .config("spark.master", "yarn")
    .config("spark.driver.memory", driver_memory)
    .config("spark.executor.memory", executor_memory)
    .config("spark.executor.cores", executor_cores)
    .config("spark.executor.instances", num_executors)
    .config("spark.yarn.executor.memoryOverhead", executor_memory_overhead)
    .config("spark.sql.shuffle.partitions", shuffle_partitions)
    .config("spark.default.parallelism", default_parallelism)
    .getOrCreate()
)

print("SparkSession created with:")
for k in [
    "spark.driver.memory",
    "spark.executor.memory",
    "spark.executor.cores",
    "spark.executor.instances",
    "spark.yarn.executor.memoryOverhead",
    "spark.sql.shuffle.partitions",
    "spark.default.parallelism",
]:
    print(f"  {k}: {spark.conf.get(k)}")
print(f"Spark UI: {spark.sparkContext.uiWebUrl}")

# --------------------------------
# 0) Load inputs
# --------------------------------
path_n = "gs://open-targets-data-releases/25.06/output/"

target = spark.read.parquet(f"{path_n}target/")
diseases = spark.read.parquet(f"{path_n}disease/")
evidences = spark.read.parquet(f"{path_n}evidence")
credible = spark.read.parquet(f"{path_n}credible_set")
new = spark.read.parquet(f"{path_n}colocalisation_coloc")
index = spark.read.parquet(f"{path_n}study/")
variantIndex = spark.read.parquet(f"{path_n}variant")
biosample = spark.read.parquet(f"{path_n}biosample")
ecaviar = spark.read.parquet(f"{path_n}colocalisation_ecaviar")
all_coloc = ecaviar.unionByName(new, allowMissingColumns=True)
print("Loaded all base tables.")

# --------------------------------
# 1) Build coloc + GWAS dataset
# --------------------------------
newColoc = buildColocData(all_coloc, credible, index)
print("Built newColoc")

gwasComplete = gwasDataset(evidences, credible)
print("Built gwasComplete")

resolvedColoc = (
    newColoc.withColumnRenamed("geneId", "targetId")
    .join(
        gwasComplete.withColumnRenamed("studyLocusId", "leftStudyLocusId"),
        on=["leftStudyLocusId", "targetId"],
        how="inner",
    )
    .join(
        diseases.selectExpr("id as diseaseId", "name", "parents", "therapeuticAreas"),
        on="diseaseId",
        how="left",
    )
    .withColumn(
        "diseaseId",
        F.explode_outer(F.concat(F.array(F.col("diseaseId")), F.col("parents"))),
    )
    .drop("parents", "oldDiseaseId")
    .withColumn(
        "colocDoE",
        F.when(
            F.col("rightStudyType").isin(["eqtl", "pqtl", "tuqtl", "sceqtl", "sctuqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_protect"))
        ).when(
            F.col("rightStudyType").isin(["sqtl", "scsqtl"]),
            F.when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") > 0), F.lit("LoF_risk"))
            .when((F.col("betaGwas") > 0) & (F.col("betaRatioSignAverage") < 0), F.lit("GoF_risk"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") > 0), F.lit("GoF_protect"))
            .when((F.col("betaGwas") < 0) & (F.col("betaRatioSignAverage") < 0), F.lit("LoF_protect"))
        ),
    )
)
print("Built resolvedColoc")

# --------------------------------
# 2) Direction of Effect & ChEMBL indication
# --------------------------------
datasource_filter = [
    "gwas_credible_sets",
    "gene_burden",
    "eva",
    "eva_somatic",
    "gene2phenotype",
    "orphanet",
    "cancer_gene_census",
    "intogen",
    "impc",
    "chembl",
]
assessment, evidences, actionType_unused, oncolabel_unused = temporary_directionOfEffect(path_n, datasource_filter)
print("Built temporary DoE datasets")

# (Optional) Add MoA to ChEMBL paths as in your later code
mecact_path = f"{path_n}drug_mechanism_of_action/"
mecact = spark.read.parquet(mecact_path)
actionType = (
    mecact.select(
        F.explode_outer("chemblIds").alias("drugId"),
        "actionType",
        "mechanismOfAction",
        "targets",
    )
    .select(
        F.explode_outer("targets").alias("targetId"),
        "drugId",
        "actionType",
        "mechanismOfAction",
    )
    .groupBy("targetId", "drugId")
    .agg(F.collect_set("actionType").alias("actionType2"))
    .withColumn("nMoA", F.size(F.col("actionType2")))
)

analysis_chembl_indication = (
    discrepancifier(
        assessment.filter(F.col("datasourceId") == "chembl")
        .join(actionType, on=["targetId", "drugId"], how="left")
        .withColumn(
            "maxClinPhase",
            F.max("clinicalPhase").over(Window.partitionBy("targetId", "diseaseId")),
        )
        .groupBy("targetId", "diseaseId", "maxClinPhase", "actionType2")
        .pivot("homogenized")
        .agg(F.count("targetId"))
    )
    .drop("coherencyDiagonal", "coherencyOneCell", "noEvaluable", "GoF_risk", "LoF_risk")
    .withColumnRenamed("GoF_protect", "drugGoF_protect")
    .withColumnRenamed("LoF_protect", "drugLoF_protect")
)
print("Built analysis_chembl_indication")

# --------------------------------
# 3) Benchmark (filtered coloc) + clinical phase flags
# --------------------------------
resolvedColocFiltered = resolvedColoc.filter((F.col("clpp") >= 0.01) | (F.col("h4") >= 0.8))

benchmark = (
    resolvedColocFiltered.filter(F.col("name") != "COVID-19")
    .join(analysis_chembl_indication, on=["targetId", "diseaseId"], how="right")
    .withColumn(
        "AgreeDrug",
        F.when((F.col("drugGoF_protect").isNotNull()) & (F.col("colocDoE") == "GoF_protect"), "yes")
        .when((F.col("drugLoF_protect").isNotNull()) & (F.col("colocDoE") == "LoF_protect"), "yes")
        .otherwise("no"),
    )
    .join(biosample.select("biosampleId", "biosampleName"), on="biosampleId", how="left")
)

negativeTD = (
    evidences.filter(F.col("datasourceId") == "chembl")
    .select("targetId", "diseaseId", "studyStopReason", "studyStopReasonCategories")
    .filter(F.array_contains(F.col("studyStopReasonCategories"), "Negative"))
    .groupBy("targetId", "diseaseId").count()
    .withColumn("stopReason", F.lit("Negative")).drop("count")
)

benchmark = (
    benchmark.join(F.broadcast(negativeTD), on=["targetId", "diseaseId"], how="left")
    .withColumn("PhaseT", F.when(F.col("stopReason") == "Negative", "yes").otherwise("no"))
    .withColumn("Phase>=4", F.when((F.col("maxClinPhase") == 4) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=3", F.when((F.col("maxClinPhase") >= 3) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=2", F.when((F.col("maxClinPhase") >= 2) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
    .withColumn("Phase>=1", F.when((F.col("maxClinPhase") >= 1) & (F.col("PhaseT") == "no"), "yes").otherwise("no"))
)

# --------------------------------
# 4) Replace nested loops:
#     compute DoE counts once → derive flags → unpivot → single aggregation
# --------------------------------
doe_cols = ["LoF_protect", "GoF_risk", "LoF_risk", "GoF_protect"]

# counts per colocDoE over the grouping you previously used in the loop
group_keys = [
    "targetId", "diseaseId", "maxClinPhase",
    "actionType2", "biosampleName", "projectId", "rightStudyType", "colocalisationMethod"
]

doe_counts = (
    benchmark.groupBy(*group_keys)
    .agg(*[F.sum(F.when(F.col("colocDoE") == c, 1).otherwise(0)).alias(c) for c in doe_cols])
)

# max name(s) (in case of ties) without arrays of structs
greatest_count = F.greatest(*[F.col(c) for c in doe_cols])
max_names = F.filter(
    F.array(*[F.when(F.col(c) == greatest_count, F.lit(c)) for c in doe_cols]),
    lambda x: x.isNotNull()
)

# presence of drug-side signals (equivalent to *_ch presence in your loop path)
has_lof_ch = F.col("drugLoF_protect").isNotNull()
has_gof_ch = F.col("drugGoF_protect").isNotNull()

test2 = (
    benchmark.select(*group_keys, "drugLoF_protect", "drugGoF_protect")
    .join(doe_counts, on=group_keys, how="left")
    .withColumn("NoneCellYes",
        F.when(has_lof_ch & (~has_gof_ch) & F.array_contains(max_names, F.lit("LoF_protect")), "yes")
         .when(has_gof_ch & (~has_lof_ch) & F.array_contains(max_names, F.lit("GoF_protect")), "yes")
         .otherwise("no")
    )
    .withColumn("NdiagonalYes",
        F.when(has_lof_ch & (~has_gof_ch) & (F.array_contains(max_names, F.lit("LoF_protect")) | F.array_contains(max_names, F.lit("GoF_risk"))), "yes")
         .when(has_gof_ch & (~has_lof_ch) & (F.array_contains(max_names, F.lit("GoF_protect")) | F.array_contains(max_names, F.lit("LoF_risk"))), "yes")
         .otherwise("no")
    )
    .withColumn("drugCoherency",
        F.when(has_lof_ch & ~has_gof_ch, "coherent")
         .when(~has_lof_ch & has_gof_ch, "coherent")
         .when(has_lof_ch & has_gof_ch, "dispar")
         .otherwise("other")
    )
    .withColumn("hasGenetics", F.when(F.col("NdiagonalYes").isNotNull(), "yes").otherwise("no"))
)


SparkSession created with:
  spark.driver.memory: 4g
  spark.executor.memory: 8g
  spark.executor.cores: 4
  spark.executor.instances: 10
  spark.yarn.executor.memoryOverhead: 2g
  spark.sql.shuffle.partitions: 150
  spark.default.parallelism: 80
Spark UI: http://jr-temp-doe-m.c.open-targets-eu-dev.internal:45327


                                                                                

Loaded all base tables.
Built newColoc


                                                                                

loaded gwasComplete
Built gwasComplete
Built resolvedColoc


25/09/07 15:53:56 WARN CacheManager: Asked to cache already cached data.
25/09/07 15:53:56 WARN CacheManager: Asked to cache already cached data.


Built temporary DoE datasets


Exception in thread "serve-DataFrame" java.net.SocketTimeoutException: Accept timed out
	at java.base/java.net.PlainSocketImpl.socketAccept(Native Method)
	at java.base/java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:474)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:565)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:533)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:65)
25/09/07 15:53:57 WARN CacheManager: Asked to cache already cached data.
25/09/07 15:53:57 WARN CacheManager: Asked to cache already cached data.


Built analysis_chembl_indication


In [10]:
common_cols = [
    "targetId","diseaseId","maxClinPhase",
    "Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT",
    "NoneCellYes","NdiagonalYes","hasGenetics"
]

# 1) actionType2 is ARRAY<STRING> → explode to one value per row
long_action = (
    test2.join(
        benchmark.select(
            "targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
        ).dropDuplicates(["targetId","diseaseId","maxClinPhase"]),
        on=["targetId","diseaseId","maxClinPhase"], how="left"
    )
    .select(*common_cols, F.explode_outer("actionType2").alias("value"))
    .withColumn("feature", F.lit("actionType2"))
    .select(*common_cols, "feature", "value")
)

# 2–5) the scalar columns
def longify_scalar(colname: str):
    return (
        test2.join(
            benchmark.select(
                "targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT"
            ).dropDuplicates(["targetId","diseaseId","maxClinPhase"]),
            on=["targetId","diseaseId","maxClinPhase"], how="left"
        )
        .select(*common_cols, F.col(colname).alias("value"))
        .withColumn("feature", F.lit(colname))
        .select(*common_cols, "feature", "value")
    )

long_biosample = longify_scalar("biosampleName")
long_project   = longify_scalar("projectId")
long_rstype    = longify_scalar("rightStudyType")
long_colocm    = longify_scalar("colocalisationMethod")

# Union them (schema-aligned)
long_features = (
    long_action
    .unionByName(long_biosample)
    .unionByName(long_project)
    .unionByName(long_rstype)
    .unionByName(long_colocm)
).filter(F.col("value").isNotNull())


In [11]:

# Single aggregation for ALL features and ALL metrics (booleans as max over yes/no)
agg_once = (
    long_features
    .groupBy("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT","feature","value")
    .agg(
        F.max(F.when(F.col("NoneCellYes")=="yes", 1).otherwise(0)).alias("NoneCellYes"),
        F.max(F.when(F.col("NdiagonalYes")=="yes", 1).otherwise(0)).alias("NdiagonalYes"),
        F.max(F.when(F.col("hasGenetics")=="yes", 1).otherwise(0)).alias("hasGenetics"),
    )
    .selectExpr(
        "*",
        "CASE WHEN NoneCellYes=1 THEN 'yes' ELSE 'no' END as NoneCellYes_flag",
        "CASE WHEN NdiagonalYes=1 THEN 'yes' ELSE 'no' END as NdiagonalYes_flag",
        "CASE WHEN hasGenetics=1 THEN 'yes' ELSE 'no' END as hasGenetics_flag"
    )
)


In [None]:

# --------------------------------
# 5) Optional: persist a canonical long dataset partitioned by feature
# --------------------------------
today = date.today().isoformat()
out_base = f"gs://ot-team/jroldan/{today}_analysis"
out_long = f"{out_base}/pivot_long"
(
    agg_once
    .repartition("feature")
    .write.mode("overwrite").partitionBy("feature").parquet(out_long)
)
print(f"Wrote canonical long dataset: {out_long}")

# --------------------------------
# 6) Build disdic (value -> feature) without looping giant frames
# --------------------------------
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

# --------------------------------
# 7) On-demand “wide” view for a specific feature/metric (no Python loops)
#     Example: projectId × NoneCellYes_flag
# --------------------------------
project_wide = (
    agg_once.filter(F.col("feature")=="projectId")
            .groupBy("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
            .pivot("value")
            .agg(F.first("NoneCellYes_flag"))
)

out_wide = f"{out_base}/wide_projectId_NoneCellYes"
project_wide.write.mode("overwrite").parquet(out_wide)
print(f"Wrote example wide view: {out_wide}")

# --------------------------------
# 8) (Optional) Another wide view: biosampleName × hasGenetics_flag
# --------------------------------
biosample_wide = (
    agg_once.filter(F.col("feature")=="biosampleName")
            .groupBy("targetId","diseaseId","maxClinPhase","Phase>=4","Phase>=3","Phase>=2","Phase>=1","PhaseT")
            .pivot("value")
            .agg(F.first("hasGenetics_flag"))
)
biosample_out = f"{out_base}/wide_biosample_hasGenetics"
biosample_wide.write.mode("overwrite").parquet(biosample_out)
print(f"Wrote example wide view: {biosample_out}")

# --------------------------------
# 9) (Optional) Export a CSV-like report using your spreadsheet helper
#     (Format like your downstream code; adjust as needed)
# --------------------------------
# Build a result DF (group, comparison, phase, etc.) from one wide view if needed.
# Here we illustrate by melting one metric back to a tall format for reporting.

# Example reporting DF (simple): feature/value + flags counted
report_df = (
    agg_once
    .groupBy("feature","value")
    .agg(
        F.sum(F.col("NoneCellYes")).alias("NoneCellYes_cnt"),
        F.sum(F.col("NdiagonalYes")).alias("NdiagonalYes_cnt"),
        F.sum(F.col("hasGenetics")).alias("hasGenetics_cnt"),
    )
)

# Use your spreadSheetFormatter if it expects a certain schema (optional)
report_fmt = spreadSheetFormatter(report_df)

csv_out = f"{out_base}/summary_counts.csv"
report_fmt.toPandas().to_csv(csv_out, index=False)
print(f"Wrote summary CSV: {csv_out}")

print("Job completed successfully.")


25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_59 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_0 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_71_12 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_2 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_135 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_38_37 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_38_51 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_4 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_377_76 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_38_131 !
25/09/07 16:06:09 WARN BlockManagerMasterEndpoint: No m

Wrote canonical long dataset: gs://ot-team/jroldan/2025-09-07_analysis/pivot_long


25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_71_93 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_115 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_71_11 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_71_12 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_377_114 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_377_58 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_38_80 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_77 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_63 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: No more replicas available for rdd_371_92 !
25/09/07 16:12:29 WARN BlockManagerMasterEndpoint: N

In [None]:
# ============================
# Spark-friendly ANALYSIS
# ============================
from datetime import date
import pyspark.sql.functions as F
from pyspark.sql.types import (
    StructType, StructField, StringType, DoubleType, IntegerType, ArrayType
)
from pyspark.sql import Window

# ---- 1) Build the comparison × prediction long table in one go ----
# metrics we want to analyze (these replace columns_to_aggregate)
metric_flags = ["NoneCellYes_flag", "NdiagonalYes_flag", "hasGenetics_flag"]

# Melt the phase flags into (phase_name, prediction) rows
phases_long = (
    agg_once.select(
        "targetId","diseaseId","maxClinPhase","feature","value",
        F.expr("stack(5, "
               "'Phase>=4', `Phase>=4`, "
               "'Phase>=3', `Phase>=3`, "
               "'Phase>=2', `Phase>=2`, "
               "'Phase>=1', `Phase>=1`, "
               "'PhaseT',  `PhaseT`"
               ")").alias("phase_name", "prediction")
    )
    .filter(F.col("prediction").isNotNull())
)

# For each metric, attach its comparison flag ("yes"/"no") and union them all
def attach_metric(metric_col: str):
    return (
        agg_once.select("targetId","diseaseId","maxClinPhase","feature","value", F.col(metric_col).alias("comparison"))
                .join(phases_long.select("targetId","diseaseId","maxClinPhase","feature","value","phase_name","prediction"),
                      on=["targetId","diseaseId","maxClinPhase","feature","value"],
                      how="inner")
                .withColumn("metric", F.lit(metric_col.replace("_flag", "")))  # pretty label
    )

analysis_long = attach_metric(metric_flags[0])
for mc in metric_flags[1:]:
    analysis_long = analysis_long.unionByName(attach_metric(mc))

# Now analysis_long has rows like:
# (targetId, diseaseId, feature, value, metric, phase_name, comparison='yes'/'no', prediction='yes'/'no')

# ---- 2) Count ALL 2×2 cells in one aggregation ----
cell_counts = (
    analysis_long
    .groupBy("metric","feature","value","phase_name","comparison","prediction")
    .agg(F.count("*").alias("a"))
)

# ---- 3) Reshape to 2×2 matrices per (metric, feature, value, phase) ----
# Create a compact 2-row frame with columns yes/no for prediction
mat_counts = (
    cell_counts
    .groupBy("metric","feature","value","phase_name","comparison")
    .pivot("prediction", ["yes","no"])
    .agg(F.first("a"))
    .fillna(0)
)

# We'll compute Fisher per group using a grouped map Pandas UDF
import pandas as pd
import numpy as np
from scipy.stats import fisher_exact
from scipy.stats.contingency import odds_ratio
from pyspark.sql.functions import pandas_udf

# Output schema mirrors your previous 'schema'
result_schema = StructType([
    StructField("group",        StringType(),  True),  # metric name
    StructField("comparison",   StringType(),  True),  # e.g., "<value>_only"
    StructField("phase",        StringType(),  True),  # phase_name
    StructField("oddsRatio",    DoubleType(),  True),
    StructField("pValue",       DoubleType(),  True),
    StructField("lowerInterval",DoubleType(),  True),
    StructField("upperInterval",DoubleType(),  True),
    StructField("total",        StringType(),  True),
    StructField("values",       ArrayType(ArrayType(IntegerType())), True),
    StructField("relSuccess",   DoubleType(),  True),
    StructField("rsLower",      DoubleType(),  True),
    StructField("rsUpper",      DoubleType(),  True),
    StructField("path",         StringType(),  True),
])

# Relative success helper – reusing your function on driver won’t vectorize,
# so we replicate a simple version here. If you must use your exact math,
# import it and call it inside the pandas UDF.
def _relative_success(matrix_2x2: np.ndarray):
    # matrix rows: comparison yes/no; cols: prediction yes/no
    # success rate in comparison==yes vs comparison==no
    a, b = matrix_2x2[0,0], matrix_2x2[0,1]
    c, d = matrix_2x2[1,0], matrix_2x2[1,1]
    rate_yes = a / (a + b) if (a + b) > 0 else 0.0
    rate_no  = c / (c + d) if (c + d) > 0 else 0.0
    rs = rate_yes - rate_no
    # crude Wald CI for difference in proportions (you can swap with your exact function)
    import math
    se = 0.0
    if (a+b) > 0:
        se += rate_yes * (1 - rate_yes) / (a + b)
    if (c+d) > 0:
        se += rate_no  * (1 - rate_no)  / (c + d)
    se = math.sqrt(se)
    lo, hi = rs - 1.96*se, rs + 1.96*se
    return float(rs), float(lo), float(hi)

@pandas_udf(result_schema)
def fisher_by_group(pdf: pd.DataFrame) -> pd.DataFrame:
    # pdf has columns: metric, feature, value, phase_name, comparison, yes, no
    # build 2x2 with rows ordered comparison=['yes','no'] and cols ['yes','no']
    # Ensure both rows exist; fill missing with zeros
    sub = pdf[["comparison","yes","no"]].copy()
    sub = sub.set_index("comparison").reindex(["yes","no"]).fillna(0)
    mat = sub[["yes","no"]].to_numpy(dtype=int)

    total = int(mat.sum())
    or_val, p_val = fisher_exact(mat, alternative="two-sided")
    ci = odds_ratio(mat).confidence_interval(0.95)
    rs, rs_lo, rs_hi = _relative_success(mat)

    row = {
        "group":        pdf["metric"].iloc[0],
        "comparison":   f"{pdf['value'].iloc[0]}_only",  # to match your previous naming
        "phase":        pdf["phase_name"].iloc[0],
        "oddsRatio":    float(round(or_val, 2)),
        "pValue":       float(p_val),
        "lowerInterval":float(round(ci[0], 2)),
        "upperInterval":float(round(ci[1], 2)),
        "total":        str(total),
        "values":       mat.tolist(),
        "relSuccess":   float(round(rs, 2)),
        "rsLower":      float(round(rs_lo, 2)),
        "rsUpper":      float(round(rs_hi, 2)),
        "path":         "",   # you can fill a path pattern here if you still write per-combo parquet
    }
    return pd.DataFrame([row])

# Apply the grouped map per (metric, feature, value, phase)
results_df = (
    mat_counts
    .groupBy("metric","feature","value","phase_name")
    .apply(fisher_by_group)
)

# ---- 4) Optional spreadsheet formatting + annotate and export ----
from itertools import chain
from pyspark.sql.functions import create_map

# Your disdic: {value -> feature}; rebuild safely from agg_once if not present
disdic = {r["value"]: r["feature"] for r in agg_once.select("feature","value").distinct().collect()}

# If you still want to keep the regex-based prefix/suffix (works with our "value_only" naming):
patterns = ["_only", "_isRightTissueSignalAgreed"]
regex_pattern = "(" + "|".join(patterns) + ")"

df_fmt = (
    spreadSheetFormatter(results_df)  # reuse your helper
    .withColumn("prefix", F.regexp_replace(F.col("comparison"), regex_pattern + ".*", ""))
    .withColumn("suffix", F.regexp_extract(F.col("comparison"), regex_pattern, 0))
)

mapping_expr = create_map([F.lit(x) for x in chain(*disdic.items())])
df_annot = df_fmt.withColumn("annotation", mapping_expr.getItem(F.col("prefix")))

today_date = date.today().isoformat()
out_csv = f"gs://ot-team/jroldan/analysis/{today_date}_credibleSetColocDoEanalysis_filteredColocAndCaviarWithOthers4phasesTrue.csv"
df_annot.toPandas().to_csv(out_csv, index=False)
print(f"Analysis written: {out_csv}")
