# Game environment

We first define the game environment here

In [111]:
import enum
import random
import copy
import math
import numpy as np
from typing import Hashable, List, Dict, Optional, Tuple, Callable
from collections import defaultdict
import matplotlib.pyplot as plt
from abc import abstractmethod, ABC
from dataclasses import dataclass



    
class State(ABC):
    @abstractmethod
    def apply_action(self, action: int, player: int):
        pass

    @abstractmethod
    def get_observation(self, player: int) -> Hashable:
        pass

    @abstractmethod
    def current_player(self) -> int:
        pass

    @abstractmethod
    def is_terminal(self) -> bool:
        pass

    @abstractmethod
    def legal_actions(self) -> List[int]:
        pass

    @abstractmethod
    def get_returns(self) -> dict[int, float]:
        pass

    @abstractmethod
    def __str__(self) -> str:
        pass

    @abstractmethod
    def __hash__(self) -> int:
        pass

class History(ABC):

    @abstractmethod
    def get_legal_actions(self) -> List[int]:
        pass

class ActionType(enum.IntEnum):
    FOLD = -1
    CHECK = 0
    # BET actions will be numbers from 1 to 100


class PlayerType(enum.IntEnum):
    CHANCE = -2  # For card dealing
    TERMINAL = -1  # Game is over
    # Regular players are 0, 1, 2, ...


@dataclass
class KuhnPokerObservation:
    player_hand: int
    player_index: int
    bets: List[int]
    current_player: int
    folded: List[bool]
    bet_amount: Optional[int]
    winner: Optional[int] = None

    def get_legal_actions(self) -> List[int]:
        if self.is_terminal():
            return []
        elif self.bet_amount is None:
            return [ActionType.CHECK] + list(range(1, KuhnPokerState.MAX_BET + 1))
        else:
            return [ActionType.FOLD, self.bet_amount]
        
    def is_terminal(self) -> bool:
        return self.current_player == PlayerType.TERMINAL
    
    def __hash__(self) -> int:
        return hash((self.player_hand, self.player_index, tuple(self.bets), self.current_player, tuple(self.folded), self.bet_amount, self.winner))

@dataclass
class KuhnPokerHistory:
    observations: List[KuhnPokerObservation]

    def get_legal_actions(self) -> List[int]:
        return self.observations[-1].get_legal_actions()
    
    def copy(self) -> 'KuhnPokerHistory':
        return copy.deepcopy(self)
    
    def __hash__(self) -> int:
        return hash(tuple(self.observations))
    
    def get_current_player(self) -> int:
        return self.observations[-1].current_player
    
    def get_last_observation(self) -> KuhnPokerObservation:
        return self.observations[-1]

class KuhnPokerState:
    DECK = tuple([0, 1, 2])  # Cards are dealt from this deck
    MAX_BET = 100  # Maximum bet amount

    def __init__(self):
        self.players_hands = random.sample(KuhnPokerState.DECK, 2)  # Deal two cards to two players
        self.bets = [1, 1]  # Individual player bets. Ante is 1
        self.folded = [False, False]  # Track if players have folded
        self.current_player_index = random.randint(0, 1)  # Random starting player
        self.bet_amount = None  # Amount of the bet. None if no bet has been made
        self.winner = None  # Winner of the game

    @staticmethod
    def init_from_observation(observation: KuhnPokerObservation, opponent_card: int) -> 'KuhnPokerState':
        state = KuhnPokerState()
        state.players_hands = [0, 0]
        state.players_hands[observation.player_index] = observation.player_hand
        state.players_hands[1 - observation.player_index] = opponent_card
        state.bets = observation.bets
        state.folded = observation.folded
        state.current_player_index = observation.current_player
        state.bet_amount = observation.bet_amount
        state.winner = observation.winner
        return state

    def apply_action(self, action: int, player: int):
        if self.current_player_index != player:
            raise ValueError("Not this player's turn!")

        if action not in self.get_legal_actions():
            raise ValueError("Illegal action!")

        if action == ActionType.FOLD:
            self.folded[player] = True
            self.current_player_index = PlayerType.TERMINAL
            self.winner = 1 - player  # The other player wins
        elif action == ActionType.CHECK:
            if self.bet_amount is None:  # First check
                self.bet_amount = 0
                self.current_player_index = 1 - player
            else:  # Both players checked
                self.current_player_index = PlayerType.TERMINAL
                self.determine_winner()
        else:  # BET or CALL
            if self.bet_amount is None:  # First bet
                self.bet_amount = action
                self.bets[player] += action
                self.current_player_index = 1 - player
            else:  # CALL
                self.bets[player] += self.bet_amount
                self.current_player_index = PlayerType.TERMINAL
                self.determine_winner()

    def determine_winner(self):
        if self.folded[0]:
            self.winner = 1
        elif self.folded[1]:
            self.winner = 0
        else:  # Compare hands
            if self.players_hands[0] > self.players_hands[1]:
                self.winner = 0
            else:
                self.winner = 1

    def get_observation(self, player: int) -> KuhnPokerObservation:
        return KuhnPokerObservation(
            player_hand=self.players_hands[player],
            player_index=player,
            bets=self.bets,
            current_player=self.current_player_index,
            folded=self.folded,
            bet_amount=self.bet_amount,
            winner=self.winner
        )

    def current_player(self) -> int:
        return self.current_player_index

    def is_terminal(self) -> bool:
        return self.current_player_index == PlayerType.TERMINAL

    def get_legal_actions(self) -> List[int]:
        if self.is_terminal():
            return []
        elif self.bet_amount is None:
            return [ActionType.CHECK] + list(range(1, KuhnPokerState.MAX_BET + 1))
        else:
            return [ActionType.FOLD, self.bet_amount]

    def get_returns(self) -> dict[int, float]:
        if not self.is_terminal():
            return {0: 0.0, 1: 0.0}  # No payoff if the game is not over

        if self.winner == 0:
            out = {0: float(self.bets[1]), 1: float(self.bets[1])}
            assert out[0] <= KuhnPokerState.MAX_BET + 1, f"Player 0 bet {self.bets[1]}"
            return out
        else:
            out = {0: -float(self.bets[0]), 1: float(self.bets[0])}
            assert out[1] <= KuhnPokerState.MAX_BET + 1, f"Player 1 bet {self.bets[0]}"
            return out

    def __str__(self) -> str:
        return (f"Hands: {self.players_hands}, Bets: {self.bets}, Folded: {self.folded}, "
                f"Current Player: {self.current_player_index}, Bet Amount: {self.bet_amount}, "
                f"Winner: {self.winner}")

    def __hash__(self) -> int:
        return hash((tuple(self.players_hands), tuple(self.bets), tuple(self.folded), self.current_player_index, self.bet_amount))

    def copy(self) -> 'KuhnPokerState':
        return copy.deepcopy(self)
    
    def get_pot(self) -> int:
        return sum(self.bets)

    

class Player:

    @abstractmethod
    def choose_action(self, history: KuhnPokerHistory, player_id: int) -> int:
        pass

    @abstractmethod
    def get_policy(self) -> Dict:
        pass

class RandomPlayer(Player):
    
    def choose_action(self, history: KuhnPokerHistory, player_id: int) -> int:
        return random.choice(history.get_legal_actions())

    def get_policy(self) -> Dict:
        return {}
    
class HumanPlayer(Player):

    def choose_action(self, history: KuhnPokerHistory, player_id: int) -> int:
        print(f"Player {player_id}, choose an action:\n {history.get_legal_actions()} \n\n Observation: \n {history.get_last_observation()}")
        return int(input())
    
    def get_policy(self) -> Dict:
        return {}

@dataclass
class SimulatorResults:
    player_0_wins: int
    player_1_wins: int
    draws: int
    average_pot: float
    player_0_episodes_by_card: Dict[int, int]
    player_1_episodes_by_card: Dict[int, int]
    player_0_conditional_winrate_by_card: Dict[int, float]
    player_1_conditional_winrate_by_card: Dict[int, float]
    player_0_average_profit: float
    player_1_average_profit: float
    player_0_average_profit_by_card: Dict[int, float]
    player_1_average_profit_by_card: Dict[int, float]
    total_episodes: int


class Simulator:
    def __init__(self, players: List[Player]):
        self.players = players

    def simulate_episodes(self, num_episodes: int) -> SimulatorResults:
        player_0_wins = 0
        player_1_wins = 0
        draws = 0
        total_pot = 0
        player_0_episodes_by_card = defaultdict(int)
        player_1_episodes_by_card = defaultdict(int)
        player_0_wins_by_card = defaultdict(int)
        player_1_wins_by_card = defaultdict(int)
        player_0_total_profit = 0
        player_1_total_profit = 0
        player_0_total_profit_by_card = defaultdict(float)
        player_1_total_profit_by_card = defaultdict(float)

        for _ in range(num_episodes):
            # Initialize the state and histories
            state = KuhnPokerState()
            player_histories = [
                KuhnPokerHistory(observations=[state.get_observation(player_id)])
                for player_id in range(len(self.players))
            ]

            # Track the current game
            while not state.is_terminal():
                current_player = state.current_player()
                current_history = player_histories[current_player]

                print('state', state)

                # Current player chooses an action
                action = self.players[current_player].choose_action(current_history, current_player)
                assert action in state.get_legal_actions()
                print(f'Player {current_player} chooses action {action} in state {state}')


                # Apply the action to the state
                state.apply_action(action, current_player)

                # Update all players' histories
                for player_id in range(len(self.players)):
                    observation = state.get_observation(player_id)
                    player_histories[player_id].observations.append(observation)

            # Calculate returns and update metrics
            returns = state.get_returns()
            pot = state.get_pot()
            total_pot += pot

            # Update per-player metrics
            player_0_card = state.players_hands[0]
            player_1_card = state.players_hands[1]
            player_0_episodes_by_card[player_0_card] += 1
            player_1_episodes_by_card[player_1_card] += 1
            player_0_total_profit += returns[0]
            player_1_total_profit += returns[1]
            player_0_total_profit_by_card[player_0_card] += returns[0]
            player_1_total_profit_by_card[player_1_card] += returns[1]

            if returns[0] > 0:
                player_0_wins += 1
                player_0_wins_by_card[player_0_card] += 1
            elif returns[1] > 0:
                player_1_wins += 1
                player_1_wins_by_card[player_1_card] += 1
            else:
                draws += 1

        # Calculate conditional win rates and average profits
        player_0_conditional_winrate_by_card = {
            card: player_0_wins_by_card[card] / player_0_episodes_by_card[card]
            if player_0_episodes_by_card[card] > 0 else 0.0
            for card in KuhnPokerState.DECK
        }
        player_1_conditional_winrate_by_card = {
            card: player_1_wins_by_card[card] / player_1_episodes_by_card[card]
            if player_1_episodes_by_card[card] > 0 else 0.0
            for card in KuhnPokerState.DECK
        }
        player_0_average_profit_by_card = {
            card: player_0_total_profit_by_card[card] / player_0_episodes_by_card[card]
            if player_0_episodes_by_card[card] > 0 else 0.0
            for card in KuhnPokerState.DECK
        }
        player_1_average_profit_by_card = {
            card: player_1_total_profit_by_card[card] / player_1_episodes_by_card[card]
            if player_1_episodes_by_card[card] > 0 else 0.0
            for card in KuhnPokerState.DECK
        }

        return SimulatorResults(
            player_0_wins=player_0_wins,
            player_1_wins=player_1_wins,
            draws=draws,
            average_pot=total_pot / num_episodes,
            player_0_episodes_by_card=player_0_episodes_by_card,
            player_1_episodes_by_card=player_1_episodes_by_card,
            player_0_conditional_winrate_by_card=player_0_conditional_winrate_by_card,
            player_1_conditional_winrate_by_card=player_1_conditional_winrate_by_card,
            player_0_average_profit=player_0_total_profit / num_episodes,
            player_1_average_profit=player_1_total_profit / num_episodes,
            player_0_average_profit_by_card=player_0_average_profit_by_card,
            player_1_average_profit_by_card=player_1_average_profit_by_card,
            total_episodes=num_episodes
        )



In [62]:
# simulate random player vs random player
simulator = Simulator([RandomPlayer(), RandomPlayer()])
results = simulator.simulate_episodes(10)
print(results)

[<ActionType.CHECK: 0>, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]
[<ActionType.FOLD: -1>, 80]
[<ActionType.CHECK: 0>, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100]
[<ActionType.FOLD: -1>, 5]
[<ActionType.CHECK: 0>, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25

In [108]:
# simulate random player vs human player
simulator = Simulator([RandomPlayer(), HumanPlayer()])
results = simulator.simulate_episodes(1)
print(results)

state Hands: [1, 0], Bets: [1, 1], Folded: [False, False], Current Player: 0, Bet Amount: None, Winner: None
Player 0 chooses action 58 in state Hands: [1, 0], Bets: [1, 1], Folded: [False, False], Current Player: 0, Bet Amount: None, Winner: None
state Hands: [1, 0], Bets: [59, 1], Folded: [False, False], Current Player: 1, Bet Amount: 58, Winner: None
Player 1, choose an action:
 [<ActionType.FOLD: -1>, 58] 

 Observation: 
 KuhnPokerObservation(player_hand=0, player_index=1, bets=[59, 1], current_player=1, folded=[False, False], bet_amount=58, winner=None)
Player 1 chooses action -1 in state Hands: [1, 0], Bets: [59, 1], Folded: [False, False], Current Player: 1, Bet Amount: 58, Winner: None
SimulatorResults(player_0_wins=1, player_1_wins=0, draws=0, average_pot=60.0, player_0_episodes_by_card=defaultdict(<class 'int'>, {1: 1, 0: 0, 2: 0}), player_1_episodes_by_card=defaultdict(<class 'int'>, {0: 1, 1: 0, 2: 0}), player_0_conditional_winrate_by_card={0: 0.0, 1: 1.0, 2: 0.0}, player_

## Players

Next we will define a couple of different players that use different strategies to play the game.

In [25]:
class ForwardSearchPlayer(Player):
    """Player that uses forward search with policy extraction."""
    
    def __init__(self, num_simulations: int = 100):
        self.num_simulations = num_simulations
        self.policy_dict = defaultdict(lambda: defaultdict(float))
        self.value_dict = defaultdict(lambda: defaultdict(float))
        
    def choose_action(self, history: KuhnPokerHistory, player_id: int) -> int:
        legal_actions = history.get_legal_actions()
        
        if not legal_actions:
            return ActionType.CHECK  # Default to CHECK if no legal actions
        
        # Get the player's current observation
        observation = history.observations[-1]
        my_card = observation.player_hand
        all_cards = set(KuhnPokerState.DECK)
        possible_opponent_cards = all_cards - {my_card}
        
        # Get game state key for policy recording
        game_state = self.get_state_key(history)
        
        best_action = None
        best_value = float('-inf')
        action_values = {}
        visit_counts = defaultdict(int)
        
        for action in legal_actions:
            total_value = 0
            visit_counts[action] = 0
            
            for opponent_card in possible_opponent_cards:
                # Reconstruct the state from the history for simulations
                sim_state = self.reconstruct_state_from_history(history, opponent_card)
                value = self.simulate_action(sim_state, action, my_card, opponent_card, player_id)
                total_value += value
                visit_counts[action] += 1
                
            avg_value = total_value / len(possible_opponent_cards)
            action_values[action] = avg_value
            
            if avg_value > best_value:
                best_value = avg_value
                best_action = action
        
        # Record policy and values
        total_visits = sum(visit_counts.values())
        for action, visits in visit_counts.items():
            self.policy_dict[game_state][action] = visits / total_visits
            self.value_dict[game_state][action] = action_values[action]
        
        return best_action
    
    def simulate_action(self, state: KuhnPokerState, action: int, my_card: int, 
                        opponent_card: int, player_id: int) -> float:
        state.apply_action(action, player_id)
        if state.is_terminal():
            returns = state.get_returns()
            return returns[player_id]
        
        total_value = 0
        opponent_actions = state.get_legal_actions()
        
        for opp_action in opponent_actions:
            sim_state = state.copy()
            sim_state.apply_action(opp_action, 1 - player_id)
            
            if sim_state.is_terminal():
                returns = sim_state.get_returns()
                value = returns[player_id]
            else:
                value = self.evaluate_position(sim_state, my_card, opponent_card, player_id)
            
            total_value += value / len(opponent_actions)
        
        return total_value
    
    def evaluate_position(self, state: KuhnPokerState, my_card: int, opponent_card: int, 
                          player_id: int) -> float:
        if my_card > opponent_card:
            return state.get_pot() * 0.8
        elif my_card < opponent_card:
            return -state.get_pot() * 0.8
        else:
            return -state.get_pot() * 0.1
    
    def get_state_key(self, history: KuhnPokerHistory) -> str:
        """Create a key representing the current game state."""
        observation = history.observations[-1]
        betting_history = f"Bets:{observation.bets}|Folded:{observation.folded}"
        return f"Card:{observation.player_hand}|{betting_history}"
    
    def reconstruct_state_from_history(self, history: KuhnPokerHistory, opponent_card: int) -> KuhnPokerState:
        """Reconstruct the state from the given history and opponent's card."""
        latest_observation = history.observations[-1]
        return KuhnPokerState.init_from_observation(latest_observation, opponent_card)
    
    def get_policy(self) -> Dict:
        """Return the extracted policy."""
        return dict(self.policy_dict)
    
    def get_values(self) -> Dict:
        """Return the action values."""
        return dict(self.value_dict)


In [None]:

class HistoryMCTSPlayer(Player):
    def __init__(self, num_simulations: int, exploration_constant: float):
        self.num_simulations = num_simulations
        self.exploration_constant = exploration_constant
        self.visit_counts = {}  # Maps (history, action) -> visit count
        self.action_value_estimates = {}  # Maps (history, action) -> Q value estimates for each player
        self.beliefs = {card: 1/len(KuhnPokerState.DECK) for card in KuhnPokerState.DECK}  # Maps history -> belief distribution over opponent cards
        self.has_visited = set()  # Set of histories that have been visited

    def estimate_values(self, state: KuhnPokerState,) -> dict[int, float]:
        '''
        Estimates the values of a state using random rollouts.
        '''
        print(f'Estimating values for state:\n {state}')
        while not state.is_terminal():
            legal_actions = state.get_legal_actions()
            print(f'The legal actions are {legal_actions} at state:\n {state}')
            action = random.choice(legal_actions)
            print(f'Player {state.current_player()} chooses action {action} in state {state}')
            state.apply_action(action, state.current_player())
            print(f'The state after the action is:\n {state}')
        returns = state.get_returns()
        return returns
    
    def explore(self, history: KuhnPokerHistory) -> int:
        '''
        Selects an action to explore using UCB for the current history and current player.
        '''
        current_player = history.get_current_player()
        # calculate total visits to this history
        total_visits = sum(self.visit_counts.get((history, a), 0) for a in history.get_legal_actions())
        best_action = None
        best_value = float('-inf')
        for action in history.get_legal_actions():
            q_value = self.action_value_estimates.get((history, action), defaultdict(float))[current_player]
            # if action has not been visited, return it
            if self.visit_counts.get((history, action), 0) == 0:
                return action
            # calculate UCB value
            ucb_value = q_value + self.exploration_constant * math.sqrt(math.log(total_visits) / self.visit_counts[(history, action)])
            if ucb_value > best_value:
                best_value = ucb_value
                best_action = action
        assert best_action is not None, "No best action found"
        return best_action
        

    def simulate(self, history: KuhnPokerHistory, state: KuhnPokerState,) -> dict[int, float]:
        '''
        Runs a Monte Carlo Tree Search simulation from the given state.

        Returns the simulated values of the state (one for each player).
        '''
        # if state is terminal, return the returns
        if state.is_terminal():
            return state.get_returns()
        
        # if history has not been visited at all, return the estimated values
        if history not in self.has_visited:
            self.has_visited.add(history)
            # set q values and visit counts to 0 for each action
            for action in history.get_legal_actions():
                self.action_value_estimates[(history, action)] = {0: 0, 1: 0}
                self.visit_counts[(history, action)] = 0
            return self.estimate_values(state)
        
        # if state is not terminal and history has been visited, select action to explore
        action = self.explore(history)

        # apply the selected action to the state
        state.apply_action(action, state.current_player())
        # simulate from the new state and new history
        new_history = history.copy()
        new_history.observations.append(state.get_observation(state.current_player()))
        new_q_values = self.simulate(new_history, state)
        # update the q values and visit counts
        self.visit_counts[(history, action)] += 1
        for player, value in new_q_values.items():
            self.action_value_estimates[(history, action)][player] += (value - self.action_value_estimates[(history, action)][player]) / (self.visit_counts[(history, action)])
        return new_q_values

    def select_random_state(self, history: KuhnPokerHistory, beliefs: Dict[int, float]) -> KuhnPokerState:
        '''
        Chooses a random opponent card according to the beliefs and reconstructs the state from the history.
        '''
        opponent_card = random.choices(list(beliefs.keys()), weights=beliefs.values())[0]
        last_observation = history.get_last_observation()
        return KuhnPokerState.init_from_observation(last_observation, opponent_card)

    def choose_action(self, history: KuhnPokerHistory, player_id: int) -> int:
        '''
        Runs MCTS and chooses the best action based on action value estimates.
        '''
        for _ in range(self.num_simulations):
            # select random state according to beliefs
            state = self.select_random_state(history, self.beliefs)
            # simulate from the selected state
            returns = self.simulate(history, state)
        
        # get best action for the current history by returning action with highest q value
        best_action = None
        best_value = float('-inf')
        for action in history.get_legal_actions():
            q_value = self.action_value_estimates.get((history, action), defaultdict(float))[player_id]
            if q_value > best_value:
                best_value = q_value
                best_action = action
        assert best_action is not None, "No best action found"
        return best_action


In [120]:
# Now we can simulate the game with the MCTS player
mcts_player = HistoryMCTSPlayer(num_simulations=100, exploration_constant=1.0)
simulator = Simulator([mcts_player, RandomPlayer()])
results = simulator.simulate_episodes(10)
print(results)

state Hands: [2, 1], Bets: [1, 1], Folded: [False, False], Current Player: 1, Bet Amount: None, Winner: None
Player 1 chooses action 71 in state Hands: [2, 1], Bets: [1, 1], Folded: [False, False], Current Player: 1, Bet Amount: None, Winner: None
state Hands: [2, 1], Bets: [1, 72], Folded: [False, False], Current Player: 0, Bet Amount: 71, Winner: None
Estimating values for state:
 Hands: [2, 0], Bets: [1, 72], Folded: [False, False], Current Player: 0, Bet Amount: 71, Winner: None
The legal actions are [<ActionType.FOLD: -1>, 71] at state:
 Hands: [2, 0], Bets: [1, 72], Folded: [False, False], Current Player: 0, Bet Amount: 71, Winner: None
Player 0 chooses action -1 in state Hands: [2, 0], Bets: [1, 72], Folded: [False, False], Current Player: 0, Bet Amount: 71, Winner: None
The state after the action is:
 Hands: [2, 0], Bets: [1, 72], Folded: [True, False], Current Player: -1, Bet Amount: 71, Winner: 1
Estimating values for state:
 Hands: [2, 1], Bets: [1, 72], Folded: [True, False

AssertionError: Player 1 bet 143