In [56]:
from collections import defaultdict
from pprint import pprint

class Graph:
    def __init__(self, graph: dict[str, list], hueristic: dict[str, int], start_node: str):
        self.graph = graph
        self.hueristic = hueristic
        self.start_node = start_node
        self.parent = defaultdict(str)
        self.status = defaultdict(int)
        self.solution_graph = defaultdict(str)
    
    def apply_ao_star(self):
        self.ao_star(self.start_node, backtracking=False)
    
    def neighbors(self, node: str):
        return self.graph.get(node, [])
    
    def get_status(self, node: str):
        return self.status[node]
    
    def set_status(self, node: str, val: int):
        self.status[node] = val
    
    def get_hueristic(self, node: str):
        return self.hueristic[node]
    
    def set_hueristic(self, node: str, val: int):
        self.hueristic[node] = val
    
    def print_solution(self):
        print("FOR GRAPH SOLUTION, TRAVERSE THE GRAPH FROM THE START NODE:", self.start_node)
        print("------------------------------------------------------------")
        pprint(self.solution_graph)
        print("------------------------------------------------------------")
    
    def minimum_cost_child_nodes(self, node: str):
        minimum_cost = 0
        cost_to_child = defaultdict(list)
        flag = True
        
        for children in self.neighbors(node):
            cost = 0
            nodes = []
            for neib, weight in children:
                cost = cost + self.get_hueristic(neib) + weight
                nodes.append(neib)
            
            if flag:
                minimum_cost = cost
                cost_to_child[minimum_cost] = nodes
                flag = False
            else:
                if minimum_cost > cost:
                    minimum_cost = cost
                    cost_to_child[minimum_cost] = nodes
            
        return minimum_cost, cost_to_child[minimum_cost]
        
    def ao_star(self, node: str, backtracking: bool):
        if self.get_status(node) >= 0:
            minimum_cost, child_nodes = self.minimum_cost_child_nodes(node)
            self.set_hueristic(node, minimum_cost)
            self.set_status(node, len(child_nodes))
            solved = True
            
            for child in child_nodes:
                self.parent[child] = node
                if self.get_status(child) != -1:
                    solved = False
            
            if solved:
                self.set_status(node, -1)
                self.solution_graph[node] = child_nodes
            
            if node != self.start_node:
                self.ao_star(self.parent[node], True)
            
            if not backtracking:
                for child in child_nodes:
                    self.set_status(child, 0)
                    self.ao_star(child, False)

In [57]:
GRAPH = {
    'A': [[('B', 1), ('C', 1)], [('D', 1)]],
    'B': [[('G', 1)], [('H', 1)]],
    'C': [[('J', 1)]],
    'D': [[('E', 1), ('F', 1)]],
    'G': [[('I', 1)]]
}

In [58]:
HUERISTIC = {'A': 1, 'B': 6, 'C': 2, 'D': 12, 'E': 2, 
             'F': 1, 'G': 5, 'H': 7, 'I': 7, 'J': 1}

In [59]:
G = Graph(GRAPH, HUERISTIC, 'A')
G.apply_ao_star()
G.print_solution()

FOR GRAPH SOLUTION, TRAVERSE THE GRAPH FROM THE START NODE: A
------------------------------------------------------------
defaultdict(<class 'str'>,
            {'A': ['B', 'C'],
             'B': ['G'],
             'C': ['J'],
             'G': ['I'],
             'I': [],
             'J': []})
------------------------------------------------------------
