In [2]:
# =========================================================
# Temporal Causal Discovery using DAGitty (Simplified Import)
# =========================================================

import pandas as pd
from collections import defaultdict, deque
from itertools import product
from typing import List, Tuple, Set
import rpy2.robjects as ro

# Load required R library
ro.r('library(dagitty)')

import warnings
warnings.filterwarnings('ignore')

In [3]:
def extract_edge_list_from_csv(
    csv_path: str,
    sample_frac: float,
    existence_threshold: int = 50,
    direction_count_threshold: int = 10,
    orientation_threshold: float = 0.85,
) -> list:
    """
    Load causal discovery edge results from CSV and return a cleaned edge list.

    Args:
        csv_path: path to the causal discovery results CSV file.
        sample_frac: fraction value to filter (e.g., 0.8).
        existence_threshold: minimum total frequency for an edge to be considered.
        direction_count_threshold: minimum directional count required to call an edge directed.
        orientation_threshold: ratio threshold (e.g., 0.85) to determine direction robustness.

    Returns:
        A list of (a, b, edge_type) tuples such as ('X', 'Y', '->') or ('X', 'Y', '--').
    """
    # ---------- Step 1: Load and filter ----------
    df = pd.read_csv(csv_path)
    df_filtered = df.loc[df["sample_frac"] == sample_frac].reset_index(drop=True)

    # ---------- Step 2: Build dictionary ----------
    edge_dict = {
        (row["from_var"], row["to_var"]): [int(round(row["freq"] * 100)), row["edge_type"]]
        for _, row in df_filtered.iterrows()
    }

    # ---------- Step 3: Aggregate directions ----------
    edge_stat = {}
    for (a, b), (freq, typ) in edge_dict.items():
        key = tuple(sorted([a, b]))
        if key not in edge_stat:
            edge_stat[key] = {"Nab": 0, "Nba": 0, "Nun": 0}

        if typ == "directed":
            if (a, b) == key:
                edge_stat[key]["Nab"] += freq
            else:
                edge_stat[key]["Nba"] += freq
        elif typ == "undirected":
            edge_stat[key]["Nun"] += freq

    # ---------- Step 4: Build edge list ----------
    edge_list = []
    for (left, right), stat in edge_stat.items():
        Nab = stat["Nab"]
        Nba = stat["Nba"]
        Nun = stat["Nun"]
        total = Nab + Nba + Nun

        if total < existence_threshold:
            continue

        if Nab > Nba:
            if Nab >= direction_count_threshold and Nab / (Nab + Nba) >= orientation_threshold:
                edge_list.append((left, right, "->"))
            else:
                edge_list.append((left, right, "--"))
        elif Nba > Nab:
            if Nba >= direction_count_threshold and Nba / (Nba + Nab) >= orientation_threshold:
                edge_list.append((right, left, "->"))
            else:
                edge_list.append((left, right, "--"))
        else:
            edge_list.append((left, right, "--"))

    return edge_list

In [None]:
# ---------- 1) Graph parsing from edge list ----------
def parse_dagitty_graph(edge_list):
    """
    Given an edge_list already in the form of (src, dst, type),
    return a set of unique node names and the edge list itself.

    Example:
        input = [
            ('age_binned', 'promis_dep_sum_tm1', '->'),
            ('age_binned', 'promis_anx_sum_tm1', '->'),
            ('promis_anx_sum_tm1', 'promis_dep_sum_tm1', '--')
        ]

        output:
            nodes = {'age_binned', 'promis_dep_sum_tm1', 'promis_anx_sum_tm1'}
            edges = same as input
    """
    nodes: Set[str] = set()
    edges: List[Tuple[str, str, str]] = []

    for u, v, t in edge_list:
        nodes.add(u)
        nodes.add(v)
        edges.append((u, v, t))

    return nodes, edges

# ---------- 2) Ancestor computation (ignoring undirected edges) ----------
def ancestors_directed_only(edges, query_nodes, include_self=True): 
    """
    Find all ancestor nodes of given query nodes using only directed edges.

    Args:
        edges: list of (src, dst, type) tuples, where type is '->' or '--'.
        query_nodes: list or set of nodes for which to find ancestors.
        include_self: if True, include the query nodes themselves
                      (consistent with DAGitty's behavior).

    Returns:
        A set of all ancestor nodes that have a directed path
        leading to any of the query nodes.
    """
    parents = defaultdict(set)
    for src, dst, typ in edges:
        if typ == '->':
            parents[dst].add(src)

    anc: Set[str] = set()
    # DAGitty's ancestors() by default includes the node itself.
    if include_self:
        anc.update(query_nodes)

    stack = list(query_nodes)
    seen = set(query_nodes)

    # Depth-first traversal following parent (reverse) links
    while stack:
        cur = stack.pop()
        for p in parents.get(cur, set()):
            if p not in seen:
                seen.add(p)
                anc.add(p)
                stack.append(p)
    return anc

# ---------- 3) Edge filtering ----------
def filter_edges_by_nodes(edges, keep_nodes):
    """
    Filter the given edge list to include only the edges whose
    source and destination nodes are both within a specified set.

    Args:
        edges: list of (src, dst, type) tuples representing edges in the graph.
        keep_nodes: set of node names to keep in the filtered graph.

    Returns:
        A filtered list of edges where both endpoints belong to keep_nodes.
    """
    return [(u, v, t) for (u, v, t) in edges if u in keep_nodes and v in keep_nodes]

# ---------- 4) Cycle and name checking ----------
def creates_cycle(adj, u, v) -> bool:
    """
    Check whether adding a directed edge u -> v would create a cycle
    in the current directed adjacency structure.

    Args:
        adj: dict[str, set[str]], representing the current adjacency list
             of directed edges (e.g., adj[u] = {v1, v2, ...}).
        u: source node of the new edge.
        v: destination node of the new edge.

    Returns:
        True if adding u -> v would introduce a cycle (i.e., if a path
        already exists from v back to u), otherwise False.
    """
    if u == v:
        return True
    visited = set()
    dq = deque([v])
    while dq:
        x = dq.popleft()
        if x == u:
            return True
        for w in adj.get(x, ()):
            if w not in visited:
                visited.add(w)
                dq.append(w)
    return False


def is_temporal(name: str) -> bool:
    return name.endswith('_t') or name.endswith('_tm1')

# ---------- 5) Orientation rules for undirected edges ----------
def orient_pdag_to_dag_all(edges):
    """
    Generate all possible DAG completions from a given PDAG
    by orienting each undirected edge ('--') in both possible directions.

    Each resulting DAG must satisfy two constraints:
        1. No directed cycles are introduced.
        2. The destination node (outcome) must be temporally valid,
           i.e., its name must end with '_t' or '_tm1'.

    Args:
        edges: list of (src, dst, type) tuples representing the PDAG.
               'type' can be '->' (directed) or '--' (undirected).

    Returns:
        A list of DAGs, where each DAG is represented as a list of (src, dst)
        directed edges that form a valid acyclic graph.

    Notes:
        - For each undirected edge (u, v), both orientations (u→v and v→u)
          are tested independently.
        - Only combinations that pass all constraints (no cycles, valid
          temporal naming) are included in the final result.
    """
    directed = [(u, v) for (u, v, t) in edges if t == '->']
    undirected = [(u, v) for (u, v, t) in edges if t == '--']

    all_dags = []
    for bits in product([0, 1], repeat=len(undirected)):
        adj = defaultdict(set)
        # Add existing directed edges first
        for u, v in directed:
            adj[u].add(v)
        dir_edges = directed.copy()

        ok = True
        for (u, v), bit in zip(undirected, bits):
            # bit=0 → u→v, bit=1 → v→u
            src, dst = (u, v) if bit == 0 else (v, u)

            # Temporal constraint: only variables ending with '_t' or '_tm1'
            # can appear as outcome (destination) nodes.
            if not is_temporal(dst):
                ok = False
                break

            # Cycle check
            if creates_cycle(adj, src, dst):
                ok = False
                break

            # Add the directed edge
            adj[src].add(dst)
            dir_edges.append((src, dst))

        if ok:
            all_dags.append(dir_edges)

    return all_dags

# ---------- 6) Build DAGitty-compatible graph string ----------
def build_dagitty_dag_string(nodes, directed_edges):
    """
    Construct a DAGitty-compatible DAG string representation from
    a given set of nodes and directed edges.

    Args:
        nodes: set of node names (strings) to include in the DAG.
        directed_edges: list of (src, dst) tuples representing directed edges.

    Returns:
        A formatted string defining the DAG in DAGitty syntax, e.g.:

            dag {
                A
                B
                C
                A -> B ;
                B -> C ;
            }

    Notes:
        - Each node is printed on a new line (sorted alphabetically).
        - Each directed edge is represented as 'A -> B ;' on a single line.
        - The output can be directly parsed by the R `dagitty` package.
    """
    node_lines = "\n".join(sorted(nodes))
    edge_lines = " ".join(f"{u} -> {v} ;" for (u, v) in directed_edges)
    return f"dag {{\n{node_lines}\n{edge_lines}\n}}"

# ---------- 7) Compute adjustment sets for a single DAG ----------
def adjustment_sets_for_dag(nodes, directed_edges, treatment, outcome):
    """
    Compute minimal adjustment sets for a single DAG using the R 'dagitty' package.

    Args:
        nodes: set of node names (strings).
        directed_edges: list of (src, dst) tuples representing directed edges.
        treatment: name of the treatment variable.
        outcome: name of the outcome variable.

    Returns:
        dag_str: DAGitty-compatible DAG string.
        py_sets: list of adjustment sets (each set represented as a list of variable names).
                 Example: [] represents an empty adjustment set {}.
        min_size: size (int) of the smallest minimal adjustment set,
                  or None if no adjustment set exists.

    Notes:
        - The function constructs the DAG in DAGitty syntax, passes it to R via `rpy2`,
          and retrieves the minimal adjustment sets using `adjustmentSets()`.
        - DAGitty may return multiple valid minimal adjustment sets; all are included.
    """
    dag_str = build_dagitty_dag_string(nodes, directed_edges)
    g = ro.r['dagitty'](dag_str)
    adj = ro.r['adjustmentSets'](g, treatment=treatment, outcome=outcome, type="minimal")

    # Each element in 'adj' is an R character vector (possibly empty)
    py_sets = [list(s) for s in adj]  # [] == {} → empty adjustment set
    min_size = min((len(s) for s in py_sets), default=None)
    return dag_str, py_sets, min_size


# ---------- 8) Evaluate all possible DAG completions ----------
def evaluate_all_dags(dag_edge_sets, nodes, treatment, outcome):
    """
    Compute adjustment sets across all DAG completions derived from a PDAG.

    Args:
        dag_edge_sets: list of DAGs, where each DAG is represented as a list of (src, dst) edges.
        nodes: set of all node names (strings).
        treatment: name of the treatment variable.
        outcome: name of the outcome variable.

    Returns:
        A list of dictionaries, one per DAG, each containing:
            - "idx": DAG index number.
            - "edges": list of directed edges used in this DAG.
            - "dag_str": the DAGitty string representation.
            - "adj_sets": list of minimal adjustment sets (possibly empty).
            - "min_size": size of the smallest minimal adjustment set.
            - "has_sets": boolean indicating whether any adjustment set exists.

    Notes:
        - This function iterates through all possible DAG completions (e.g., generated by
          `orient_pdag_to_dag_all()`), computes adjustment sets for each, and aggregates results.
        - It is useful for exploring the impact of undirected edge orientations on causal adjustment.
    """
    results = []
    for i, dir_edges in enumerate(dag_edge_sets):
        dag_str, sets_i, min_size = adjustment_sets_for_dag(nodes, dir_edges, treatment, outcome)
        results.append({
            "idx": i,
            "edges": dir_edges,
            "dag_str": dag_str,
            "adj_sets": sets_i,     # [[], ['a','b'], ...]
            "min_size": min_size,   # None (no set) or int
            "has_sets": len(sets_i) > 0
        })
    return results


In [5]:
edge_list = extract_edge_list_from_csv(
    csv_path="causal_discovery_results/all_edges_sample_frac_with_vars.csv",
    sample_frac=0.8,
    existence_threshold=50,
    direction_count_threshold=10,
    orientation_threshold=0.85,
)

In [None]:
nodes, edges = parse_dagitty_graph(edge_list)
treatment = "promis_dep_sum_t"
outcome  = "rem_std_t"
anc_nodes = ancestors_directed_only(edge_list, [treatment, outcome], include_self=True)
sub_edges = filter_edges_by_nodes(edge_list, anc_nodes)
dag_edge_sets = orient_pdag_to_dag_all(sub_edges)

In [None]:
results = evaluate_all_dags(dag_edge_sets, anc_nodes, treatment, outcome)

In [8]:
results

[{'idx': 0,
  'edges': [('age_binned', 'promis_dep_sum_tm1'),
   ('age_binned', 'promis_anx_sum_tm1'),
   ('age_binned', 'rmssd_std_tm1'),
   ('sex_encoded', 'promis_anx_sum_tm1'),
   ('promis_dep_sum_tm1', 'promis_dep_sum_t'),
   ('promis_anx_sum_tm1', 'promis_dep_sum_t'),
   ('rmssd_std_tm1', 'rem_std_t'),
   ('rem_mean_tm1', 'rem_std_t'),
   ('rem_std_tm1', 'rem_std_t'),
   ('temperature_max_mean_tm1', 'rem_mean_tm1'),
   ('promis_dep_sum_t', 'rem_std_t'),
   ('promis_anx_sum_tm1', 'promis_dep_sum_tm1')],
  'dag_str': 'dag {\nage_binned\npromis_anx_sum_tm1\npromis_dep_sum_t\npromis_dep_sum_tm1\nrem_mean_tm1\nrem_std_t\nrem_std_tm1\nrmssd_std_tm1\nsex_encoded\ntemperature_max_mean_tm1\nage_binned -> promis_dep_sum_tm1 ; age_binned -> promis_anx_sum_tm1 ; age_binned -> rmssd_std_tm1 ; sex_encoded -> promis_anx_sum_tm1 ; promis_dep_sum_tm1 -> promis_dep_sum_t ; promis_anx_sum_tm1 -> promis_dep_sum_t ; rmssd_std_tm1 -> rem_std_t ; rem_mean_tm1 -> rem_std_t ; rem_std_tm1 -> rem_std_t ; t