In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import time

def is_terminal(state):
    return all(h == 0 for h in state)

def get_moves(state):
    for i, heap in enumerate(state):
        if heap > 0:
            for remove in range(1, heap + 1):
                new_state = list(state)
                new_state[i] -= remove
                yield (state, tuple(new_state))

def minimax(state, maximizing=True, graph=None, parent=None, best_path=None):
    if graph is None:
        graph = nx.DiGraph()
    if best_path is None:
        best_path = {}
    node = str(state)
    if parent is not None:
        graph.add_edge(str(parent), node)
    if is_terminal(state):
        value = -1 if maximizing else 1
        graph.nodes[node]["label"] = f"{state}\nVal={value}"
        return value, graph, best_path
    if maximizing:
        value, best_child = float("-inf"), None
        for _, child in get_moves(state):
            child_val, graph, best_path = minimax(child, False, graph, state, best_path)
            if child_val > value:
                value, best_child = child_val, child
        graph.nodes[node]["label"] = f"{state}\nMAX={value}"
    else:
        value, best_child = float("inf"), None
        for _, child in get_moves(state):
            child_val, graph, best_path = minimax(child, True, graph, state, best_path)
            if child_val < value:
                value, best_child = child_val, child
        graph.nodes[node]["label"] = f"{state}\nMIN={value}"
    if best_child:
        best_path[state] = best_child
    return value, graph, best_path

def alpha_beta(state, alpha=float("-inf"), beta=float("inf"), maximizing=True, graph=None, parent=None, best_path=None):
    if graph is None:
        graph = nx.DiGraph()
    if best_path is None:
        best_path = {}
    node = str(state)
    if parent is not None:
        graph.add_edge(str(parent), node)
    if is_terminal(state):
        value = -1 if maximizing else 1
        graph.nodes[node]["label"] = f"{state}\nVal={value}"
        return value, graph, best_path
    if maximizing:
        value, best_child = float("-inf"), None
        for _, child in get_moves(state):
            child_val, graph, best_path = alpha_beta(child, alpha, beta, False, graph, state, best_path)
            if child_val > value:
                value, best_child = child_val, child
            alpha = max(alpha, value)
            if alpha >= beta:
                break
        graph.nodes[node]["label"] = f"{state}\nMAX={value}"
    else:
        value, best_child = float("inf"), None
        for _, child in get_moves(state):
            child_val, graph, best_path = alpha_beta(child, alpha, beta, True, graph, state, best_path)
            if child_val < value:
                value, best_child = child_val, child
            beta = min(beta, value)
            if alpha >= beta:
                break
        graph.nodes[node]["label"] = f"{state}\nMIN={value}"
    if best_child:
        best_path[state] = best_child
    return value, graph, best_path

def draw_tree(graph, best_path, root, title):
    pos = nx.spring_layout(graph, seed=42)
    labels = nx.get_node_attributes(graph, "label")
    path_edges, current = [], root
    while current in best_path:
        nxt = best_path[current]
        path_edges.append((str(current), str(nxt)))
        current = nxt
    plt.figure(figsize=(10, 7))
    nx.draw(graph, pos, with_labels=False, node_size=2000,
            node_color="lightblue", edge_color="gray")
    nx.draw_networkx_labels(graph, pos, labels, font_size=8)
    nx.draw_networkx_edges(graph, pos, edgelist=path_edges,
                           edge_color="red", width=2)
    plt.title(title)
    plt.show()

if __name__ == "__main__":
    initial_state = (3,4,5)

    t1 = time.time()
    val1, g1, path1 = minimax(initial_state, True)
    t_minimax = time.time() - t1

    t2 = time.time()
    val2, g2, path2 = alpha_beta(initial_state, float("-inf"), float("inf"), True)
    t_alphabeta = time.time() - t2

    draw_tree(g1, path1, initial_state, f"Full Minimax Tree (Nodes: {len(g1.nodes)})")
    draw_tree(g2, path2, initial_state, f"Alpha-Beta Pruned Tree (Nodes: {len(g2.nodes)})")

    print(f"Initial State       : {initial_state}")
    print(f"Minimax Result      : Value = {val1}, Nodes = {len(g1.nodes)}")
    print(f"Alpha-Beta Result   : Value = {val2}, Nodes = {len(g2.nodes)}")
    print("--------------------------------")
    print(f"Best Move Sequence (Alpha-Beta): {path2}")
    print("================================")
