In [8]:
from itertools import combinations


import networkx as nx

def is_d_separated(graph, node_a, node_b, conditioned_on):
    """
    Determines if two nodes are d-separated given a conditioning set.
    Args:
        graph (nx.DiGraph): The causal graph as a directed acyclic graph.
        node_a (str): First node.
        node_b (str): Second node.
        conditioned_on (set): Set of nodes conditioned on.
    Returns:
        bool: True if node_a and node_b are d-separated, False otherwise.
    """
    def blocked(path, conditioning_set):
        """
        Check if a path is blocked by the conditioning set.
        """
        for i in range(len(path) - 1):
            u, v = path[i], path[i + 1]
            if graph.has_edge(u, v):  # u -> v
                if u in conditioning_set:  # collider blocked by conditioning
                    return True
            elif graph.has_edge(v, u):  # u <- v
                if v in conditioning_set:  # collider
                    return True
        return False
    
    for path in nx.all_simple_paths(graph.to_undirected(), source=node_a, target=node_b):
        if not blocked(path, conditioned_on):
            return False
    return True


def all_d_separation_sets(graph, node_a, node_b):
    """
    Finds all conditioning sets that d-separate two nodes.
    """
    nodes = set(graph.nodes) - {node_a, node_b}
    all_conditioning_sets = []
    for r in range(len(nodes) + 1):
        for conditioned_on in combinations(nodes, r):
            if is_d_separated(graph, node_a, node_b, set(conditioned_on)):
                all_conditioning_sets.append(set(conditioned_on))
    return all_conditioning_sets


# Define the causal graph
edges = [
    ("A", "C"),
    ("B", "C"),
    ("B", "D"),
    ("C", "E"),
    ("D", "E")
]
graph = nx.DiGraph(edges)



# Find all d-separation sets for X and Y
d_separation_sets = all_d_separation_sets(graph, "A", "E")
print(f"D-separation sets for {node_a} and {node_b}: {d_separation_sets}")


D-separation sets for X and Y: [{'B', 'C'}, {'D', 'C'}, {'B', 'D', 'C'}]
