##### Imports

In [8]:
import networkx as nx
import itertools
from scipy.special import comb
import math
from typing import Any, List

#### Equivalence Class definition

In [10]:
class nodePosition:

    def __init__(self, node_name, appears_after_xi : bool) -> None:
        self.node_name = node_name
        self.appears_after_xi = appears_after_xi

    def isBefore(self):
        return not self.appears_after_xi

    def __str__(self):
        return f"({self.node_name}, {self.appears_after_xi})"

class equivalenceClass:

    def __init__(self, unrelated_node_position, left_topo=1, right_topo=1):
        self.position = unrelated_node_position
        self.left_topo = left_topo
        self.right_topo = right_topo
    
    def nodes_before(self): #The nodes before x_i
        return sum(map(lambda node_pos : node_pos.isBefore(), self.position))
    
    def nodes_after(self): #The nodes after x_i
        return len(self.position) - self.nodes_before()
    
    def classSize(self): #Number of topological orders
        return self.left_topo * self.right_topo

    def __str__(self):
        positions = [str(x) for x in self.position]
        return f"Equivalence Class (Nodes={positions}, Size={self.classSize()})"

#### Auxiliary Functions

In [9]:
def isLeaf(node, dag : nx.DiGraph):
    return dag.out_degree(node) == 0

def isRoot(node, dag : nx.DiGraph):
    return dag.in_degree(node) == 0

def multinomial_coefficient(args):
    n = sum(args)
    coeff = 1
    for k in args:
        coeff *= comb(n, k, exact=True)
        n -= k
    return coeff

def unionOf(equivalence_classes : List[equivalenceClass]):
    n = len(equivalence_classes)
    positions = set()
    nodes_before = [0]*n
    nodes_after = [0]*n
    left_topos = [0]*n
    right_topos = [0]*n
    for i,eq_class in enumerate(equivalence_classes):
        nodes_before[i] = eq_class.nodes_before()
        nodes_after[i] = eq_class.nodes_after()
        left_topos[i] = eq_class.left_topo
        right_topos[i] = eq_class.right_topo
        positions = positions.union(eq_class.position)

    left_size = multinomial_coefficient(nodes_before) * math.prod(left_topos)
    right_size = multinomial_coefficient(nodes_after) * math.prod(right_topos)
    return equivalenceClass(positions, left_size, right_size)

#### Equivalence Class formula

In [16]:
def equivalenceClassesSizes(node, dag : nx.DiGraph):
    if isLeaf(node, dag):
        classes = []
        for x in [True, False]:
           classes.append(equivalenceClass({nodePosition(str(node), x)}))
        return classes
    
    children_classes = list(map(lambda child : equivalenceClassesSizes(child,dag), dag.successors(node)))
    classes_combinations = list(itertools.product(*children_classes)) #Generate al the possible combinations for each eqClass of each child with the eqClass of the other children. 
    classes = list(map(lambda mix : unionOf(mix), classes_combinations))
    return classes

    



# If it has more than one root we need to add one root and connect it to all of the roots of the unrelated nodes. 
def setUp(dag : nx.DiGraph, unrelated_roots : List[Any]):
    if len(unrelated_roots)>1:
        for root in unrelated_roots:
            dag.add_edge(new_root, root)
        return new_root
    return unrelated_roots[0]

def tearDown(dag : nx.DiGraph):
    if dag.has_node(new_root):
        edges_to_remove =  list(dag.out_edges(new_root))
        dag.remove_edges_from(edges_to_remove)


#### First example

In [23]:
Naive_Bayes = nx.DiGraph()
Naive_Bayes.add_nodes_from([1,2,3,4,5,6])
#Naive_Bayes.add_edges_from([
#    (1, 2),
#    (1, 3),
#    (1, 4)
#])
#Naive_Bayes.add_edges_from([
#    (1, 2),
#    (1, 3),
#    (1, 4)
#])

all_topo_sorts = list(nx.all_topological_sorts(Naive_Bayes))

x_i = 5
new_root = 'r0' 
unr_roots =  [node for node in Naive_Bayes.nodes if (isRoot(node, Naive_Bayes) and node != x_i)]
new_root = setUp(Naive_Bayes, unr_roots)
classesSize = equivalenceClassesSizes(new_root, Naive_Bayes)
tearDown(Naive_Bayes)


number_of_topos = 0
for eq_class in classesSize:
#   print(eq_class)
    number_of_topos += eq_class.classSize()

print(number_of_topos == len(all_topo_sorts))


True
