In [None]:
"""
sorghum_barrier_pipeline_simple.py

Reproduce all analysis tables for the sorghumâ€“johnsongrass barrier paper.

INPUT (CSV, under <BASE_DIR>/master/):
    - Figure2_SourceData.csv
        columns: host, tissue, surface, score
    - Figure3_SourceData.csv
        columns: host, round, score
    - Figure4_SourceData.csv
        columns: host_type, tissue, score
    - Figure5_SourceData.csv
        columns: clone, score

OUTPUT (CSV, under <BASE_DIR>/results/tables/):
    - surface_treatment_summary.csv
    - logit_surface_params.csv
    - logit_surface_predicted_probs.csv
    - regrowth_round_summary.csv
    - logit_regrowth_params.csv
    - logit_regrowth_predicted_probs.csv
    - rhizome_summary.csv
    - logit_rhizome_params.csv
    - logit_rhizome_predicted_probs.csv
    - growth_stage_GS6_summary.csv
    - defense_barriers.csv
    - strategic_indices.csv

NO FIGURES are produced in this script.
"""

import os
import math

import numpy as np
import pandas as pd
import statsmodels.api as sm
import statsmodels.formula.api as smf

# 0. Paths (Dynamic Setup)
# ----------------------------------------------------------

# Set BASE_DIR to the directory where this script is located.
# This ensures the code runs on any machine without modifying paths.
BASE_DIR = os.getcwd()

# Assumes data files are in a 'master' subfolder or the same folder.
# If your data is in the same folder as the script, change "master" to "" (empty string).
MASTER_DIR = os.path.join(BASE_DIR) 
OUT_DIR = os.path.join(BASE_DIR, "results")

# Ensure output directory exists
os.makedirs(OUT_DIR, exist_ok=True)

# ----------------------------------------------------------
# 1. Common Utility Functions
# ----------------------------------------------------------

def compute_p_ge4_and_mean(df, score_col, group_cols):
    """
    Compute summary statistics by group:
      - n: Total count
      - p_ge4: Empirical probability P(score >= 4)
      - mean_score: Mean severity score (equivalent to Maslow index)
    
    Returns:
        pd.DataFrame containing the summary statistics.
    """
    df = df.copy()
    df["is_ge4"] = (df[score_col] >= 4).astype(int)

    grouped = df.groupby(group_cols)

    rows = []
    for keys, sub in grouped:
        if not isinstance(keys, tuple):
            keys = (keys,)

        n = float(len(sub))
        n_ge4 = float(sub["is_ge4"].sum())
        p_ge4 = n_ge4 / n if n > 0 else np.nan
        mean_score = float(sub[score_col].mean())

        row = dict(zip(group_cols, keys))
        row["n"] = n
        row["n_ge4"] = n_ge4
        row["p_ge4_empirical"] = p_ge4
        row["mean_score"] = mean_score  # Maslow index = E[score]
        rows.append(row)

    result = pd.DataFrame(rows)
    return result


def add_barrier_from_prob(df, p_col, out_col="B", min_p=1e-6):
    """
    Calculate the defense barrier B = -ln(p) from the probability column `p_col`.
    If p is 0, use `min_p` to avoid log(0) errors.
    
    Returns:
        pd.DataFrame with the new barrier column.
    """
    df = df.copy()
    values = []
    for p in df[p_col]:
        if pd.isna(p):
            values.append(np.nan)
        else:
            p_adj = max(float(p), min_p)
            values.append(-math.log(p_adj))
    df[out_col] = values
    return df


# ----------------------------------------------------------
# 2. Surface Integrity Analysis
# ----------------------------------------------------------

surface_file = os.path.join(MASTER_DIR, "Figure2_SourceData.csv")
surface = pd.read_csv(surface_file)

# 2-1) Summary + Maslow index
surface_summary = compute_p_ge4_and_mean(
    surface,
    score_col="score",
    group_cols=["host", "tissue", "surface"]
)
surface_summary.to_csv(
    os.path.join(OUT_DIR, "surface_treatment_summary.csv"),
    index=False
)

# 2-2) GLM: is_ge4 ~ host + tissue + surface
surface["is_ge4"] = (surface["score"] >= 4).astype(int)

surface_model = smf.glm(
    formula="is_ge4 ~ C(host) + C(tissue) + C(surface)",
    data=surface,
    family=sm.families.Binomial()
).fit()

surface_params = surface_model.summary2().tables[1].reset_index()
surface_params.rename(columns={"index": "term"}, inplace=True)
surface_params.to_csv(
    os.path.join(OUT_DIR, "logit_surface_params.csv"),
    index=False
)

surface_pred = surface[["host", "tissue", "surface"]].drop_duplicates()
surface_pred = surface_pred.copy()
surface_pred["pred_p_ge4"] = surface_model.predict(surface_pred)
surface_pred = add_barrier_from_prob(surface_pred, "pred_p_ge4", out_col="B")
surface_pred.to_csv(
    os.path.join(OUT_DIR, "logit_surface_predicted_probs.csv"),
    index=False
)


# ----------------------------------------------------------
# 3. Regrowth (Resilience) Analysis
# ----------------------------------------------------------

regrowth_file = os.path.join(MASTER_DIR, "Figure3_SourceData.csv")
regrowth = pd.read_csv(regrowth_file)

regrowth_summary = compute_p_ge4_and_mean(
    regrowth,
    score_col="score",
    group_cols=["host", "round"]
)
regrowth_summary.to_csv(
    os.path.join(OUT_DIR, "regrowth_round_summary.csv"),
    index=False
)

regrowth = regrowth.copy()
regrowth["is_ge4"] = (regrowth["score"] >= 4).astype(int)

regrowth_model = smf.glm(
    formula="is_ge4 ~ C(host) + C(round) + C(host):C(round)",
    data=regrowth,
    family=sm.families.Binomial()
).fit()

regrowth_params = regrowth_model.summary2().tables[1].reset_index()
regrowth_params.rename(columns={"index": "term"}, inplace=True)
regrowth_params.to_csv(
    os.path.join(OUT_DIR, "logit_regrowth_params.csv"),
    index=False
)

regrowth_pred = regrowth[["host", "round"]].drop_duplicates()
regrowth_pred = regrowth_pred.copy()
regrowth_pred["pred_p_ge4"] = regrowth_model.predict(regrowth_pred)
regrowth_pred = add_barrier_from_prob(regrowth_pred, "pred_p_ge4", out_col="B")
regrowth_pred.to_csv(
    os.path.join(OUT_DIR, "logit_regrowth_predicted_probs.csv"),
    index=False
)


# ----------------------------------------------------------
# 4. Rhizome vs Leaf (Deep-tissue Fortress) Analysis
# ----------------------------------------------------------

rhizome_file = os.path.join(MASTER_DIR, "Figure4_SourceData.csv")
rhizome = pd.read_csv(rhizome_file)

rhizome_summary = compute_p_ge4_and_mean(
    rhizome,
    score_col="score",
    group_cols=["host_type", "tissue"]
)
rhizome_summary.to_csv(
    os.path.join(OUT_DIR, "rhizome_summary.csv"),
    index=False
)

rhizome = rhizome.copy()
rhizome["is_ge4"] = (rhizome["score"] >= 4).astype(int)

rhizome_model = smf.glm(
    formula="is_ge4 ~ C(host_type) + C(tissue)",
    data=rhizome,
    family=sm.families.Binomial()
).fit()

rhizome_params = rhizome_model.summary2().tables[1].reset_index()
rhizome_params.rename(columns={"index": "term"}, inplace=True)
rhizome_params.to_csv(
    os.path.join(OUT_DIR, "logit_rhizome_params.csv"),
    index=False
)

rhizome_pred = rhizome[["host_type", "tissue"]].drop_duplicates()
rhizome_pred = rhizome_pred.copy()
rhizome_pred["pred_p_ge4"] = rhizome_model.predict(rhizome_pred)
rhizome_pred = add_barrier_from_prob(rhizome_pred, "pred_p_ge4", out_col="B")
rhizome_pred.to_csv(
    os.path.join(OUT_DIR, "logit_rhizome_predicted_probs.csv"),
    index=False
)


# ----------------------------------------------------------
# 5. GS6 (Growth Stage Mature) Analysis
# ----------------------------------------------------------

gs6_file = os.path.join(MASTER_DIR, "Figure5_SourceData.csv")
gs6 = pd.read_csv(gs6_file)

# Check if column is 'clone' or 'accession'
if "clone" in gs6.columns:
    clone_col = "clone"
elif "accession" in gs6.columns:
    clone_col = "accession"
else:
    raise ValueError("Column name for clone/accession not found in Figure5_SourceData.csv")

gs6_summary = compute_p_ge4_and_mean(
    gs6,
    score_col="score",
    group_cols=[clone_col]
)
gs6_summary = add_barrier_from_prob(gs6_summary, "p_ge4_empirical", out_col="B")
gs6_summary.to_csv(
    os.path.join(OUT_DIR, "growth_stage_GS6_summary.csv"),
    index=False
)


# ----------------------------------------------------------
# 6. Construct Unified Barrier Table (defense_barriers.csv)
# ----------------------------------------------------------

# surface
surf_b = surface_pred.copy()
surf_b["dataset"] = "surface"

# regrowth
reg_b = regrowth_pred.copy()
reg_b["dataset"] = "regrowth"

# rhizome
rhi_b = rhizome_pred.copy()
rhi_b["dataset"] = "rhizome"

# GS6
gs6_b = gs6_summary[[clone_col, "p_ge4_empirical", "B"]].copy()
gs6_b.rename(columns={"p_ge4_empirical": "p_ge4"}, inplace=True)
gs6_b["dataset"] = "GS6"

# concat with minimal common columns
surf_b2 = surf_b[["dataset", "host", "tissue", "surface", "pred_p_ge4", "B"]].copy()
surf_b2.rename(columns={"pred_p_ge4": "p_ge4"}, inplace=True)

reg_b2 = reg_b[["dataset", "host", "round", "pred_p_ge4", "B"]].copy()
reg_b2["tissue"] = "leaf"
reg_b2["surface"] = "NA"
reg_b2.rename(columns={"pred_p_ge4": "p_ge4"}, inplace=True)

rhi_b2 = rhi_b[["dataset", "host_type", "tissue", "pred_p_ge4", "B"]].copy()
rhi_b2.rename(columns={"host_type": "host", "pred_p_ge4": "p_ge4"}, inplace=True)
rhi_b2["surface"] = "NA"
rhi_b2["round"] = np.nan

gs6_b2 = gs6_b.copy()
gs6_b2["host"] = "JG"
gs6_b2["tissue"] = "leaf"
gs6_b2["surface"] = "NA"
gs6_b2["round"] = np.nan
gs6_b2.rename(columns={clone_col: "condition"}, inplace=True)

defense_barriers = pd.concat(
    [
        surf_b2[["dataset", "host", "tissue", "surface", "round", "p_ge4", "B"]],
        reg_b2[["dataset", "host", "tissue", "surface", "round", "p_ge4", "B"]],
        rhi_b2[["dataset", "host", "tissue", "surface", "round", "p_ge4", "B"]],
        gs6_b2[["dataset", "host", "tissue", "surface", "round", "p_ge4", "B"]],
    ],
    ignore_index=True
)
defense_barriers.to_csv(
    os.path.join(OUT_DIR, "defense_barriers.csv"),
    index=False
)


# ----------------------------------------------------------
# 7. Strategic Indices (RI & MTI)
# ----------------------------------------------------------

# 7-1) Calculate RI = B_round2 - B_round1 from Regrowth data
reg_b_small = regrowth_pred[["host", "round", "B"]].copy()
reg_wide = reg_b_small.pivot(index="host", columns="round", values="B")
reg_wide.columns = [f"B_round{int(c)}" for c in reg_wide.columns]
reg_wide.reset_index(inplace=True)

if "B_round1" in reg_wide.columns and "B_round2" in reg_wide.columns:
    reg_wide["RI"] = reg_wide["B_round2"] - reg_wide["B_round1"]
else:
    reg_wide["RI"] = np.nan

# 7-2) Calculate MTI = B_core - B_canopy
# Johnsongrass (JG) uses rhizome barrier as core.
# Sorghum has no rhizome data, so MTI is set to 0.

# Find JG rhizome barrier from rhizome_pred
jg_rhi_rows = rhizome_pred[
    (rhizome_pred["host_type"] == "JG") & (rhizome_pred["tissue"] == "rhizome")
]
if len(jg_rhi_rows) > 0:
    B_core_JG = float(jg_rhi_rows["B"].iloc[0])
else:
    B_core_JG = np.nan

core_records = []
for _, row in reg_wide.iterrows():
    host = row["host"]
    B_leaf = row.get("B_round1", np.nan)
    if host == "SH1152":
        B_core = B_core_JG
        if not (math.isnan(B_core) or math.isnan(B_leaf)):
            MTI = B_core - B_leaf
        else:
            MTI = np.nan
    else:
        B_core = B_leaf
        MTI = 0.0  # Assume no rhizome strategy for sorghum
    core_records.append((host, B_core, MTI))

core_df = pd.DataFrame(core_records, columns=["host", "B_core", "MTI"])

strategic = pd.merge(reg_wide, core_df, on="host", how="left")
strategic.to_csv(
    os.path.join(OUT_DIR, "strategic_indices.csv"),
    index=False
)

print("=== Done. All analysis tables saved in:", OUT_DIR)