In [1]:
# -*- coding: utf-8 -*-
import ast
import itertools
import re
import rpy2.robjects as ro

# ---------- parsing ----------
def parse_r_list_item(item: str):
    """
    Robustly parse DAGitty adjustment set strings returned as R characters.

    Supports the following formats:
      - 'c("K","Z")'           -> ['K', 'Z']
      - '{K, Z}' or '{Z}'      -> ['K', 'Z'] / ['Z']    (brace style sometimes used by DAGitty)
      - 'character(0)' / '{}'  -> []                    (empty set)
      - 'K'                    -> ['K']                 (single symbol without quotes)
      - '"K"'                  -> ['K']                 (quoted single)

    Returns:
        list[str]: A Python list of variable names.
    """
    if item is None:
        return []
    s = str(item).strip()

    # Empty set variants from DAGitty or R printing
    if s in ("character(0)", "", "NULL"):
        return []
    if s == "{}":
        return []

    # R vector literal, e.g., c("A","B")
    if s.startswith("c("):
        cleaned = s.replace('c', '', 1).replace('\n', '').replace(' ', '')
        if cleaned in ('()',):
            return []
        return list(ast.literal_eval(cleaned))

    # Brace style, e.g., {A, B}
    if s.startswith("{") and s.endswith("}"):
        inner = s[1:-1].strip()
        if inner == "":
            return []
        parts = [p.strip() for p in inner.split(",")]
        return [p for p in parts if p]

    # Single bare symbol, e.g., K
    if re.fullmatch(r"[A-Za-z_][A-Za-z0-9_]*", s):
        return [s]

    # Last resort: try literal_eval for quoted single or list
    try:
        val = ast.literal_eval(s)
        if isinstance(val, str):
            return [val]
        if isinstance(val, (list, tuple)):
            return list(val)
    except Exception:
        pass

    raise ValueError(f"Unrecognized adjustment-set string format: {s}")

# ---------- combine bi-directional into directed/undirected ----------
def combine_bidirectional_edges(edge_freqs, strict=True, threshold=1):
    """
    Collapse bootstrap direction frequencies into a mixed edge list for a PDAG.

    The input counts come from repeated PC runs:
      edge_freqs: dict with keys as ('A','B') for A->B and values as frequency (int).

    Policy:
      - strict=True:
          If both directions appear at least once ((A,B) and (B,A) in the dict),
          mark the edge as undirected ('--') because orientation is unstable.
      - strict=False:
          Only mark as undirected if both directions reach at least `threshold`.
          Otherwise keep the dominating direction. This is a noise-robust option.

    Args:
        edge_freqs (dict[tuple[str, str], int]): frequencies of directed edges.
        strict (bool): strict undirected rule as described above.
        threshold (int): minimum count to consider a direction present when strict=False.

    Returns:
        list[tuple[str, str, str]]: A list of edges in the form (a, b, etype),
                                    where etype is '->' or '--'.
                                    Undirected edges are canonicalized as (min(a,b), max(a,b), '--').
    """
    visited, edges = set(), []
    for (a, b), fab in edge_freqs.items():
        if (a, b) in visited or (b, a) in visited:
            continue
        fba = edge_freqs.get((b, a), 0)

        if strict:
            # If both directions ever appeared, treat as unoriented
            if fab > 0 and fba > 0:
                left, right = sorted([a, b])
                edges.append((left, right, '--'))
            else:
                # Only one direction present
                if fab > 0:
                    edges.append((a, b, '->'))
                elif fba > 0:
                    edges.append((b, a, '->'))
        else:
            # Thresholded rule to filter out rare noise directions
            if fab >= threshold and fba >= threshold:
                left, right = sorted([a, b])
                edges.append((left, right, '--'))
            elif fab >= threshold and fba < threshold:
                edges.append((a, b, '->'))
            elif fba >= threshold and fab < threshold:
                edges.append((b, a, '->'))
            # If both are below threshold, drop the edge entirely

        visited.add((a, b))
        visited.add((b, a))
    return edges

# ---------- enumerate DAG completions from mixed edges ----------
def enumerate_dag_strings_from_mixed(mixed_edges, max_combinations=1024):
    """
    Expand a PDAG with unoriented edges ('--') into all possible DAG completions.

    Each '--' edge (u, v) is replaced by either u->v or v->u. This enumerates
    every orientation combination and returns a list of DAGitty strings.

    Args:
        mixed_edges (list[(str, str, str)]): edges with types '->' or '--'.
        max_combinations (int): safety cap to avoid combinatorial explosion.

    Returns:
        list[str]: DAGitty DAG strings like 'dag { A -> B ; C -> D }'.

    Raises:
        RuntimeError: if the number of completions exceeds max_combinations.
    """
    directed = [(a, b) for a, b, t in mixed_edges if t == '->']
    undirected = [(a, b) for a, b, t in mixed_edges if t == '--']

    if not undirected:
        body = " ; ".join([f"{a} -> {b}" for (a, b) in directed])
        return [f"dag {{ {body} }}"]

    # For each undirected edge, create two choices: u->v or v->u
    choices_per_edge = [[(u, v), (v, u)] for (u, v) in undirected]
    all_choices = list(itertools.product(*choices_per_edge))
    if len(all_choices) > max_combinations:
        raise RuntimeError(
            f"Too many DAG completions ({len(all_choices)}). "
            f"Reduce undirected edges or increase max_combinations."
        )

    dag_strings = []
    for orient in all_choices:
        dirs = directed + list(orient)
        body = " ; ".join([f"{a} -> {b}" for (a, b) in dirs])
        dag_strings.append(f"dag {{ {body} }}")
    return dag_strings

# ---------- compute adjustment sets across all DAG completions ----------
def compute_adjustment_sets_across_dags(dag_strings, exposure, outcome):
    """
    Compute adjustment sets that are valid across all possible DAG completions.

    Steps:
      1) For each DAG completion, call DAGitty::adjustmentSets(type="minimal")
         and collect the union of all candidate sets.
      2) For each candidate set, verify with DAGitty::isAdjustmentSet that it is
         valid for every single DAG completion.
      3) Return only those sets that pass all checks (intersection over completions).

    This is a conservative approach that respects orientation uncertainty in the PDAG.

    Args:
        dag_strings (list[str]): list of 'dag { ... }' strings.
        exposure (str): treatment variable name.
        outcome (str): outcome variable name.

    Returns:
        list[list[str]]: a list of valid adjustment sets (each is a list of variables).
    """
    # 1) Union of candidates across all DAG completions
    union_candidates = []
    for dag in dag_strings:
        r_code = f'''
        library(dagitty)
        g <- dagitty("{dag}")
        as.character(adjustmentSets(g, exposure="{exposure}", outcome="{outcome}", type="minimal"))
        '''
        cand = list(ro.r(r_code))
        # Each entry is like 'c("A","Z")', '{A, Z}', or 'character(0)'
        union_candidates.extend(cand if cand else [])

    # Deduplicate by raw string to avoid repeated parsing of identical sets
    union_candidates = list(dict.fromkeys(union_candidates))

    # Parse to Python lists
    parsed_candidates = [parse_r_list_item(s) for s in union_candidates]

    # If any DAG returned character(0) include [] explicitly
    if any(str(s).strip() == "character(0)" for s in union_candidates):
        parsed_candidates.append([])

    # 2) Keep only candidates that are valid for ALL DAG completions
    valid_all = []
    for Z in parsed_candidates:
        ok_all = True
        for dag in dag_strings:
            z_vec = "c()" if not Z else "c(" + ",".join(f'"{v}"' for v in Z) + ")"
            r_code = f'''
            library(dagitty)
            g <- dagitty("{dag}")
            isAdjustmentSet(g, exposure="{exposure}", outcome="{outcome}", Z={z_vec})
            '''
            res = bool(ro.r(r_code)[0])
            if not res:
                ok_all = False
                break
        if ok_all:
            valid_all.append(Z)

    # Remove duplicates while preserving order
    uniq = []
    for z in valid_all:
        if z not in uniq:
            uniq.append(z)
    return uniq

# ---------- Example usage ----------
edges_freq = {
    ('T', 'Y'): 93,
    ('Z', 'T'): 83,
    ('Z', 'Y'): 67,
    ('U', 'T'): 60,
    ('U', 'V'): 100,
    ('V', 'Y'): 100,
    ('Y', 'V'): 80,  # This creates orientation uncertainty between V and Y
}

treatment = "T"
outcome = "Y"

# 1) Convert bootstrap frequencies into a PDAG edge list (-> and --)
mixed_edges = combine_bidirectional_edges(edges_freq, strict=True)

# 2) Enumerate all DAG completions by orienting every '--' both ways
dag_strings = enumerate_dag_strings_from_mixed(mixed_edges, max_combinations=2048)
print("DAG completions:", len(dag_strings))
print("\n".join(dag_strings[:3]), "\n...")  # sanity check

# 3) Compute adjustment sets that remain valid across all DAG completions
adj_sets_all = compute_adjustment_sets_across_dags(
    dag_strings, exposure=treatment, outcome=outcome
)
print("Adjustment sets valid in ALL DAG completions:", adj_sets_all)


DAG completions: 2
dag { T -> Y ; Z -> T ; Z -> Y ; U -> T ; U -> V ; V -> Y }
dag { T -> Y ; Z -> T ; Z -> Y ; U -> T ; U -> V ; Y -> V } 
...
Adjustment sets valid in ALL DAG completions: [['U', 'Z']]
