In [13]:
import enum
import random
import copy
import math
import numpy as np
from typing import List, Dict, Optional, Tuple, Callable
from collections import defaultdict
import matplotlib.pyplot as plt


class ActionType(enum.IntEnum):
    PASS = 0
    # BET actions will be numbers from 1 to 100

class PlayerType(enum.IntEnum):
    CHANCE = -2  # For card dealing
    TERMINAL = -1  # Game is over
    # Regular players are 0, 1, 2, ...


class ForwardSearchPlayer:
    """Player that uses forward search with policy extraction."""
    
    def __init__(self, num_simulations: int = 100):
        self.num_simulations = num_simulations
        self.policy_dict = defaultdict(lambda: defaultdict(float))
        self.value_dict = defaultdict(lambda: defaultdict(float))
        
    def choose_action(self, game: 'KuhnPoker', player_id: int) -> int:
        legal_actions = game.legal_actions()
        if not legal_actions:
            return 0
            
        my_card = game.get_player_card(player_id)
        all_cards = set(range(game.num_players + 1))
        possible_opponent_cards = all_cards - {my_card}
        
        # Get game state key for policy recording
        game_state = self.get_state_key(game, player_id)
        
        best_action = None
        best_value = float('-inf')
        action_values = {}
        visit_counts = defaultdict(int)
        
        for action in legal_actions:
            total_value = 0
            visit_counts[action] = 0
            
            for opponent_card in possible_opponent_cards:
                sim_game = copy.deepcopy(game)
                value = self.simulate_action(sim_game, action, my_card, opponent_card, player_id)
                total_value += value
                visit_counts[action] += 1
                
            avg_value = total_value / len(possible_opponent_cards)
            action_values[action] = avg_value
            
            if avg_value > best_value:
                best_value = avg_value
                best_action = action
        
        # Record policy and values
        total_visits = sum(visit_counts.values())
        for action, visits in visit_counts.items():
            self.policy_dict[game_state][action] = visits / total_visits
            self.value_dict[game_state][action] = action_values[action]
        
        return best_action
    def simulate_action(self, game: 'KuhnPoker', action: int, my_card: int, 
                       opponent_card: int, player_id: int) -> float:
        game.apply_action(action)
        if game.is_terminal():
            returns = game.returns()
            return returns[player_id]
            
        total_value = 0
        opponent_actions = [0]  # fold
        bet_amounts = [20, 40, 60, 80, 100]
        opponent_actions.extend(bet_amounts)
        
        for opp_action in opponent_actions:
            sim_game = copy.deepcopy(game)
            sim_game.apply_action(opp_action)
            
            if sim_game.is_terminal():
                returns = sim_game.returns()
                value = returns[player_id]
            else:
                value = self.evaluate_position(sim_game, my_card, opponent_card, player_id)
            
            total_value += value / len(opponent_actions)
            
        return total_value
    
    def evaluate_position(self, game: 'KuhnPoker', my_card: int, opponent_card: int, 
                         player_id: int) -> float:
        if my_card > opponent_card:
            return game.pot * 0.8
        elif my_card < opponent_card:
            return -game.pot * 0.8
        else:
            return -game.pot * 0.1
    
    def get_state_key(self, game: 'KuhnPoker', player_id: int) -> str:
        """Create a key representing the current game state."""
        my_card = game.get_player_card(player_id)
        betting_history = ''.join('b' if h[1] > 0 else 'p' 
                                for h in game.history[game.num_players:])
        return f"Card:{my_card}|Hist:{betting_history}"
    
    def get_policy(self) -> Dict:
        """Return the extracted policy."""
        return dict(self.policy_dict)
    
    def get_values(self) -> Dict:
        """Return the action values."""
        return dict(self.value_dict)

def extract_policy(num_games: int = 1000) -> Tuple[Dict, Dict]:
    """Extract policy by simulating many games."""
    forward_player = ForwardSearchPlayer()
    random_player = RandomPlayer()
    
    # Simulate games to build policy
    for _ in range(num_games):
        game = KuhnPoker()
        
        # Deal cards
        available_cards = list(range(3))
        random.shuffle(available_cards)
        game.apply_action(available_cards[0])
        game.apply_action(available_cards[1])
        
        # Play the game
        while not game.is_terminal():
            current_player = game.current_player()
            if current_player == 0:
                action = forward_player.choose_action(game, current_player)
            else:
                action = random_player.choose_action(game, current_player)
            game.apply_action(action)
    
    return forward_player.get_policy(), forward_player.get_values()

def visualize_policy_simple(policy: Dict, values: Dict):
    """Create a simple visualization of the policy using matplotlib."""
    # Organize data for visualization
    states = list(policy.keys())
    
    # Create figure with subplots for different card values
    fig, axes = plt.subplots(3, 1, figsize=(15, 12))
    fig.suptitle('Kuhn Poker Forward Search Policy', fontsize=16)
    
    # Group states by card
    card_states = {
        '0': [],
        '1': [],
        '2': []
    }
    
    for state in states:
        card = state.split('|')[0].split(':')[1]
        card_states[card].append(state)
    
    # Plot each card's policy
    for card_idx, (card, card_state_list) in enumerate(card_states.items()):
        ax = axes[card_idx]
        ax.set_title(f'Card {card}')
        
        if not card_state_list:
            continue
            
        # Prepare data for this card
        y_positions = np.arange(len(card_state_list))
        y_labels = []
        
        # Plot bars for each state
        for idx, state in enumerate(card_state_list):
            history = state.split('|')[1].split(':')[1]
            y_labels.append(f'History: {history}')
            
            # Plot bars for each action
            x_position = 0
            for action, prob in policy[state].items():
                value = values[state][action]
                color = 'green' if value > 0 else 'red'
                alpha = min(abs(value) / 10, 1.0)  # Scale alpha by value
                
                # Create bar
                width = prob * 0.8  # Scale width by probability
                ax.barh(idx, width, left=x_position, 
                       color=color, alpha=alpha, height=0.5)
                
                # Add action label
                action_str = 'fold' if action == 0 else f'bet {action}'
                if prob > 0.1:  # Only show label if probability is significant
                    ax.text(x_position + width/2, idx,
                           f'{action_str}\n{prob:.2f}',
                           va='center', ha='center', fontsize=8)
                x_position += width
        
        ax.set_yticks(y_positions)
        ax.set_yticklabels(y_labels)
        ax.set_xlim(0, 1)
        ax.grid(True, alpha=0.3)
        
    plt.tight_layout()
    return fig

def print_detailed_policy(policy: Dict, values: Dict):
    """Print a detailed view of the policy."""
    print("\nDetailed Policy Analysis")
    print("=" * 80)
    
    # Group by card
    for card in ['0', '1', '2']:
        print(f"\nCard {card} Policies:")
        print("-" * 80)
        print(f"{'History':15} | {'Action':10} | {'Probability':12} | {'Value':8}")
        print("-" * 80)
        
        # Find all states for this card
        card_states = [s for s in policy.keys() 
                      if s.split('|')[0].split(':')[1] == card]
        
        for state in sorted(card_states):
            history = state.split('|')[1].split(':')[1]
            first = True
            
            # Sort actions by probability
            actions = sorted(policy[state].items(), 
                           key=lambda x: x[1], reverse=True)
            
            for action, prob in actions:
                value = values[state][action]
                action_str = 'fold' if action == 0 else f'bet {action}'
                
                if first:
                    print(f"{history:15} | {action_str:10} | {prob:11.2f} | {value:8.1f}")
                    first = False
                else:
                    print(f"{' ':15} | {action_str:10} | {prob:11.2f} | {value:8.1f}")
            
            print("-" * 80)


class RandomPlayer:
    """Player that makes random decisions."""
    
    def __init__(self, fold_probability: float = 0.5):
        self.fold_probability = fold_probability
        
    def choose_action(self, game: 'KuhnPoker', player_id: int) -> int:
        """Choose a random action based on defined probabilities."""
        legal_actions = game.legal_actions()
        
        # First determine if we fold
        if 0 in legal_actions and random.random() < self.fold_probability:
            return 0
            
        # If we didn't fold, choose a random bet
        betting_actions = [action for action in legal_actions if action > 0]
        if betting_actions:
            return random.choice(betting_actions)
        return 0

class Node:
    def __init__(self, game_state: Optional['KuhnPoker'] = None, parent=None, action=None):
        self.game_state = game_state
        self.parent = parent
        self.action = action
        self.children: Dict[int, Node] = {}
        self.visits = 0
        self.value = 0.0
        self.untried_actions = game_state.legal_actions() if game_state else []
        
    def ucb1(self, exploration_weight: float) -> float:
        if self.visits == 0:
            return float('inf')
        exploitation = self.value / self.visits
        exploration = exploration_weight * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration

class MCTSPlayer:
    def __init__(self, num_simulations: int = 100, exploration_weight: float = 1.414):
        self.num_simulations = num_simulations
        self.exploration_weight = exploration_weight
        
    def choose_action(self, game: 'KuhnPoker', player_id: int) -> int:
        if not game.legal_actions():
            return 0
            
        root = Node(game_state=copy.deepcopy(game))
        
        for _ in range(self.num_simulations):
            node = root
            sim_game = copy.deepcopy(game)
            
            # Selection
            while node.untried_actions == [] and node.children:
                node = self.select_child(node)
                sim_game.apply_action(node.action)
            
            # Expansion
            if node.untried_actions:
                action = random.choice(node.untried_actions)
                sim_game.apply_action(action)
                node.untried_actions.remove(action)
                child = Node(game_state=copy.deepcopy(sim_game), parent=node, action=action)
                node.children[action] = child
                node = child
            
            # Simulation
            while not sim_game.is_terminal():
                sim_action = random.choice(sim_game.legal_actions())
                sim_game.apply_action(sim_action)
            
            # Backpropagation
            returns = sim_game.returns()
            while node:
                node.visits += 1
                node.value += returns[player_id]
                node = node.parent
        
        # Choose best action based on visit count
        return max(root.children.items(), 
                  key=lambda x: (x[1].visits, x[1].value))[0]
    
    def select_child(self, node: Node) -> Node:
        return max(node.children.values(),
                  key=lambda child: child.ucb1(self.exploration_weight))

def simulate_mcts_vs_random(num_games: int = 1000) -> Dict:
    """Simulate games between MCTS and random players."""
    mcts_player = MCTSPlayer(num_simulations=50)  # Reduced simulations for faster execution
    random_player = RandomPlayer()
    
    stats = defaultdict(float)
    card_stats = defaultdict(lambda: defaultdict(int))
    
    for game_num in range(num_games):
        game = KuhnPoker()
        
        # Deal cards
        available_cards = list(range(3))
        random.shuffle(available_cards)
        game.apply_action(available_cards[0])  # MCTS player's card
        game.apply_action(available_cards[1])  # Random player's card
        
        mcts_card = game.get_player_card(0)
        random_card = game.get_player_card(1)
        
        # Play the game
        while not game.is_terminal():
            current_player = game.current_player()
            if current_player == 0:
                action = mcts_player.choose_action(game, current_player)
            else:
                action = random_player.choose_action(game, current_player)
            game.apply_action(action)
        
        # Record results
        returns = game.returns()
        stats['mcts_total_profit'] += returns[0]
        
        if returns[0] > 0:
            stats['mcts_wins'] += 1
            card_stats[mcts_card]['wins'] += 1
        elif returns[0] < 0:
            stats['random_wins'] += 1
            card_stats[mcts_card]['losses'] += 1
            
        stats['total_pots'] += game.pot
        
        if game_num % 100 == 0:
            win_rate = (stats['mcts_wins'] / (game_num + 1)) * 100
            print(f"Game {game_num}: MCTS Player Win Rate: {win_rate:.1f}%")
    
    # Calculate final statistics
    stats['num_games'] = num_games
    stats['mcts_win_rate'] = (stats['mcts_wins'] / num_games) * 100
    stats['average_pot'] = stats['total_pots'] / num_games
    stats['mcts_profit_per_game'] = stats['mcts_total_profit'] / num_games
    
    # Calculate card-specific statistics
    stats['card_stats'] = {}
    for card in range(3):
        wins = card_stats[card]['wins']
        total = card_stats[card]['wins'] + card_stats[card]['losses']
        if total > 0:
            stats['card_stats'][card] = {
                'win_rate': (wins / total) * 100,
                'total_games': total
            }
    
    return stats


class KuhnPoker:
    """Implementation of Kuhn Poker with automated player support."""
    
    # Game constants
    ANTE = 1
    INVALID_PLAYER = -1
    MIN_BET = 1
    MAX_BET = 100

    def __init__(self, num_players: int = 2):
        """Initialize a new game of Kuhn Poker.
        
        Args:
            num_players: Number of players (default: 2)
        """
        if not (2 <= num_players <= 10):
            raise ValueError("Number of players must be between 2 and 10")
            
        self.num_players = num_players
        self.random_player = RandomPlayer(fold_probability=0.5)
        self.reset()

    def reset(self):
        """Reset the game state to the beginning of a new game."""
        self.history = []  # List of (player, action) tuples
        self.card_dealt = [self.INVALID_PLAYER] * (self.num_players + 1)
        self.first_bettor = self.INVALID_PLAYER
        self.winner = self.INVALID_PLAYER
        self.pot = self.ANTE * self.num_players
        self.ante = [self.ANTE] * self.num_players
        self.current_bet = 0
        self.player_bets = [0] * self.num_players

    def current_player(self) -> int:
        """Returns the current player's id, or special values for chance/terminal."""
        if self.is_terminal():
            return PlayerType.TERMINAL
        
        if len(self.history) < self.num_players:
            return PlayerType.CHANCE
        else:
            return len(self.history) % self.num_players

    def get_player_card(self, player: int) -> Optional[int]:
        """Get the card dealt to a specific player.
        
        Args:
            player: Player index (0 to num_players-1)
            
        Returns:
            The card value or None if no card has been dealt
        """
        if len(self.history) > player:
            return self.history[player][1]
        return None

    def get_card_holder(self, card: int) -> Optional[int]:
        """Get which player holds a specific card.
        
        Args:
            card: Card value (0 to num_players)
            
        Returns:
            The player index who has this card, or None if not dealt
        """
        player = self.card_dealt[card]
        return None if player == self.INVALID_PLAYER else player

    def is_chance_node(self) -> bool:
        """Returns True if cards are still being dealt."""
        return len(self.history) < self.num_players

    def is_terminal(self) -> bool:
        """Returns True if the game is over."""
        return self.winner != self.INVALID_PLAYER

    def legal_actions(self) -> List[int]:
        """Returns a list of legal actions for the current player.
        
        Returns:
            - If dealing cards: list of undealt card values
            - If player acting: [PASS] + list of valid bet amounts
            - If game over: empty list
        """
        if self.is_terminal():
            return []
        
        if self.is_chance_node():
            return [card for card in range(len(self.card_dealt)) 
                   if self.card_dealt[card] == self.INVALID_PLAYER]
        else:
            actions = [ActionType.PASS]  # Always allow fold
            min_bet = max(self.MIN_BET, self.current_bet)
            actions.extend(range(min_bet, self.MAX_BET + 1))
            return actions

    def apply_action(self, action: int):
        """Apply an action to the game state.
        
        Args:
            action: 
                - If at chance node: the card value being dealt
                - If at player node: 0 for PASS, or bet amount (1-100)
        """
        current_player = self.current_player()
        
        if current_player == PlayerType.CHANCE:
            # Dealing phase - record the card being dealt
            player_receiving_card = len(self.history)
            self.card_dealt[action] = player_receiving_card
            self.history.append((current_player, action))
        else:
            # Betting phase
            if action > ActionType.PASS:  # Betting
                bet_amount = action
                if self.first_bettor == self.INVALID_PLAYER:
                    self.first_bettor = current_player
                self.pot += bet_amount
                self.player_bets[current_player] += bet_amount
                self.current_bet = bet_amount
                self.ante[current_player] += bet_amount
            
            self.history.append((current_player, action))
            
            if self.should_end_game():
                self.determine_winner()

    def should_end_game(self) -> bool:
        """Determines if the game should end based on the current state."""
        if len(self.history) <= self.num_players:
            return False
            
        num_actions = len(self.history) - self.num_players
        
        # Game ends if everyone passed
        if self.first_bettor == self.INVALID_PLAYER and num_actions == self.num_players:
            return True
            
        # Game ends if everyone responded to the first bet
        if (self.first_bettor != self.INVALID_PLAYER and 
            num_actions == self.num_players + self.first_bettor):
            return True
            
        # Game ends if someone passed after a bet
        if self.first_bettor != self.INVALID_PLAYER:
            last_action = self.history[-1][1]
            if last_action == ActionType.PASS:
                return True
                
        return False

    def determine_winner(self):
        """Determines the winner of the game."""
        if self.first_bettor == self.INVALID_PLAYER:
            # Nobody bet - highest card wins
            self.winner = self.card_dealt[self.num_players]
            if self.winner == self.INVALID_PLAYER:
                self.winner = self.card_dealt[self.num_players - 1]
        else:
            # Someone bet - highest remaining card who didn't fold wins
            for card in range(self.num_players, -1, -1):
                player = self.card_dealt[card]
                if player != self.INVALID_PLAYER and self.player_bets[player] == self.current_bet:
                    self.winner = player
                    break

    def returns(self) -> List[float]:
        """Returns a list of payoffs for each player."""
        if not self.is_terminal():
            return [0.0] * self.num_players
        
        returns = []
        for player in range(self.num_players):
            bet = self.player_bets[player] + self.ANTE
            returns.append(self.pot - bet if player == self.winner else -bet)
        return returns

    def auto_play_round(self, human_action: int) -> Tuple[List[float], str]:
        """Play one round with human action and automated second player.
        
        Args:
            human_action: The action chosen by the human player
            
        Returns:
            Tuple of (payoffs, game_log)
        """
        game_log = ""
        
        # Apply human's action
        self.apply_action(human_action)
        game_log += f"Player 0 {'folds' if human_action == 0 else f'bets {human_action}'}\n"
        
        # If game not over, let random player act
        if not self.is_terminal():
            bot_action = self.random_player.choose_action(KuhnPoker(),1)
            self.apply_action(bot_action)
            game_log += f"Player 1 {'folds' if bot_action == 0 else f'bets {bot_action}'}\n"
        
        if self.is_terminal():
            game_log += f"Winner: Player {self.winner}\n"
        
        returns = self.returns() if self.is_terminal() else [0, 0]
        return returns, game_log

    def __str__(self) -> str:
        """Returns a string representation of the current state."""
        # Show dealt cards
        dealt = " ".join(
            f"P{i}:{h[1]}" for i, h in enumerate(self.history[:min(len(self.history), self.num_players)])
        )
        
        # Show betting amounts
        if len(self.history) > self.num_players:
            betting = " Bets:" + " ".join(
                f"P{h[0]}:{'fold' if h[1] == 0 else h[1]}"
                for h in self.history[self.num_players:]
            )
            return dealt + betting
            
        return dealt

def simulate_forward_vs_random(num_games: int = 1000) -> Dict:
    """Simulate games between forward search and random players."""
    forward_player = ForwardSearchPlayer()
    random_player = RandomPlayer()
    
    stats = defaultdict(float)
    card_stats = defaultdict(lambda: defaultdict(int))
    
    for game_num in range(num_games):
        game = KuhnPoker()
        
        # Deal cards
        available_cards = list(range(3))
        random.shuffle(available_cards)
        game.apply_action(available_cards[0])  # Forward search player's card
        game.apply_action(available_cards[1])  # Random player's card
        
        forward_card = game.get_player_card(0)
        random_card = game.get_player_card(1)
        
        # Play the game
        while not game.is_terminal():
            current_player = game.current_player()
            if current_player == 0:
                action = forward_player.choose_action(game, current_player)
            else:
                action = random_player.choose_action(game, current_player)
            game.apply_action(action)
        
        # Record results
        returns = game.returns()
        stats['forward_total_profit'] += returns[0]
        
        if returns[0] > 0:
            stats['forward_wins'] += 1
            card_stats[forward_card]['wins'] += 1
        elif returns[0] < 0:
            stats['random_wins'] += 1
            card_stats[forward_card]['losses'] += 1
            
        stats['total_pots'] += game.pot
        
        if game_num % 100 == 0:
            win_rate = (stats['forward_wins'] / (game_num + 1)) * 100
            print(f"Game {game_num}: Forward Search Player Win Rate: {win_rate:.1f}%")
    
    # Calculate final statistics
    stats['num_games'] = num_games
    stats['forward_win_rate'] = (stats['forward_wins'] / num_games) * 100
    stats['average_pot'] = stats['total_pots'] / num_games
    stats['forward_profit_per_game'] = stats['forward_total_profit'] / num_games
    
    # Calculate card-specific statistics
    stats['card_stats'] = {}
    for card in range(3):
        wins = card_stats[card]['wins']
        total = card_stats[card]['wins'] + card_stats[card]['losses']
        if total > 0:
            stats['card_stats'][card] = {
                'win_rate': (wins / total) * 100,
                'total_games': total
            }
    
    return stats


class FixedWidthMCTSPlayer:
    def __init__(self, num_simulations: int = 100, width: int = 3):
        self.num_simulations = num_simulations
        self.width = width
        
    def choose_action(self, game: 'KuhnPoker', player_id: int) -> int:
        if not game.legal_actions():
            return 0
            
        legal_actions = game.legal_actions()
        action_values = defaultdict(float)
        action_visits = defaultdict(int)
        
        for _ in range(self.num_simulations):
            for action in legal_actions[:self.width]:  # Only explore top k actions
                sim_game = copy.deepcopy(game)
                sim_game.apply_action(action)
                
                # Random playout
                while not sim_game.is_terminal():
                    sim_action = random.choice(sim_game.legal_actions())
                    sim_game.apply_action(sim_action)
                
                # Update statistics
                returns = sim_game.returns()
                action_values[action] += returns[player_id]
                action_visits[action] += 1
        
        # Choose best action based on average value
        return max(action_values.items(),
                  key=lambda x: (x[1] / action_visits[x[0]] if action_visits[x[0]] > 0 else float('-inf')))[0]

def simulate_fixed_mcts_vs_random(num_games: int = 1000) -> Dict:
    """Simulate games between Fixed-Width MCTS and random players."""
    fixed_mcts_player = FixedWidthMCTSPlayer(num_simulations=50, width=3)
    random_player = RandomPlayer()
    
    stats = defaultdict(float)
    card_stats = defaultdict(lambda: defaultdict(int))
    
    for game_num in range(num_games):
        game = KuhnPoker()
        
        # Deal cards
        available_cards = list(range(3))
        random.shuffle(available_cards)
        game.apply_action(available_cards[0])
        game.apply_action(available_cards[1])
        
        fixed_mcts_card = game.get_player_card(0)
        
        # Play the game
        while not game.is_terminal():
            current_player = game.current_player()
            if current_player == 0:
                action = fixed_mcts_player.choose_action(game, current_player)
            else:
                action = random_player.choose_action(game, current_player)
            game.apply_action(action)
        
        # Record results
        returns = game.returns()
        stats['fixed_mcts_total_profit'] += returns[0]
        
        if returns[0] > 0:
            stats['fixed_mcts_wins'] += 1
            card_stats[fixed_mcts_card]['wins'] += 1
        elif returns[0] < 0:
            stats['random_wins'] += 1
            card_stats[fixed_mcts_card]['losses'] += 1
            
        stats['total_pots'] += game.pot
        
        if game_num % 100 == 0:
            win_rate = (stats['fixed_mcts_wins'] / (game_num + 1)) * 100
            print(f"Game {game_num}: Fixed-Width MCTS Win Rate: {win_rate:.1f}%")
    
    # Calculate final statistics
    stats['num_games'] = num_games
    stats['fixed_mcts_win_rate'] = (stats['fixed_mcts_wins'] / num_games) * 100
    stats['average_pot'] = stats['total_pots'] / num_games
    stats['fixed_mcts_profit_per_game'] = stats['fixed_mcts_total_profit'] / num_games
    
    # Calculate card-specific statistics
    stats['card_stats'] = {}
    for card in range(3):
        wins = card_stats[card]['wins']
        total = card_stats[card]['wins'] + card_stats[card]['losses']
        if total > 0:
            stats['card_stats'][card] = {
                'win_rate': (wins / total) * 100,
                'total_games': total
            }
    
    return stats


def play_interactive_game():
    """Play an interactive game against the random player."""
    game = KuhnPoker(2)
    
    # Deal cards
    available_cards = list(range(3))
    random.shuffle(available_cards)
    game.apply_action(available_cards[0])
    game.apply_action(available_cards[1])
    
    print(f"Your card is: {game.get_player_card(0)}")
    print("Choose your action:")
    print("0: Fold")
    print("1-100: Bet amount")
    
    try:
        human_action = int(input("Your action: "))
        if human_action not in game.legal_actions():
            print("Illegal action! Defaulting to fold.")
            human_action = 0
    except ValueError:
        print("Invalid input! Defaulting to fold.")
        human_action = 0
    
    returns, game_log = game.auto_play_round(human_action)
    
    print("\nGame Log:")
    print(game_log)
    print(f"\nOpponent's card was: {game.get_player_card(1)}")
    print(f"Final returns: {returns}")
    
    return game

if __name__ == "__main__":
    simulate = True
    if simulate:

        print("Simulating Forward Search vs Random Player...")
        stats = simulate_forward_vs_random(100)
        
        print("\nSimulation Results:")
        print(f"Number of games: {stats['num_games']}")
        print(f"Forward Search Win Rate: {stats['forward_win_rate']:.1f}%")
        print(f"Average Profit per Game: {stats['forward_profit_per_game']:.2f}")
        print(f"Average Pot Size: {stats['average_pot']:.2f}")
        
        print("\nCard-specific Performance:")
        for card, card_stats in stats['card_stats'].items():
            print(f"Card {card}:")
            print(f"  Win Rate: {card_stats['win_rate']:.1f}%")
            print(f"  Games Played: {card_stats['total_games']}")
        

    
    # Offer to play interactive game
    #print("\nWould you like to play a game against the random player? (y/n)")
    #if input().lower().startswith('y'):
    #    play_interactive_game()


Simulating Forward Search vs Random Player...
Game 0: Forward Search Player Win Rate: 100.0%

Simulation Results:
Number of games: 100
Forward Search Win Rate: 74.0%
Average Profit per Game: 12.43
Average Pot Size: 32.67

Card-specific Performance:
Card 0:
  Win Rate: 57.7%
  Games Played: 26
Card 1:
  Win Rate: 65.9%
  Games Played: 44
Card 2:
  Win Rate: 100.0%
  Games Played: 30


In [None]:
if __name__ == "__main__":
    print("Simulating MCTS vs Random Player...")
    mcts_stats = simulate_mcts_vs_random(100)
    
    print("\nMCTS Simulation Results:")
    print(f"Number of games: {mcts_stats['num_games']}")
    print(f"MCTS Win Rate: {mcts_stats['mcts_win_rate']:.1f}%")
    print(f"Average Profit per Game: {mcts_stats['mcts_profit_per_game']:.2f}")
    print(f"Average Pot Size: {mcts_stats['average_pot']:.2f}")
    
    print("\nCard-specific Performance:")
    for card, card_stats in mcts_stats['card_stats'].items():
        print(f"Card {card}:")
        print(f"  Win Rate: {card_stats['win_rate']:.1f}%")
        print(f"  Games Played: {card_stats['total_games']}")

Simulating MCTS vs Random Player...
Game 0: MCTS Player Win Rate: 0.0%

MCTS Simulation Results:
Number of games: 100
Fixed mcts win rate: 0.0%
Average Profit per Game: 0.49
Average Pot Size: 107.40

Card-specific Performance:
Card 0:
  Win Rate: 56.4%
  Games Played: 39
Card 1:
  Win Rate: 57.1%
  Games Played: 35
Card 2:
  Win Rate: 84.6%
  Games Played: 26


In [18]:
print("\nSimulating Fixed-Width MCTS vs Random Player...")
fixed_mcts_stats = simulate_fixed_mcts_vs_random(100)
print("\nFixed width Simulation Results:")
print(f"Number of games: {fixed_mcts_stats['num_games']}")
print(f"Fixed width MCTS Win Rate: {fixed_mcts_stats['fixed_mcts_win_rate']:.1f}%")
print(f"Average Profit per Game: {fixed_mcts_stats['mcts_profit_per_game']:.2f}")
print(f"Average Pot Size: {fixed_mcts_stats['average_pot']:.2f}")

print("\nCard-specific Performance:")
for card, card_stats in fixed_mcts_stats['card_stats'].items():
    print(f"Card {card}:")
    print(f"  Win Rate: {card_stats['win_rate']:.1f}%")
    print(f"  Games Played: {card_stats['total_games']}")


Simulating Fixed-Width MCTS vs Random Player...
Game 0: Fixed-Width MCTS Win Rate: 100.0%

Fixed width Simulation Results:
Number of games: 100
Fixed width MCTS Win Rate: 82.0%
Average Profit per Game: 0.00
Average Pot Size: 49.43

Card-specific Performance:
Card 0:
  Win Rate: 70.0%
  Games Played: 30
Card 1:
  Win Rate: 71.9%
  Games Played: 32
Card 2:
  Win Rate: 100.0%
  Games Played: 38


In [18]:
print("Extracting policy from forward search...")
forward_player = ForwardSearchPlayer()
policy, values = extract_policy(100)  # Assuming this function exists

# Print detailed policy analysis
print_detailed_policy(policy, values)


# Create visualization
print("\nCreating policy visualization...")
fig = visualize_policy_simple(policy, values)
plt.plot(fig)

Extracting policy from forward search...

Detailed Policy Analysis

Card 0 Policies:
--------------------------------------------------------------------------------
History         | Action     | Probability  | Value   
--------------------------------------------------------------------------------
                | fold       |        0.01 |    -41.5
                | bet 1      |        0.01 |     -1.5
                | bet 2      |        0.01 |     -2.3
                | bet 3      |        0.01 |     -3.2
                | bet 4      |        0.01 |     -4.0
                | bet 5      |        0.01 |     -4.8
                | bet 6      |        0.01 |     -5.7
                | bet 7      |        0.01 |     -6.5
                | bet 8      |        0.01 |     -7.3
                | bet 9      |        0.01 |     -8.2
                | bet 10     |        0.01 |     -9.0
                | bet 11     |        0.01 |     -9.8
                | bet 12     |        0.01 |    -1

Traceback (most recent call last):
  File "a:\spyder\spyder\lib\site-packages\IPython\core\interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "C:\Users\15483\AppData\Local\Temp\ipykernel_18376\564834427.py", line 12, in <module>
    plt.plot(fig)
  File "a:\spyder\spyder\lib\site-packages\matplotlib\pyplot.py", line 2761, in plot
    vmin=vmin, vmax=vmax,
  File "a:\spyder\spyder\lib\site-packages\matplotlib\axes\_axes.py", line 1649, in plot
  File "a:\spyder\spyder\lib\site-packages\matplotlib\axes\_base.py", line 1850, in add_line
    raise ValueError('argument must be among %s' %
  File "a:\spyder\spyder\lib\site-packages\matplotlib\axes\_base.py", line 1872, in _update_line_limits
    xsize = max(abs(txmax - txmin), 1e-30)
  File "a:\spyder\spyder\lib\site-packages\matplotlib\lines.py", line 1027, in get_path
    if self._invalidy or self._invalidx:
  File "a:\spyder\spyder\lib\site-packages\matplotlib\lines.py", line 675, in 

ImportError: cannot import name '_png' from 'matplotlib' (a:\spyder\spyder\lib\site-packages\matplotlib\__init__.py)

<Figure size 1500x1200 with 3 Axes>

In [19]:
play_interactive_game()


Your card is: 1
Choose your action:
0: Fold
1-100: Bet amount

Game Log:
Player 0 bets 50
Player 1 bets 1
Winner: Player 1


Opponent's card was: 2
Final returns: [-51, 51]


<__main__.KuhnPoker at 0x20cb513f4d0>