##### Imports

In [1]:
import networkx as nx
import itertools
from scipy.special import comb
import math
from enum import Enum
import matplotlib.pyplot as plt
from typing import Any, List, Dict
import time

#### Classify the nodes in the DAG

In [2]:

class NodeState(Enum):
    ANCESTOR = 1
    DESCENDANT = 2
    UNRELATED = 3
    FEATURE = 4


#Classify the nodes into the category descendants, ancestors and unrelated. 

def classifyNodes(dag: nx.DiGraph, x_i : Any, nodes_classification : Dict[Any, NodeState]):
    ancestors = nx.ancestors(dag, x_i)
    descendants = nx.descendants(dag, x_i)
    unrelated_roots = []
    for node in dag.nodes():
        if node in ancestors:
            nodes_classification[node] = NodeState.ANCESTOR
        elif node in descendants:
            nodes_classification[node] = NodeState.DESCENDANT
        elif node == x_i:
            nodes_classification[node] = NodeState.FEATURE
        else:
            nodes_classification[node] = NodeState.UNRELATED
            if isRoot(node, dag):
                unrelated_roots.append(node)

    return unrelated_roots


#### Equivalence Class definition

In [3]:
class nodePosition:

    def __init__(self, node_name, appears_after_xi : bool) -> None:
        self.node_name = node_name
        self.appears_after_xi = appears_after_xi

    def isBefore(self):
        return not self.appears_after_xi
    
    def nodeName(self):
        return self.node_name

    def __str__(self):
        return f"({self.node_name}, {self.appears_after_xi})"

class equivalenceClass:

    def __init__(self, unrelated_node_position, left_topo=1, right_topo=1):
        self.position = unrelated_node_position
        self.left_topo = left_topo
        self.right_topo = right_topo
    
    def nodes_before(self): #The nodes before x_i
        positions = filter(lambda node_pos : node_pos.isBefore(), self.position)
        return list(map(lambda p : p.nodeName(),  positions))

    def num_nodes_before(self): 
        return len(self.nodes_before())
    
    def num_nodes_after(self): #The nodes after x_i
        return len(self.position) - self.num_nodes_before()
    
    def classSize(self): #Number of topological orders
        return self.left_topo * self.right_topo
    
    def __str__(self):
        positions = [str(x) for x in self.position]
        return f"Equivalence Class (Nodes={positions}, Size={self.classSize()})"

#### Auxiliary Functions

In [4]:
def isLeaf(node, dag : nx.DiGraph):
    return dag.out_degree(node) == 0

def isRoot(node, dag : nx.DiGraph):
    return dag.in_degree(node) == 0

def multinomial_coefficient(args):
    n = sum(args)
    coeff = 1
    for k in args:
        coeff *= comb(n, k, exact=True)
        n -= k
    return coeff

def unionOf(equivalence_classes : List[equivalenceClass]):
    n = len(equivalence_classes)
    positions = set()
    nodes_before = [0]*n
    nodes_after = [0]*n
    left_topos = [0]*n
    right_topos = [0]*n
    for i,eq_class in enumerate(equivalence_classes):
        nodes_before[i] = eq_class.num_nodes_before()
        nodes_after[i] = eq_class.num_nodes_after()
        left_topos[i] = eq_class.left_topo
        right_topos[i] = eq_class.right_topo
        positions = positions.union(eq_class.position)

    left_size = multinomial_coefficient(nodes_before) * math.prod(left_topos)
    right_size = multinomial_coefficient(nodes_after) * math.prod(right_topos)
    return equivalenceClass(positions, left_size, right_size)

In [5]:
#Returns a hash that is the binary number which has 0 or 1 in the i-th position if the i-th unrelated node is before or after x_i

class TopoSortHasher:
    def __init__(self, nodes_classification: Dict[Any, NodeState]):
        self._unrelated_nodes_ids = self._get_unrelated_nodes(nodes_classification)

    def _get_unrelated_nodes(self, nodes_classification: Dict[Any, NodeState]):
        unrelated_nodes = list(filter(lambda node: nodes_classification[node] == NodeState.UNRELATED, nodes_classification.keys()))
        self._unrelated_nodes_ids = {node: i for i, node in enumerate(unrelated_nodes)}
        return self._unrelated_nodes_ids

    def hashTopoSort(self, topoSort: List[Any], x_i: Any) -> int:
        unrelated_nodes = self._unrelated_nodes_ids
        hash_val = 0
        for node in topoSort:
            if node == x_i:
                break
            if node in unrelated_nodes:
                hash_val += 2 ** unrelated_nodes[node]
        return hash_val


### Recursive Equivalence Class formula

In [6]:


def equivalenceClassesSizes(node, dag : nx.DiGraph, added_node):
    if isLeaf(node, dag):
        classes = []
        for x in [True, False]:
           classes.append(equivalenceClass({nodePosition(node, x)}))
        return classes
    
    children_classes = list(map(lambda child : equivalenceClassesSizes(child,dag, added_node), dag.successors(node)))

    if node != added_node:
        children_classes.append(equivalenceClass({nodePosition(node, False)}))
    # All the equivalence classes will have this node in the left part. 

    classes_combinations = list(itertools.product(*children_classes)) #Generate al the possible combinations for each eqClass of each child with the eqClass of the other children. 
    classes = list(map(lambda mix : unionOf(mix), classes_combinations))
    return classes

# If it has more than one root we need to add one root and connect it to all of the roots of the unrelated nodes. 

def setNewRoot(dag : nx.DiGraph, unrelated_roots : List[Any], root : Any):
    for ur in unrelated_roots:
            dag.add_edge(root, ur)


def removeNewRoot(dag : nx.DiGraph, root : Any):
    dag.remove_node(root)

    

def hashEquivClasses(equivClasses : List[equivalenceClass], hasher : TopoSortHasher):
    hashedClasses = {}
    for eqClass in equivClasses:
        topoSortForClass = eqClass.nodes_before()
        hash = hasher.hashTopoSort(topoSortForClass, None)
        hashedClasses[hash] = [topoSortForClass,  eqClass.classSize()]
    
    return hashedClasses
        
        
        

### Naive Equivalence Classes

In [7]:
def naiveEquivalenceClassesSizes(all_topo_sorts : List[List[Any]], nodes_classification: Dict[Any, NodeState], x_i : Any, hasher : TopoSortHasher):
      
   result = {}
   for topoSort in all_topo_sorts:
      hash = hasher.hashTopoSort(topoSort, x_i)
      actual_value = result.get(hash, [topoSort, 0])
      result[hash] = [actual_value[0], actual_value[1] + 1]
      # It has a representative of each class and the number of topological orders that are in that class.

   return result

#TODO: Here I don't need the topological orders with the descendants of x_i, I can remove them and then multiply the number of topological orders of each class. 
# To do this I just need to calculate the "merging" of this possible topological orders as I do in the dynamic approach. 


## Examples

#### Testing functions

In [8]:
def assertEquivalenceClassesForNode(dag: nx.DiGraph, feature_node, all_topo_sorts: List[List[Any]], timing_dict: Dict[str, Dict[str, float]]):
    
    nodes_classification = {}
    unr_roots = classifyNodes(dag, feature_node, nodes_classification)
    hasher = TopoSortHasher(nodes_classification)

    # Naive approach
    start_time = time.time()
    naiveClassesSizes = naiveEquivalenceClassesSizes(all_topo_sorts, nodes_classification, feature_node, hasher)
    end_time = time.time()
    timing_dict[feature_node]['naiveEquivalenceClassesSizes'] = end_time - start_time

    # Recursive approach
    new_root = 'r0'
    start_time = time.time()
    setNewRoot(dag, unr_roots, new_root)
    equivClassesSizes = equivalenceClassesSizes(new_root, dag, new_root)
    removeNewRoot(dag, new_root)
    recursiveClassesSizes = hashEquivClasses(equivClassesSizes, hasher)
    end_time = time.time()
    timing_dict[feature_node]['equivalenceClassesSizes'] = end_time - start_time

    # Assert that each equivalence class has the same number of elements.
    for eqClassHash in naiveClassesSizes.keys():
        assert naiveClassesSizes[eqClassHash][1] == recursiveClassesSizes[eqClassHash][1]

def assertEquivClassesForDag(dag: nx.DiGraph) -> Dict[str, float]:
    timing_dict = {}
    
    # Measure time for all topological sorts
    start_time = time.time()
    all_topo_sorts = list(nx.all_topological_sorts(dag))
    end_time = time.time()
    timing_dict['all_topological_sorts'] = end_time - start_time
    
    for node in list(dag.nodes):
        timing_dict[node] = {}
        assertEquivalenceClassesForNode(dag, node, all_topo_sorts, timing_dict)
    
    return timing_dict

### Empty graph example

In [None]:
dag = nx.DiGraph()
nodes = [i for i in range(9)]
dag.add_nodes_from(nodes)


assertEquivClassesForDag(dag)

### Naive Bayes Example

In [11]:
naive_bayes = nx.DiGraph()
nodes = [i for i in range(3)]
naive_bayes.add_nodes_from(nodes)
root = list(naive_bayes.nodes)[0]
for node in nodes:
    if node != root:
        naive_bayes.add_edge(root, node)

#assertEquivClassesForDag(naive_bayes)