In [None]:
import numpy as np
import pandas as pd

# -----------------------
# 1) Add phenotype from ARD
# -----------------------
def add_phenotype_from_ard(df, ard_col="ard", threshold=0.30):
    out = df.copy()
    out["phenotype"] = np.where(
        out[ard_col].isna(),
        "UNKNOWN",
        np.where(out[ard_col].astype(float) >= threshold, "ARDS", "NON_ARDS")
    )
    return out


# -----------------------
# 2) Replay orchestrator using merged-only columns
# -----------------------
class OrchestratorReplay:
    """
    Uses ONLY merged columns:
      - current_state_label (chronologist)
      - ard (agent 1) 
      - recommended_total_peep, confidence_score (agent 2)
      - success_48h (agent 3 outcome label)

    Orchestrator-followed states (requested renames):
      Acute -> Distress
      Recovery -> Controlled
      Liberation -> Weaning (collapsed)
      Weaning -> Weaning
    """

    LABEL_MAP = {
        "Acute": "Distress",
        "Distress": "Distress",
        "Recovery": "Controlled",
        "Controlled": "Controlled",
        "Weaning": "Weaning",
        "Liberation": "Weaning",
    }

    def __init__(self, k_switch=2, start_state="Distress"):
        self.K_SWITCH = int(k_switch)
        self.start_state = start_state
        self.reset()

    def reset(self):
        self.current_state = self.start_state
        self._pending_label = None
        self._pending_count = 0

    def _normalize_label(self, x):
        if x is None:
            return None
        x = str(x).strip()
        return self.LABEL_MAP.get(x, None)

    def step(self, row):
        obs = self._normalize_label(row.get("current_state_label", None))
        prev = self.current_state
        self._update_state(obs)

        # choose which agent output is "active" based on orchestrator-followed state
        decision = self._select_decision(self.current_state, row)

        return {
            "visit_occurrence_id": row.get("visit_occurrence_id"),
            "measure_time": row.get("measure_time"),
            "person_id": row.get("person_id"),
            "phenotype": row.get("phenotype"),
            "chronologist_obs_label": obs,
            "prev_state_label": prev,
            "orchestrator_state": self.current_state,
            "decision": decision,
            "recommended_total_peep": row.get("recommended_total_peep"),
            "confidence_score": row.get("confidence_score"),
            "success_48h": row.get("success_48h"),  # outcome label
        }

    def _update_state(self, obs):
        if obs is None:
            return

        if obs == self.current_state:
            self._pending_label = None
            self._pending_count = 0
            return

        if self._pending_label != obs:
            self._pending_label = obs
            self._pending_count = 1
        else:
            self._pending_count += 1

        if self._pending_count >= self.K_SWITCH:
            self.current_state = obs
            self._pending_label = None
            self._pending_count = 0

    def _select_decision(self, state, row):
        if state == "Distress":
            return {"active_agent": "Agent1", "action": "ASSESS", "ard": row.get("ard")}

        if state == "Controlled":
            return {
                "active_agent": "Agent2_RL",
                "action": "ADJUST_PEEP",
                "recommended_total_peep": row.get("recommended_total_peep"),
                "confidence_score": row.get("confidence_score"),
            }

        if state == "Weaning":
            # success_48h is an outcome label; we canâ€™t use it prospectively as readiness probability
            return {"active_agent": "Agent3", "action": "WEAN_PROTOCOL", "label_success_48h": row.get("success_48h")}

        raise ValueError(f"Unknown state: {state}")


def replay_on_merged(merged_df, k_switch=2):
    d = merged_df.sort_values(["visit_occurrence_id", "measure_time"]).copy()
    orch = OrchestratorReplay(k_switch=k_switch, start_state="Distress")
    outs = []

    for vid, g in d.groupby("visit_occurrence_id", sort=False):
        orch.reset()
        for _, r in g.iterrows():
            outs.append(orch.step(r.to_dict()))

    return pd.DataFrame(outs)


# -----------------------
# 3) Run
# -----------------------
merged2 = add_phenotype_from_ard(merged, ard_col="ard", threshold=0.30)
out_df = replay_on_merged(merged2, k_switch=2)

print(out_df["orchestrator_state"].value_counts())
print(out_df[["visit_occurrence_id","measure_time","phenotype","chronologist_obs_label","orchestrator_state"]].head(20))
