# Seldon Labs Game Engine V1

In [None]:
import itertools
import random
from networkx.drawing.nx_agraph import graphviz_layout  # Improved layout
import networkx as nx
import matplotlib.pyplot as plt

### Just for Show

In [None]:
class Strategy:
    """Represents a player's strategy, supporting both pure and mixed strategies."""
    def __init__(self, strategy_name, probabilities=None):
        self.strategy_name = strategy_name
        self.probabilities = probabilities or {}  # Maps actions to probabilities

    def choose_action(self):
        """Chooses an action based on the mixed strategy probabilities."""
        actions, probs = zip(*self.probabilities.items())
        return random.choices(actions, probs)[0]  # Probabilistic choice

In [None]:
class Node:
    """Represents a game state, supporting simultaneous and sequential moves."""
    def __init__(self, players):
        self.players = players  # A set of players acting at this node
        self.actions = {}  # Maps action tuples to child nodes
        self.payoff = None  # Stores outcome if terminal

    def add_action(self, actions, child_node):
        """Adds an action tuple leading to a child node."""
        self.actions[tuple(actions)] = child_node

class Game:
    """Represents a game supporting both mixed strategies and simultaneous moves."""
    def __init__(self, root_players):
        if isinstance(root_players, str):
            root_players = [root_players]  # Convert to list if single player
        self.root = Node(set(root_players))
        self.current_nodes = [self.root]  # Track nodes that need expansion
        self.players = set(root_players)

    def add_moves(self, players, actions_list):
        """
        Expands the game tree with moves for one or more players.
        - `players`: Single player (str) or multiple players (list).
        - `actions_list`: If multiple players, provide a list of lists of actions.
        """
        if isinstance(players, str):
            players = [players]  # Convert to list
        if isinstance(actions_list[0], str):
            actions_list = [actions_list]  # Wrap in list for consistency
        
        new_nodes = []
        for node in self.current_nodes:
            action_combinations = list(itertools.product(*actions_list))
            for actions in action_combinations:
                child_node = Node(set(players))
                node.add_action(actions, child_node)
                new_nodes.append(child_node)
        
        self.current_nodes = new_nodes  # Update frontier
        self.players.update(players)

    def add_outcomes(self, payoffs):
        """Assigns payoffs to terminal nodes."""
        if len(self.current_nodes) != len(payoffs):
            raise ValueError("Number of outcomes must match terminal nodes.")

        for node, payoff in zip(self.current_nodes, payoffs):
            node.payoff = payoff

    def display_tree(self, node=None, depth=0):
        """Recursively prints the game tree for debugging."""
        if node is None:
            node = self.root
        indent = "    " * depth
        if node.actions:
            print(f"{indent}{', '.join(node.players)} moves:")
            for actions, child in node.actions.items():
                action_str = " | ".join(actions)
                print(f"{indent}  ├─ {action_str}")
                self.display_tree(child, depth + 1)
        else:
            print(f"{indent}  (Payoff: {node.payoff})")

    def visualize_tree(self):
        """Visualizes the game tree with improved spacing using Graphviz."""
        graph = nx.DiGraph()
        node_labels = {}
        
        def add_edges(node, parent=None, action_label=None):
            """Recursively add nodes and edges to the graph."""
            node_id = id(node)  # Unique identifier
            label = f"{', '.join(node.players)}" if node.actions else f"Payoff: {node.payoff}"
            node_labels[node_id] = label

            if parent is not None:
                graph.add_edge(parent, node_id, action=action_label)

            for action, child in node.actions.items():
                add_edges(child, node_id, action)

        add_edges(self.root)

        # Use Graphviz DOT layout for better hierarchy
        pos = graphviz_layout(graph, prog="dot")

        # Draw graph
        plt.figure(figsize=(10, 6))
        nx.draw(graph, pos, with_labels=True, labels=node_labels, node_color="lightblue", edge_color="black", 
                node_size=3000, font_size=8, font_weight="bold", arrowsize=15)

        # Add action labels on edges
        edge_labels = {(u, v): data["action"] for u, v, data in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8, 
                                     bbox=dict(facecolor="white", edgecolor="none", alpha=0.8))

        plt.title("Game Tree Visualization")
        plt.show()

In [None]:
# Example Usage
game = Game(root_players=["China", "US"])  # Both move simultaneously
game.add_moves(players=["China", "US"], actions_list=[["Tariff", "No Tariff"], ["Tariff", "No Tariff"]])

outcomes = [
    (-6, -6),  # Both tariff
    (0, -10),  # China tariffs, US does not
    (-10, 0),  # China does not tariff, US does
    (-1, -1)   # Neither tariffs
]
game.add_outcomes(outcomes)

game.display_tree()

In [None]:
game.visualize_tree()

In [None]:
# Example Usage
game = Game(root_players=["China"])  # Both move simultaneously
game.add_moves(players=["US"], actions_list=[["Tariff", "No Tariff"]])
game.add_moves(players=["China"], actions_list=[["Tariff", "No Tariff"]])

outcomes = [
    (-6, -6),  # Both tariff
    (0, -10),  # China tariffs, US does not
    (-10, 0),  # China does not tariff, US does
    (-1, -1)   # Neither tariffs
]
game.add_outcomes(outcomes)

game.display_tree()

In [None]:
game.visualize_tree()

### Good Stuff

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout

class Node:
    """Represents a game state, supporting sequential moves."""
    def __init__(self, players=None):
        self.players = players if players else set()
        self.actions = {}  # Maps action names to child nodes
        self.payoff = None  # Stores outcome if terminal

    def add_action(self, action, child_node):
        """Adds an action leading to a child node."""
        self.actions[action] = child_node

class Game:
    """Represents a game theory structure with players, actions, and payoffs."""
    def __init__(self):
        self.root = Node()
        self.current_nodes = [self.root]  # Track leaf nodes for expansion
    
    def add_moves(self, player, actions):
        """Adds moves for a player at all current leaf nodes."""
        new_nodes = []
        for node in self.current_nodes:
            node.players.add(player)
            for action in actions:
                child_node = Node()
                node.add_action(action, child_node)
                new_nodes.append(child_node)
        self.current_nodes = new_nodes
    
    def add_outcomes(self, outcomes):
        """Assigns payoffs to the current leaf nodes."""
        if len(outcomes) != len(self.current_nodes):
            raise ValueError("Number of outcomes must match the number of terminal nodes.")
        for node, payoff in zip(self.current_nodes, outcomes):
            node.payoff = payoff
    
    def display_tree(self):
        """Recursively prints the game tree."""
        def recurse(node, depth=0):
            payoff_text = f", Payoff: {node.payoff}" if node.payoff is not None else ""
            print("  " * depth + f"Players: {node.players}{payoff_text}")
            for action, child in node.actions.items():
                print("  " * depth + f"Action: {action}")
                recurse(child, depth + 1)
        recurse(self.root)
    
    def visualize_tree(self):
        """Visualizes the game tree with improved spacing using Graphviz."""
        graph = nx.DiGraph()
        node_labels = {}
        
        def add_edges(node, parent=None, action_label=None):
            """Recursively add nodes and edges to the graph."""
            node_id = id(node)  # Unique identifier
            label = f"{', '.join(node.players)}" if node.actions else f"Payoff: {node.payoff}"
            node_labels[node_id] = label

            if parent is not None:
                graph.add_edge(parent, node_id, action=action_label)

            for action, child in node.actions.items():
                add_edges(child, node_id, action)

        add_edges(self.root)

        # Use Graphviz DOT layout for better hierarchy
        pos = graphviz_layout(graph, prog="dot")

        # Draw graph
        plt.figure(figsize=(10, 6))
        nx.draw(graph, pos, with_labels=True, labels=node_labels, node_color="lightblue", edge_color="black", 
                node_size=3000, font_size=8, font_weight="bold", arrowsize=15)

        # Add action labels on edges
        edge_labels = {(u, v): data["action"] for u, v, data in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8, 
                                     bbox=dict(facecolor="white", edgecolor="none", alpha=0.8))

        plt.title("Game Tree Visualization")
        plt.show()


In [None]:
# Create a new game instance
game = Game()

# Add moves for China and the US
game.add_moves(player="China", actions=["Tariff", "No Tariff"])
game.add_moves(player="US", actions=["Tariff", "No Tariff"])

# Define the outcomes/payoffs for the terminal nodes
outcomes = [
    (-6, -10),  # Both impose tariffs
    (0, -6),  # China tariffs, US does not
    (-10, 0),  # China does not tariff, US does
    (-1, -1),   # Neither imposes tariffs
]
game.add_outcomes(outcomes)

# Display the game tree structure
game.display_tree()

# Visualize the game tree
game.visualize_tree()

In [None]:
class Solver:
    def __init__(self, game):
        """Initialize with a game instance."""
        if not isinstance(game, Game):
            raise TypeError("Solver expects an instance of Game.")
        self.game = game
        self.equilibrium = {}

    def solve(self):
        """Base method to be overridden by specific solvers."""
        raise NotImplementedError("Solve method must be implemented in subclasses.")

    def get_equilibrium(self):
        """Return the computed equilibrium."""
        return self.equilibrium

class BackwardInductionSolver(Solver):
    def __init__(self, game):
        """Initialize the backward induction solver with a game."""
        super().__init__(game)
        self.optimal_actions = {}  # Dictionary to store optimal actions at each node
        self.node_values = {}      # Dictionary to store computed values for each node
        # Debug mode to print detailed information during solving
        self.debug = False

    def solve(self):
        """Solve the game using backward induction."""
        # Start from the root and solve recursively
        self._backward_induction(self.game.root)           
        return self.optimal_actions

    def _backward_induction(self, node, depth=0):
        """
        Recursive backward induction algorithm.
        Returns the value (payoff) of the current node.
        
        Parameters:
        node: Current game node
        depth: Current depth in the tree (for debugging)
        """
        node_id = id(node)
        
        # Base case: terminal node (no actions)
        if not node.actions:
            if self.debug:
                print("  " * depth + f"Terminal node with payoff: {node.payoff}")
            self.node_values[node_id] = node.payoff
            return node.payoff
        
        # Get the player making the decision at this node
        if not node.players:
            if self.debug:
                print("  " * depth + "No players at this node")
            return None
        
        current_player = next(iter(node.players))
        if self.debug:
            print("  " * depth + f"Player {current_player} at depth {depth}")
        
        # Determine index of current player in payoff tuples
        # For your specific game with players "China" and "US"
        player_idx = 0 if current_player == "China" else 1
        
        if self.debug:
            print("  " * depth + f"Player index: {player_idx}")
        
        # Get values of all children
        best_payoff = float('-inf')
        best_action = None
        best_value = None
        
        for action, child_node in node.actions.items():
            if self.debug:
                print("  " * depth + f"Trying action: {action}")
            
            child_value = self._backward_induction(child_node, depth + 1)
            
            if child_value is None:
                continue
            
            # Extract the current player's payoff from the tuple
            player_payoff = child_value[player_idx]
            
            if self.debug:
                print("  " * depth + f"Action {action} gives payoff {player_payoff} to {current_player}")
            
            if player_payoff > best_payoff:
                best_payoff = player_payoff
                best_action = action
                best_value = child_value
                
                if self.debug:
                    print("  " * depth + f"New best action: {best_action} with payoff {best_payoff}")
        
        # Store optimal action for this node
        if best_action is not None:
            self.optimal_actions[node_id] = best_action
            if self.debug:
                print("  " * depth + f"Optimal action for node {node_id}: {best_action}")
        
        # Store node value
        self.node_values[node_id] = best_value
        
        return best_value
    
    def get_subgame_perfect_equilibrium(self):
        """Return the subgame perfect equilibrium strategies."""
        if not self.optimal_actions:
            self.solve()
        
        # Format the equilibrium strategies by player
        equilibrium = {}
        
        # Traverse the tree to determine which nodes are reachable
        def traverse(node, path=[]):
            node_id = id(node)
            
            # Skip terminal nodes
            if not node.actions:
                return
            
            # For each player at this node, record their optimal action
            for player in node.players:
                if player not in equilibrium:
                    equilibrium[player] = {}
                
                # Store the optimal action for this player at this information set
                if node_id in self.optimal_actions:
                    equilibrium[player][tuple(path)] = self.optimal_actions[node_id]
            
            # Continue traversal with the optimal child
            if node_id in self.optimal_actions:
                optimal_action = self.optimal_actions[node_id]
                child_node = node.actions.get(optimal_action)
                if child_node:
                    traverse(child_node, path + [optimal_action])
        
        # Start traversal from the root
        traverse(self.game.root)
        
        self.equilibrium = equilibrium
        return equilibrium
    
    def visualize_equilibrium(self):
        """Visualize the game tree with equilibrium strategies highlighted."""
        if not self.optimal_actions:
            self.solve()
            
        # Create a new directed graph
        graph = nx.DiGraph()
        node_labels = {}
        
        # Add nodes and edges to the graph
        def add_nodes_and_edges(node, parent=None, action=None):
            node_id = id(node)
            
            # Create label for the node
            if not node.actions:  # Terminal node
                label = f"Payoff: {node.payoff}"
            else:
                players_str = ", ".join(sorted(node.players))
                optimal = self.optimal_actions.get(node_id, "N/A")
                label = f"{players_str}\nOptimal: {optimal}"
            
            node_labels[node_id] = label
            graph.add_node(node_id)
            
            # Add edge from parent if applicable
            if parent is not None:
                # Check if this edge is part of the equilibrium path
                parent_optimal = self.optimal_actions.get(parent)
                is_optimal = (parent_optimal == action)
                
                # Add the edge with attributes
                graph.add_edge(parent, node_id, 
                               action=action,
                               color="red" if is_optimal else "black",
                               width=2.0 if is_optimal else 1.0)
            
            # Process all children of this node
            for act, child in node.actions.items():
                add_nodes_and_edges(child, node_id, act)
        
        # Build the graph starting from the root
        add_nodes_and_edges(self.game.root)
        
        # Draw the graph
        plt.figure(figsize=(12, 8))
        pos = graphviz_layout(graph, prog="dot")
        
        # Draw nodes
        nx.draw_networkx_nodes(graph, pos, node_size=3000, node_color="lightblue")
        nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=10)
        
        # Draw edges with appropriate colors and widths
        for (u, v, data) in graph.edges(data=True):
            nx.draw_networkx_edges(graph, pos, edgelist=[(u, v)], 
                                  edge_color=data["color"], 
                                  width=data["width"])
        
        # Add edge labels
        edge_labels = {(u, v): data["action"] for u, v, data in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
        
        plt.title("Game Tree with Equilibrium Path Highlighted")
        plt.axis("off")
        plt.show()
        
    def print_equilibrium(self):
        """Print the equilibrium strategies in a readable format."""
        if not self.equilibrium:
            self.get_subgame_perfect_equilibrium()
        
        print("Subgame Perfect Equilibrium Strategies:")
        for player, strategies in self.equilibrium.items():
            print(f"Player {player}:")
            for path, action in strategies.items():
                path_str = " → ".join(["Root"] + list(path)) if path else "Root"
                print(f"  At '{path_str}', choose '{action}'")
        
        print("\nEquilibrium Path:")
        node = self.game.root
        path = ["Root"]
        
        while node and node.actions:
            node_id = id(node)
            if node_id in self.optimal_actions:
                next_action = self.optimal_actions[node_id]
                path.append(next_action)
                
                # Find the child with this action
                node = node.actions.get(next_action)
            else:
                break
        
        print(" → ".join(path))
        
        if node and node.payoff is not None:
            print(f"Terminal payoffs: {node.payoff}")

    def record_equilibrium(self): 
        """Create dictionary of the equilibrium.""" 
        if not self.equilibrium:
            self.get_subgame_perfect_equilibrium()

        player_actions = {}
        for player, strategies in self.equilibrium.items():
            player_actions[player] = {}
            for path, action in strategies.items():
                player_actions[player] = action
                
        return player_actions

In [None]:
solver = BackwardInductionSolver(game)
solver.solve() 
solver.print_equilibrium()
solver.visualize_equilibrium()

In [None]:
def debug_trade_war_game():
    """
    Create the trade war game and manually trace through the backward induction algorithm
    to verify it works correctly.
    """
    # Create a new game instance
    game = Game()
    
    # Add moves for China and the US
    game.add_moves(player="China", actions=["Tariff", "No Tariff"])
    game.add_moves(player="US", actions=["Tariff", "No Tariff"])
    
    # Define the outcomes/payoffs for the terminal nodes
    outcomes = [
        (-6, -6),  # Both impose tariffs
        (0, -10),  # China tariffs, US does not
        (-10, 0),  # China does not tariff, US does
        (-1, -1),   # Neither imposes tariffs
    ]
    game.add_outcomes(outcomes)
    
    # Print the game structure for verification
    print("Game Structure:")
    game.display_tree()
    
    print("\nTerminal nodes and their payoffs:")
    leaf_nodes = []
    
    def find_leaf_nodes(node):
        if not node.actions:
            leaf_nodes.append((node, node.payoff))
        else:
            for action, child in node.actions.items():
                find_leaf_nodes(child)
    
    find_leaf_nodes(game.root)
    
    for i, (node, payoff) in enumerate(leaf_nodes):
        print(f"Leaf {i+1}: Payoff = {payoff}")
    
    # Create solver with debug mode on
    print("\nRunning backward induction:")
    solver = BackwardInductionSolver(game)
    solver.debug = True
    solver.solve()
    
    print("\nOptimal actions by node ID:")
    for node_id, action in solver.optimal_actions.items():
        print(f"Node {node_id}: {action}")
    
    print("\nNode values:")
    for node_id, value in solver.node_values.items():
        print(f"Node {node_id}: {value}")
    
    # Get and print equilibrium
    equilibrium = solver.get_subgame_perfect_equilibrium()
    print("\nSubgame perfect equilibrium:")
    solver.print_equilibrium()
    
    return game, solver

# Run the diagnostic function
if __name__ == "__main__":
    game, solver = debug_trade_war_game()

### Test good stuff

In [None]:
class Node:
    """Represents a game state, supporting sequential moves."""
    def __init__(self, players=None):
        self.players = players if players else set()
        self.actions = {}  # Maps action names to child nodes
        self.payoff = None  # Stores outcome if terminal

    def add_action(self, action, child_node):
        """Adds an action leading to a child node."""
        self.actions[action] = child_node

class Game:
    """Represents a game theory structure with players, actions, and payoffs."""
    def __init__(self):
        self.root = Node()
        self.current_nodes = [self.root]  # Track leaf nodes for expansion
        self.players = []  # List to track players in the order they're added
        self.player_indices = {}  # Maps player names to their indices
    
    def add_player(self, player):
        """Add a player to the game if not already present."""
        if player not in self.player_indices:
            self.players.append(player)
            self.player_indices[player] = len(self.players) - 1
        return self.player_indices[player]
    
    def get_player_index(self, player):
        """Return the index of the player in payoff tuples."""
        if player not in self.player_indices:
            raise ValueError(f"Player {player} not found in game")
        return self.player_indices[player]
    
    def add_moves(self, player, actions):
        """Adds moves for a player at all current leaf nodes."""
        # Add player to the tracking system if not already added
        self.add_player(player)
        
        new_nodes = []
        for node in self.current_nodes:
            node.players.add(player)
            for action in actions:
                child_node = Node()
                node.add_action(action, child_node)
                new_nodes.append(child_node)
        self.current_nodes = new_nodes
    
    def add_outcomes(self, outcomes):
        """Assigns payoffs to the current leaf nodes."""
        if len(outcomes) != len(self.current_nodes):
            raise ValueError("Number of outcomes must match the number of terminal nodes.")
        for node, payoff in zip(self.current_nodes, outcomes):
            node.payoff = payoff
    
    def display_tree(self):
        """Recursively prints the game tree."""
        def recurse(node, depth=0):
            payoff_text = f", Payoff: {node.payoff}" if node.payoff is not None else ""
            print("  " * depth + f"Players: {node.players}{payoff_text}")
            for action, child in node.actions.items():
                print("  " * depth + f"Action: {action}")
                recurse(child, depth + 1)
        recurse(self.root)
    
    def visualize_tree(self):
        """Visualizes the game tree with improved spacing using Graphviz."""
        graph = nx.DiGraph()
        node_labels = {}
        
        def add_edges(node, parent=None, action_label=None):
            """Recursively add nodes and edges to the graph."""
            node_id = id(node)  # Unique identifier
            label = f"{', '.join(node.players)}" if node.actions else f"Payoff: {node.payoff}"
            node_labels[node_id] = label

            if parent is not None:
                graph.add_edge(parent, node_id, action=action_label)

            for action, child in node.actions.items():
                add_edges(child, node_id, action)

        add_edges(self.root)

        # Use Graphviz DOT layout for better hierarchy
        pos = graphviz_layout(graph, prog="dot")

        # Draw graph
        plt.figure(figsize=(10, 6))
        nx.draw(graph, pos, with_labels=True, labels=node_labels, node_color="lightblue", edge_color="black", 
                node_size=3000, font_size=8, font_weight="bold", arrowsize=15)

        # Add action labels on edges
        edge_labels = {(u, v): data["action"] for u, v, data in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8, 
                                     bbox=dict(facecolor="white", edgecolor="none", alpha=0.8))

        plt.title("Game Tree Visualization")
        plt.show()

In [None]:
class Solver:
    def __init__(self, game):
        """Initialize with a game instance."""
        if not isinstance(game, Game):
            raise TypeError("Solver expects an instance of Game.")
        self.game = game
        self.equilibrium = {}

    def solve(self):
        """Base method to be overridden by specific solvers."""
        raise NotImplementedError("Solve method must be implemented in subclasses.")

    def get_equilibrium(self):
        """Return the computed equilibrium."""
        return self.equilibrium
        
class BackwardInductionSolver(Solver):
    def __init__(self, game):
        """Initialize the backward induction solver with a game."""
        super().__init__(game)
        self.optimal_actions = {}  # Dictionary to store optimal actions at each node
        self.node_values = {}      # Dictionary to store computed values for each node
        # Debug mode to print detailed information during solving
        self.debug = False

    def solve(self):
        """Solve the game using backward induction."""
        # Start from the root and solve recursively
        self._backward_induction(self.game.root)           
        return self.optimal_actions

    def _backward_induction(self, node, depth=0):
        """
        Recursive backward induction algorithm.
        Returns the value (payoff) of the current node.
        
        Parameters:
        node: Current game node
        depth: Current depth in the tree (for debugging)
        """
        node_id = id(node)
        
        # Base case: terminal node (no actions)
        if not node.actions:
            if self.debug:
                print("  " * depth + f"Terminal node with payoff: {node.payoff}")
            self.node_values[node_id] = node.payoff
            return node.payoff
        
        # Get the player making the decision at this node
        if not node.players:
            if self.debug:
                print("  " * depth + "No players at this node")
            return None
        
        current_player = next(iter(node.players))
        if self.debug:
            print("  " * depth + f"Player {current_player} at depth {depth}")
        
        # Use the game's player indexing system instead of hardcoded values
        player_idx = self.game.get_player_index(current_player)
        
        if self.debug:
            print("  " * depth + f"Player index: {player_idx}")
        
        # Get values of all children
        best_payoff = float('-inf')
        best_action = None
        best_value = None
        
        for action, child_node in node.actions.items():
            if self.debug:
                print("  " * depth + f"Trying action: {action}")
            
            child_value = self._backward_induction(child_node, depth + 1)
            
            if child_value is None:
                continue
            
            # Extract the current player's payoff from the tuple
            player_payoff = child_value[player_idx]
            
            if self.debug:
                print("  " * depth + f"Action {action} gives payoff {player_payoff} to {current_player}")
            
            if player_payoff > best_payoff:
                best_payoff = player_payoff
                best_action = action
                best_value = child_value
                
                if self.debug:
                    print("  " * depth + f"New best action: {best_action} with payoff {best_payoff}")
        
        # Store optimal action for this node
        if best_action is not None:
            self.optimal_actions[node_id] = best_action
            if self.debug:
                print("  " * depth + f"Optimal action for node {node_id}: {best_action}")
        
        # Store node value
        self.node_values[node_id] = best_value
        
        return best_value

    def get_subgame_perfect_equilibrium(self):
        """Return the subgame perfect equilibrium strategies."""
        if not self.optimal_actions:
            self.solve()
        
        # Format the equilibrium strategies by player
        equilibrium = {}
        
        # Traverse the tree to determine which nodes are reachable
        def traverse(node, path=[]):
            node_id = id(node)
            
            # Skip terminal nodes
            if not node.actions:
                return
            
            # For each player at this node, record their optimal action
            for player in node.players:
                if player not in equilibrium:
                    equilibrium[player] = {}
                
                # Store the optimal action for this player at this information set
                if node_id in self.optimal_actions:
                    equilibrium[player][tuple(path)] = self.optimal_actions[node_id]
            
            # Continue traversal with the optimal child
            if node_id in self.optimal_actions:
                optimal_action = self.optimal_actions[node_id]
                child_node = node.actions.get(optimal_action)
                if child_node:
                    traverse(child_node, path + [optimal_action])
        
        # Start traversal from the root
        traverse(self.game.root)
        
        self.equilibrium = equilibrium
        return equilibrium
    
    def visualize_equilibrium(self):
        """Visualize the game tree with equilibrium strategies highlighted."""
        if not self.optimal_actions:
            self.solve()
            
        # Create a new directed graph
        graph = nx.DiGraph()
        node_labels = {}
        
        # Add nodes and edges to the graph
        def add_nodes_and_edges(node, parent=None, action=None):
            node_id = id(node)
            
            # Create label for the node
            if not node.actions:  # Terminal node
                label = f"Payoff: {node.payoff}"
            else:
                players_str = ", ".join(sorted(node.players))
                optimal = self.optimal_actions.get(node_id, "N/A")
                label = f"{players_str}\nOptimal: {optimal}"
            
            node_labels[node_id] = label
            graph.add_node(node_id)
            
            # Add edge from parent if applicable
            if parent is not None:
                # Check if this edge is part of the equilibrium path
                parent_optimal = self.optimal_actions.get(parent)
                is_optimal = (parent_optimal == action)
                
                # Add the edge with attributes
                graph.add_edge(parent, node_id, 
                               action=action,
                               color="red" if is_optimal else "black",
                               width=2.0 if is_optimal else 1.0)
            
            # Process all children of this node
            for act, child in node.actions.items():
                add_nodes_and_edges(child, node_id, act)
        
        # Build the graph starting from the root
        add_nodes_and_edges(self.game.root)
        
        # Draw the graph
        plt.figure(figsize=(12, 8))
        pos = graphviz_layout(graph, prog="dot")
        
        # Draw nodes
        nx.draw_networkx_nodes(graph, pos, node_size=3000, node_color="lightblue")
        nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=10)
        
        # Draw edges with appropriate colors and widths
        for (u, v, data) in graph.edges(data=True):
            nx.draw_networkx_edges(graph, pos, edgelist=[(u, v)], 
                                  edge_color=data["color"], 
                                  width=data["width"])
        
        # Add edge labels
        edge_labels = {(u, v): data["action"] for u, v, data in graph.edges(data=True)}
        nx.draw_networkx_edge_labels(graph, pos, edge_labels=edge_labels, font_size=8)
        
        plt.title("Game Tree with Equilibrium Path Highlighted")
        plt.axis("off")
        plt.show()
        
    def print_equilibrium(self):
        """Print the equilibrium strategies in a readable format."""
        if not self.equilibrium:
            self.get_subgame_perfect_equilibrium()
        
        print("Subgame Perfect Equilibrium Strategies:")
        for player, strategies in self.equilibrium.items():
            print(f"Player {player}:")
            for path, action in strategies.items():
                path_str = " → ".join(["Root"] + list(path)) if path else "Root"
                print(f"  At '{path_str}', choose '{action}'")
        
        print("\nEquilibrium Path:")
        node = self.game.root
        path = ["Root"]
        
        while node and node.actions:
            node_id = id(node)
            if node_id in self.optimal_actions:
                next_action = self.optimal_actions[node_id]
                path.append(next_action)
                
                # Find the child with this action
                node = node.actions.get(next_action)
            else:
                break
        
        print(" → ".join(path))
        
        if node and node.payoff is not None:
            print(f"Terminal payoffs: {node.payoff}")

    def record_equilibrium(self): 
        """Create dictionary of the equilibrium.""" 
        if not self.equilibrium:
            self.get_subgame_perfect_equilibrium()

        player_actions = {}
        for player, strategies in self.equilibrium.items():
            player_actions[player] = {}
            for path, action in strategies.items():
                player_actions[player] = action
                
        return player_actions

In [None]:
game = Game()

# Add players and their actions
game.add_moves("China", ["Tariff", "No Tariff"])
game.add_moves("US", ["Tariff", "No Tariff"])

# Add payoffs for all terminal nodes
# Order matters! The payoffs correspond to the order players were added
# For example: (China's payoff, US's payoff)
# Define the outcomes/payoffs for the terminal nodes

outcomes = [
    (-6, -6),  # Both impose tariffs
    (0, -10),  # China tariffs, US does not
    (-10, 0),  # China does not tariff, US does
    (-1, -1),   # Neither imposes tariffs
]
game.add_outcomes(outcomes)

# Create and run solver
solver = BackwardInductionSolver(game)
solver.solve()
solver.print_equilibrium()

# We can check player indices
print(f"China's index: {game.get_player_index('China')}")  # Should print 0
print(f"US's index: {game.get_player_index('US')}")       # Should print 1

In [None]:
solver.visualize_equilibrium()

In [None]:
game = Game()

# Add players and their actions
game.add_moves("China", ["Tariff", "No Tariff"])
game.add_moves("US", ["Tariff", "No Tariff"])
game.add_moves("Europe", ["Tariff", "No Tariff"])

# Add payoffs for all terminal nodes
# Order matters! The payoffs correspond to the order players were added
# For example: (China's payoff, US's payoff)
# Define the outcomes/payoffs for the terminal nodes

outcomes = [
    (-6, -6, -6),  # Both impose tariffs
    (0, -10, -10),  # China tariffs, US does not
    (-10, 0, 0),  # China does not tariff, US does
    (-1, -1, -1),   # Neither imposes tariffs
    (-6, -6, -6),  # Both impose tariffs
    (0, -10, -10),  # China tariffs, US does not
    (-10, 0, 0),  # China does not tariff, US does
    (-1, -1, -1),   # Neither imposes tariffs
]
game.add_outcomes(outcomes)

# Create and run solver
solver = BackwardInductionSolver(game)
solver.solve()
solver.print_equilibrium()

# We can check player indices
print(f"China's index: {game.get_player_index('China')}")  # Should print 0
print(f"US's index: {game.get_player_index('US')}")       # Should print 1

In [None]:
solver.visualize_equilibrium() 