##### Imports

In [60]:
import networkx as nx
import itertools
from scipy.special import comb
import math
from enum import Enum
import matplotlib.pyplot as plt
from typing import Any, List, Dict, Set
import time

#TODO: Mantener una consistencia entre el Camel Cases o el Snake Case, revisar que se usa en Python

#### Digraph

In [61]:
def isLeaf(node, dag : nx.DiGraph):
    return dag.out_degree(node) == 0

def isRoot(node, dag : nx.DiGraph):
    return dag.in_degree(node) == 0

#### Classify the nodes in the DAG

In [62]:

class NodeState(Enum):
    ANCESTOR = 1
    DESCENDANT = 2
    UNRELATED = 3
    FEATURE = 4


#Classify the nodes into the category descendants, ancestors and unrelated. 

def classifyNodes(dag: nx.DiGraph, x_i : Any, nodes_classification : Dict[Any, NodeState]):
    ancestors = nx.ancestors(dag, x_i)
    descendants = nx.descendants(dag, x_i)
    unrelated_roots = []
    for node in dag.nodes():
        if node in ancestors:
            nodes_classification[node] = NodeState.ANCESTOR
        elif node in descendants:
            nodes_classification[node] = NodeState.DESCENDANT
        elif node == x_i:
            nodes_classification[node] = NodeState.FEATURE
        else:
            nodes_classification[node] = NodeState.UNRELATED
            parents =  list(dag.predecessors(node))
            parentIsAncestor = (parents[0] in ancestors) if len(parents) != 0 else False
            if isRoot(node, dag) or parentIsAncestor:
                unrelated_roots.append(node)


    return unrelated_roots


#### Equivalence Class definition

In [63]:
class NodePosition:

    def __init__(self, node_name, appears_after_xi : bool) -> None:
        self.node_name = node_name
        self.appears_after_xi = appears_after_xi
        self.relative_position = 'After' if appears_after_xi else 'Before'

    def isBefore(self):
        return not self.appears_after_xi
    
    def nodeName(self):
        return self.node_name

    def __str__(self):
        return f"({self.node_name}, {self.relative_position})"

class EquivalenceClass:

    def __init__(self, unrelated_node_position : Set[NodePosition], left_topo=1, right_topo=1):
        self.position = unrelated_node_position
        self.left_topo = left_topo
        self.right_topo = right_topo
    
    def nodes_before(self): #The nodes before x_i
        positions = filter(lambda node_pos : node_pos.isBefore(), self.position)
        return list(map(lambda p : p.nodeName(),  positions))

    def num_nodes_before(self): 
        return len(self.nodes_before())
    
    def num_nodes_after(self): #The nodes after x_i
        return len(self.position) - self.num_nodes_before()
    
    def classSize(self): #Number of topological orders
        return self.left_topo * self.right_topo
    
    def __str__(self):
        positions = [str(x) for x in self.position]
        return f"Equivalence Class (Nodes={positions}, Size={self.classSize()})"
    
    def addNodes(self, nodes : Set[NodePosition]):
        self.position = self.position.union(nodes)

    def addLeftTopo(self, leftTopos : int):
        self.left_topo *= leftTopos

### Auxiliary Functions

#### Equivalence Classes manipulation

In [64]:
def multinomial_coefficient(args) -> int:
    n = sum(args)
    coeff = 1
    for k in args:
        coeff *= comb(n, k, exact=True)
        n -= k
    return int(coeff)

def unionOf(equivalence_classes : List[EquivalenceClass]) -> EquivalenceClass:
    n = len(equivalence_classes)
    positions = set()
    nodes_before = [0]*n
    nodes_after = [0]*n
    left_topos = [0]*n
    right_topos = [0]*n
    for i,eq_class in enumerate(equivalence_classes):
        nodes_before[i] = eq_class.num_nodes_before()
        nodes_after[i] = eq_class.num_nodes_after()
        left_topos[i] = eq_class.left_topo
        right_topos[i] = eq_class.right_topo
        positions = positions.union(eq_class.position)

    left_size = multinomial_coefficient(nodes_before) * math.prod(left_topos)
    right_size = multinomial_coefficient(nodes_after) * math.prod(right_topos)
    return EquivalenceClass(positions, left_size, right_size)


def lastUnionOf(unr_classes : List[List[EquivalenceClass]], ancestors : List[Any], descendants : List[Any], descendantsTopoSorts : int) -> List[EquivalenceClass]:
    classes_combinations = list(itertools.product(*unr_classes)) #Generate al the possible combinations for each eqClass of each child with the eqClass of the other children. 
    
    descendants_position = set([NodePosition(des, True) for des in descendants])
    descendants_eqClass = EquivalenceClass(descendants_position,1, descendantsTopoSorts) 
    classes = []
    # All the descendants appear after the feature node, because all of them appear before it then it has 1 rigth_topo (the empty one). 

    #TODO: Try to find a better way to do it and not just to if it
    if (len(ancestors) == 0):
        classes = list(map(lambda mix : unionOf(list(mix)), classes_combinations))
        
        if len(descendants) != 0:
            classes = [unionOf([descendants_eqClass, mix]) for mix in classes]
        #TODO : Make it more efficient, so that I don't need to traverse the list again

    elif (len(descendants) == 0):
        for unr_class in classes_combinations:
            ascendantsCombinationsWithUnrelated = 1
            ascendants_position = set([NodePosition(des, False) for des in descendants])

            eqClass = unionOf(list(unr_class))

            eqClass.addNodes(ascendants_position)
            eqClass.addLeftTopo(ascendantsCombinationsWithUnrelated)
            classes.append(eqClass)
    else:
        for unr_class in classes_combinations:
            ascendantsCombinationsWithUnrelated = 1
            ascendants_position = set([NodePosition(des, False) for des in descendants])

            eqClass = unionOf(list(unr_class))

            eqClass.addNodes(ascendants_position)
            eqClass.addLeftTopo(ascendantsCombinationsWithUnrelated)
            classes.append(eqClass)
    
    return classes


#### Topological sorts

In [65]:
#Returns a hash that is the binary number which has 0 or 1 in the i-th position if the i-th unrelated node is before or after x_i

class TopoSortHasher:
    def __init__(self, nodes_classification: Dict[Any, NodeState]):
        self._unrelated_nodes_ids = self._get_unrelated_nodes(nodes_classification)

    def _get_unrelated_nodes(self, nodes_classification: Dict[Any, NodeState]):
        unrelated_nodes = list(filter(lambda node: nodes_classification[node] == NodeState.UNRELATED, nodes_classification.keys()))
        self._unrelated_nodes_ids = {node: i for i, node in enumerate(unrelated_nodes)}
        return self._unrelated_nodes_ids

    def hashTopoSort(self, topoSort: List[Any], x_i: Any) -> int:
        unrelated_nodes = self._unrelated_nodes_ids
        hash_val = 0
        for node in topoSort:
            if node == x_i:
                break
            if node in unrelated_nodes:
                hash_val += 2 ** unrelated_nodes[node]
        return hash_val


In [66]:
#Returns the size of the tree and the number of topological sorts

def sizeAndNumberOfTopoSorts(node, dag : nx.DiGraph):
    if isLeaf(node, dag):
        return 1,1
    
    childrenSubtreeSizes = []
    children_topoSorts = []

    
    for child in dag.successors(node):
        child_size, child_topos =  sizeAndNumberOfTopoSorts(child,dag)
        children_topoSorts.append(child_topos)
        childrenSubtreeSizes.append(child_size)
        

    topos = multinomial_coefficient(childrenSubtreeSizes) * math.prod(children_topoSorts)
    return sum(childrenSubtreeSizes)+1, topos

    
def topoSortsFrom(node, dag : nx.DiGraph):
   _, topos = sizeAndNumberOfTopoSorts(node, dag)
   return topos

### Equivalence Classes Formulas

#### Recursive Equivalence Class formula

In [67]:
def unrelatedEquivalenceClassesSizes(node, dag : nx.DiGraph) -> List[EquivalenceClass]:
    if isLeaf(node, dag):
        classes = []
        for x in [True, False]:
           classes.append(EquivalenceClass({NodePosition(node, x)}))
        return classes
    
    children_classes = list(map(lambda child : unrelatedEquivalenceClassesSizes(child,dag), dag.successors(node)))



    classes_combinations = list(itertools.product(*children_classes)) #Generate al the possible combinations for each eqClass of each child with the eqClass of the other children. 
    
    # All the equivalence classes will have this node in the left part. 

    classes = list(map(lambda mix : uniteChildrenAndAddParent(node, list(mix)), classes_combinations))

    allRight = unionOf(list(classes_combinations[0]))
    if allRight.num_nodes_before() == 0:
        allRight.addNodes({NodePosition(node, True)})
        classes.append(allRight)
        # If the parent is to the right, then all of the children should be after the feature node

        # TODO: I think that this kind of union (all in the right part) will always be the first element of classes_combinations, so we can just take the 
        # first element of classes_combination to do this, I need to review this.
    return classes

def uniteChildrenAndAddParent(node, equivalence_classes : List[EquivalenceClass]) -> EquivalenceClass:
        union = unionOf(equivalence_classes)
        union.addNodes({NodePosition(node, False)})

        return union

def recursiveEquivalenceClassesSizes(dag : nx.DiGraph, unr_roots : List[Any], hasher : TopoSortHasher, feature_node):
    unr_classes = list(map(lambda child : unrelatedEquivalenceClassesSizes(child,dag), unr_roots))
    ancestors = nx.ancestors(dag, feature_node)
    descendants = nx.descendants(dag, feature_node)

    descendantsTopoSorts = topoSortsFrom(feature_node, dag)
    recursiveClassesSizes = lastUnionOf(unr_classes, ancestors, descendants, descendantsTopoSorts)

    recursiveClassesSizes = hashEquivClasses(recursiveClassesSizes, hasher)
    return recursiveClassesSizes



def hashEquivClasses(equivClasses : List[EquivalenceClass], hasher : TopoSortHasher):
    hashedClasses = {}
    for eqClass in equivClasses:
        topoSortForClass = eqClass.nodes_before()
        hash = hasher.hashTopoSort(topoSortForClass, None)
        hashedClasses[hash] = [topoSortForClass,  eqClass.classSize()]
    
    return hashedClasses
        

#### Naive Equivalence Classes

In [68]:
def naiveEquivalenceClassesSizes(all_topo_sorts : List[List[Any]], nodes_classification: Dict[Any, NodeState], x_i : Any, hasher : TopoSortHasher):
      
   result = {}
   for topoSort in all_topo_sorts:
      hash = hasher.hashTopoSort(topoSort, x_i)
      actual_value = result.get(hash, [topoSort, 0])
      result[hash] = [actual_value[0], actual_value[1] + 1]
      # It has a representative of each class and the number of topological orders that are in that class.

   return result

#TODO: Here I don't need the topological orders with the descendants of x_i, I can remove them and then multiply the number of topological orders of each class. 
# To do this I just need to calculate the "merging" of this possible topological orders as I do in the dynamic approach. 


## Examples

#### Testing functions

In [69]:
def assertEquivalenceClassesForNode(dag: nx.DiGraph, feature_node, all_topo_sorts: List[List[Any]], timing_dict: Dict[str, Dict[str, float]]):
    
    nodes_classification = {}
    unr_roots = classifyNodes(dag, feature_node, nodes_classification)
    hasher = TopoSortHasher(nodes_classification)

    # Naive approach
    start_time = time.time()
    naiveClassesSizes = naiveEquivalenceClassesSizes(all_topo_sorts, nodes_classification, feature_node, hasher)
    end_time = time.time()
    timing_dict[feature_node]['naiveEquivalenceClassesSizes'] = end_time - start_time

    # Recursive approach
    start_time = time.time()
    recursiveClassesSizes = recursiveEquivalenceClassesSizes(dag, unr_roots, hasher, feature_node)
    end_time = time.time()
    timing_dict[feature_node]['equivalenceClassesSizes'] = end_time - start_time

    # Assert that each equivalence class has the same number of elements.
    for eqClassHash in naiveClassesSizes.keys():
        clSize1 = naiveClassesSizes[eqClassHash][1]
        clTopo1 = naiveClassesSizes[eqClassHash][0]
        try: 
            clSize2 = recursiveClassesSizes[eqClassHash][1]
            clTopo2 = recursiveClassesSizes[eqClassHash][0]
        except KeyError:
            raise AssertionError(f"The equivalence class {eqClassHash} is not present in the recursive approach. \n Naive Approach: Topo {clTopo1}, Size {clSize1} \n Feature Node: {feature_node}")
        if (clSize1 != clSize2):
            raise AssertionError(f"The sizes of the equivalence classes are not equal. \n Naive Approach: Topo {clTopo1}, Size {clSize1} \n Recursive Approach: Topo {clTopo2}, Size {clSize2} \n Feature Node: {feature_node}")

def assertEquivClassesForDag(dag: nx.DiGraph) -> Dict[str, float]:
    timing_dict = {}
    
    # Measure time for all topological sorts
    start_time = time.time()
    all_topo_sorts = list(nx.all_topological_sorts(dag))
    end_time = time.time()
    timing_dict['all_topological_sorts'] = end_time - start_time
    
    for node in list(dag.nodes):
        timing_dict[node] = {}
        assertEquivalenceClassesForNode(dag, node, all_topo_sorts, timing_dict)
    
    return timing_dict

def drawGraph(dag : nx.DiGraph):
    pos = nx.spring_layout(dag)
    nx.draw(dag, pos, with_labels=True)
    plt.show()

### Empty graph

In [70]:
numNodes = 4

emptyGraph = nx.DiGraph()
nodes = [i for i in range(numNodes)]
emptyGraph.add_nodes_from(nodes)

#drawGraph(emptyGraph)
res = assertEquivClassesForDag(emptyGraph)

### Naive Bayes

In [71]:
def naiveBayesOf(numNodes : int):
    naive_bayes = nx.DiGraph()
    nodes = [i for i in range(numNodes)]
    naive_bayes.add_nodes_from(nodes)
    root = list(naive_bayes.nodes)[0]
    for node in nodes:
        if node != root:
            naive_bayes.add_edge(root, node)
    return naive_bayes



numNodes = 7
naiveBayes = naiveBayesOf(numNodes)

#drawGraph(naiveBayes)

res = assertEquivClassesForDag(naiveBayes)

### Naive Bayes with Path

In [72]:
numNodes = 3
lengthOfPath = 1

naiveBayesWithPath = naiveBayesOf(numNodes)

for node in range(numNodes,numNodes+lengthOfPath):
    naiveBayesWithPath.add_node(node)
    naiveBayesWithPath.add_edge(node-1, node)

#drawGraph(naiveBayesWithPath)
res = assertEquivClassesForDag(naiveBayesWithPath)

AssertionError: The sizes of the equivalence classes are not equal. 
 Naive Approach: Topo [0, 2, 3, 1], Size 2 
 Recursive Approach: Topo [3], Size 1 
 Feature Node: 2