## Game tree plotting

This notebook uses networkx to generate a node-graph visualisation of the games state space from a given starting position

Due to the rapidly growing number of reachable states, it is only recommended to be used with states that represent a game that is close to being finished.

In [None]:
import networkx as nx
import matplotlib.pylab as plt
from cardgame import Game
from typing import Optional

We load a sensible game to draw the full state space, the example provided here generates a tree with 143 nodes

In [None]:
save = '2♥3♦8♣5♣A♣3♠/7♠6♣7♦A♦5♠2♦/6♦K♠????2♠3♥/5♥5♦7♣4♠6♥A♥/K♥K♣K♦4♦4♣2♣/??8♥8♠6♠8♦7♥//4♥A♠3♣//5261770853308142601498873863'
state = Game.load(save)
state

### Tree and layout generation

To create the graph visualisation we create some basic recursive functions to crawl of the game's state space, adding nodes and edges as we go.

To represent moves to a facedown card which have a purely random outcome we add extra distinguished nodes that are "Chance" nodes that neither player can influence the outcome of.

In [None]:
def get_game_tree(game: Game, simplified: bool=True) -> nx.DiGraph:
    """Generate a node graph for reachable game states
    
    Generates a networkx DiGraph representing the reachable states of the cardgame from
    a given starting point. The DiGraph nodes include properties to track the depth 
    (number of moves to reach), the node type, and which player's turn it is.
    
    Args:
        game: a cardgame.Game instance.
        simplified: Contract branchless chains in the graph, defaults to True.
    
    Returns:
        A networkx DiGraph representation of the state space.
    """

    graph = nx.DiGraph()
    state_queue = {(None, game, 0)}
    while state_queue:
        predecessor, curr, depth = state_queue.pop()
        node_name = curr.save()
        if curr.legal_moves:
            graph.add_node(node_name, **{"type": "position", "depth": depth, "is_player_1": curr.is_p1_turn})
            for move in curr.all_moves():
                if len(move) > 1:
                    cnode_label = node_name + '_' + str(move[0].marker)
                    graph.add_node(cnode_label, **{"type": "chance"})
                    graph.add_edge(node_name, cnode_label)
                    for res in move:
                        if simplified:
                            while len(res.legal_moves) == 1 and len(list(res.all_moves())[0]) == 1:
                                res = next(res.all_moves())[0]
                        state_queue.add((cnode_label, res, depth + 1))
                else:
                    res = move[0]
                    if simplified:
                        while len(res.legal_moves) == 1 and len(list(res.all_moves())[0]) == 1:
                            res = next(res.all_moves())[0]
                    state_queue.add((node_name, res, depth + 1))
        else:
            graph.add_node(node_name, **{"type": "position", "score": curr.score, "depth": depth})
        if predecessor:
            graph.add_edge(predecessor, node_name)
    return graph

def tree_layout(graph: nx.DiGraph, root: str) -> dict[str, tuple[float, float]]:
    """Generate a hierarchical tree layout for a rooted graph.

    Generates a hierarchical layout for a rooted networkx.DiGraph positioning the root node
    at the top centre, dividing space in layers by the number of leaf nodes in each branch.
    
    Args:
        graph: a directed graph that represents a rooted tree.
        root: the name of the root node in the graph.
    
    Returns:
        A networkx style layout dictionary.
    """
    
    
    
    pos = {}
    pos[root] = (0.5, 0)
    children = list(graph[root])
    num_children = len(children)
    leaf_count = 0
    if not children:
        return pos, 1
    subtrees = []
    for child in children:
        subtree, subleaf_count = tree_layout(graph, child)
        subtrees.append((subtree, subleaf_count))
        leaf_count += subleaf_count
    curr_leaf_count = 0
    for subtree_pos, subleaf_count in subtrees:
        pos.update({k: (((v1 * subleaf_count) + curr_leaf_count)/leaf_count, v2 - 1) for k, (v1, v2) in subtree_pos.items()})
        curr_leaf_count += subleaf_count
    return pos, leaf_count

def draw_tree(graph: nx.DiGraph, root: str, figsize: Optional[tuple[int, int]]=None, best_path: bool=False) -> None:
    """Render a node graph that represents the state space for the card game

    Draws and displays a networkx graph to represent the reachable game states, including colors to
    represent which nodes are chance nodes and which nodes/edges represent the best move.

    Args:
        graph: a directed graph that represents a rooted tree.
        root: the name of the root node in the graph.
        figsize: The dimensions to draw the matplotlip figure, dynamically calculated by default.
        best_path: Renders the optimal moves in a different color, defaults to False.
    """
    
    
    maximiser_symbol, minimiser_symbol = "▲▼"
    
    node_clrs = {k: (0 if v == "position" else 1) for k, v in nx.get_node_attributes(graph, "type").items()}
    pos, leaf_count = tree_layout(graph, root)
    
    fig, ax = plt.subplots(figsize=figsize or (leaf_count/2, max(nx.get_node_attributes(graph, "depth").values()) + 1))
    

    if best_path:
        bp_nodes = nodes_in_optimal_path(Game.load(root))
        bp_graph = graph.subgraph(bp_nodes)
        bp_nodes = list(bp_graph.nodes)
        bp_edges = list(bp_graph.edges)
        nx.draw_networkx_edges(graph, pos, edgelist=bp_edges, edge_color="green", width=2)
        nx.draw_networkx_edges(graph, pos, edgelist=set(graph.edges()) - set(bp_edges))
        filtered_node_clrs = [node_clrs[k] for k in (set(bp_nodes) & set(node_clrs))]
        nx.draw_networkx_nodes(graph, pos, nodelist=bp_nodes, node_color=filtered_node_clrs, edgecolors="green", linewidths=2)
    else:
        nx.draw_networkx_edges(graph, pos)
    filtered_nodes = list(set(graph.nodes) - set(bp_nodes))
    filtered_node_clrs = [node_clrs[k] for k in filtered_nodes]
    nx.draw_networkx_nodes(graph, pos, nodelist=filtered_nodes, node_color=filtered_node_clrs)
    nx.draw_networkx_labels(graph, pos, labels=nx.get_node_attributes(G, "score"), font_color="white")
    player_symbols = {k: (maximiser_symbol if v else minimiser_symbol) for k, v in nx.get_node_attributes(graph, "is_player_1").items()}
    nx.draw_networkx_labels(graph, pos, labels=player_symbols, font_color="white")
    
    ax.axis("off")
    bound = 1 / leaf_count
    ax.set_xlim(-bound, 1 + bound)
    plt.show()

def nodes_in_optimal_path(game: Game) -> set[str]:
    """Identifies all positions that may occur in a game played optimally

    Given a starting state, enumerate all the possible positions that may occur if both players play
    with standard minimax objective.
    
    Args:
        game: The starting state for the cardgame.

    Returns:
        A set containing the names of all nodes that could occur in a game with optimal play.

    """
    base_node_name = game.save()
    nodes = {base_node_name}
    if not game.legal_moves:
        return nodes
    best_move = game.evaluate()["Deterministic optimal moves"][0]
    
    if not hasattr(best_move, "rank"):
        nodes.add(base_node_name+ '_' + str(best_move))
        for res in game.move(*best_move):
            nodes |= nodes_in_optimal_path(res)
    else:
        for row, col in game.legal_moves:
            if game.board[row][col] == best_move:
                new_game = game.move(row, col)[0]
                nodes |= nodes_in_optimal_path(new_game)
    return nodes

### Drawing the tree

We can now draw the tree, in this example we are using a simplified tree which means that any branchless chains of minimiser/maximiser nodes will be contracted to a single node

Chance nodes are in yellow, and minimiser and maximiser nodes are distinguished with `▲` for the maximising player and `▼` for the minimsing player.

Leaf nodes, representing terminal game states are labelled with the final score for that branch, and the path representing optimal play, including where it branches due to random chance, is highlighted in green.

In [None]:
G = get_game_tree(state, simplified=True)
draw_tree(G, save, best_path=True)