# Game environment

We first define the game environment here

In [23]:
import enum
import random
import copy
import math
import numpy as np
from typing import Hashable, List, Dict, Optional, Tuple, Callable
from collections import defaultdict
import matplotlib.pyplot as plt
from abc import abstractmethod, ABC
from dataclasses import dataclass



    
class State(ABC):
    @abstractmethod
    def apply_action(self, action: int, player: int):
        pass

    @abstractmethod
    def get_observation(self, player: int) -> Hashable:
        pass

    @abstractmethod
    def current_player(self) -> int:
        pass

    @abstractmethod
    def is_terminal(self) -> bool:
        pass

    @abstractmethod
    def legal_actions(self) -> List[int]:
        pass

    @abstractmethod
    def get_returns(self) -> dict[int, float]:
        pass

    @abstractmethod
    def __str__(self) -> str:
        pass

    @abstractmethod
    def __hash__(self) -> int:
        pass

class History(ABC):

    @abstractmethod
    def get_legal_actions(self) -> List[int]:
        pass

class ActionType(enum.IntEnum):
    FOLD = -1
    CHECK = 0
    # BET actions will be numbers from 1 to 100


class PlayerType(enum.IntEnum):
    CHANCE = -2  # For card dealing
    TERMINAL = -1  # Game is over
    # Regular players are 0, 1, 2, ...


@dataclass
class KuhnPokerObservation:
    player_hand: int
    player_index: int
    bets: List[int]
    current_player: int
    folded: List[bool]
    bet_amount: Optional[int]
    winner: Optional[int] = None

    def get_legal_actions(self) -> List[int]:
        if self.is_terminal():
            return []
        elif self.bet_amount is None:
            return [ActionType.CHECK] + list(range(1, KuhnPokerState.MAX_BET + 1))
        else:
            return [ActionType.FOLD, self.bet_amount]
        
    def is_terminal(self) -> bool:
        return self.current_player == PlayerType.TERMINAL

@dataclass
class KuhnPokerHistory:
    observations: List[KuhnPokerObservation]

    def get_legal_actions(self) -> List[int]:
        return self.observations[-1].get_legal_actions()
    
    def copy(self) -> 'KuhnPokerHistory':
        return copy.deepcopy(self)

class KuhnPokerState:
    DECK = tuple([0, 1, 2])  # Cards are dealt from this deck
    MAX_BET = 100  # Maximum bet amount

    def __init__(self):
        self.players_hands = random.sample(KuhnPokerState.DECK, 2)  # Deal two cards to two players
        self.bets = [1, 1]  # Individual player bets. Ante is 1
        self.folded = [False, False]  # Track if players have folded
        self.current_player_index = random.randint(0, 1)  # Random starting player
        self.bet_amount = None  # Amount of the bet. None if no bet has been made
        self.winner = None  # Winner of the game

    @staticmethod
    def init_from_observation(observation: KuhnPokerObservation, opponent_card: int) -> 'KuhnPokerState':
        state = KuhnPokerState()
        state.players_hands = [0, 0]
        state.players_hands[observation.player_index] = observation.player_hand
        state.players_hands[1 - observation.player_index] = opponent_card
        state.bets = observation.bets
        state.folded = observation.folded
        state.current_player_index = observation.current_player
        state.bet_amount = observation.bet_amount
        state.winner = observation.winner
        return state

    def apply_action(self, action: int, player: int):
        if self.current_player_index != player:
            raise ValueError("Not this player's turn!")

        if action not in self.get_legal_actions():
            raise ValueError("Illegal action!")

        if action == ActionType.FOLD:
            self.folded[player] = True
            self.current_player_index = PlayerType.TERMINAL
            self.winner = 1 - player  # The other player wins
        elif action == ActionType.CHECK:
            if self.bet_amount is None:  # First check
                self.bet_amount = 0
                self.current_player_index = 1 - player
            else:  # Both players checked
                self.current_player_index = PlayerType.TERMINAL
                self.determine_winner()
        else:  # BET or CALL
            if self.bet_amount is None:  # First bet
                self.bet_amount = action
                self.bets[player] += action
                self.current_player_index = 1 - player
            else:  # CALL
                self.bets[player] += self.bet_amount
                self.current_player_index = PlayerType.TERMINAL
                self.determine_winner()

    def determine_winner(self):
        if self.folded[0]:
            self.winner = 1
        elif self.folded[1]:
            self.winner = 0
        else:  # Compare hands
            if self.players_hands[0] > self.players_hands[1]:
                self.winner = 0
            else:
                self.winner = 1

    def get_observation(self, player: int) -> KuhnPokerObservation:
        return KuhnPokerObservation(
            player_hand=self.players_hands[player],
            player_index=player,
            bets=self.bets,
            current_player=self.current_player_index,
            folded=self.folded,
            bet_amount=self.bet_amount,
            winner=self.winner
        )

    def current_player(self) -> int:
        return self.current_player_index

    def is_terminal(self) -> bool:
        return self.current_player_index == PlayerType.TERMINAL

    def get_legal_actions(self) -> List[int]:
        if self.is_terminal():
            return []
        elif self.bet_amount is None:
            return [ActionType.CHECK] + list(range(1, KuhnPokerState.MAX_BET + 1))
        else:
            return [ActionType.FOLD, self.bet_amount]

    def get_returns(self) -> dict[int, float]:
        if not self.is_terminal():
            return {0: 0.0, 1: 0.0}  # No payoff if the game is not over

        pot = sum(self.bets)
        if self.winner == 0:
            return {0: pot, 1: -pot}
        else:
            return {0: -pot, 1: pot}

    def __str__(self) -> str:
        return (f"Hands: {self.players_hands}, Bets: {self.bets}, Folded: {self.folded}, "
                f"Current Player: {self.current_player_index}, Bet Amount: {self.bet_amount}, "
                f"Winner: {self.winner}")

    def __hash__(self) -> int:
        return hash((tuple(self.players_hands), tuple(self.bets), tuple(self.folded), self.current_player_index, self.bet_amount))

    def copy(self) -> 'KuhnPokerState':
        return copy.deepcopy(self)
    
    def get_pot(self) -> int:
        return sum(self.bets)

    

class Player:

    @abstractmethod
    def choose_action(self, history: History, player_id: int) -> int:
        pass

    @abstractmethod
    def get_policy(self) -> Dict:
        pass

class RandomPlayer(Player):
    
    def choose_action(self, history: History, player_id: int) -> int:
        return random.choice(history.get_legal_actions())

    def get_policy(self) -> Dict:
        return {}

@dataclass
class SimulatorResults:
    player_0_wins: int
    player_1_wins: int
    draws: int
    average_pot: float
    player_0_wins_by_card: Dict[int, int]
    player_1_wins_by_card: Dict[int, int]
    player_0_average_profit: float
    player_1_average_profit: float
    player_0_average_profit_by_card: Dict[int, float]
    player_1_average_profit_by_card: Dict[int, float]
    total_episodes: int


class Simulator:
    def __init__(self, players: List[Player]):
        self.players = players

    def simulate_episodes(self, num_episodes: int) -> SimulatorResults:
        player_0_wins = 0
        player_1_wins = 0
        draws = 0
        average_pot = 0
        player_0_wins_by_card = defaultdict(int)
        player_1_wins_by_card = defaultdict(int)
        player_0_total_profit = 0
        player_1_total_profit = 0
        player_0_total_profit_by_card = defaultdict(float)
        player_1_total_profit_by_card = defaultdict(float)

        for _ in range(num_episodes):
            # Initialize the state and player-specific histories for each episode
            state = KuhnPokerState()
            player_histories = [KuhnPokerHistory(observations=[state.get_observation(player_id)]) for player_id in range(len(self.players))]

            while not state.is_terminal():
                current_player = state.current_player()
                current_history = player_histories[current_player]

                # Current player chooses an action based on their visible history
                action = self.players[current_player].choose_action(current_history, current_player)

                # Apply the action to the state
                state.apply_action(action, current_player)

                # Update all players' histories
                for player_id in range(len(self.players)):
                    observation = state.get_observation(player_id)
                    player_histories[player_id].observations.append(observation)

            # Calculate returns and update metrics
            returns = state.get_returns()
            average_pot += state.get_pot()
            if returns[0] > 0:
                player_0_wins += 1
                player_0_wins_by_card[state.players_hands[0]] += 1
            elif returns[1] > 0:
                player_1_wins += 1
                player_1_wins_by_card[state.players_hands[1]] += 1
            else:
                draws += 1
            player_0_total_profit += returns[0]
            player_1_total_profit += returns[1]
            player_0_total_profit_by_card[state.players_hands[0]] += returns[0]
            player_1_total_profit_by_card[state.players_hands[1]] += returns[1]

        return SimulatorResults(
            player_0_wins=player_0_wins,
            player_1_wins=player_1_wins,
            draws=draws,
            average_pot=average_pot / num_episodes,
            player_0_wins_by_card=player_0_wins_by_card,
            player_1_wins_by_card=player_1_wins_by_card,
            player_0_average_profit=player_0_total_profit / num_episodes,
            player_1_average_profit=player_1_total_profit / num_episodes,
            player_0_average_profit_by_card={k: v / player_0_wins_by_card[k] for k, v in player_0_total_profit_by_card.items()},
            player_1_average_profit_by_card={k: v / player_1_wins_by_card[k] for k, v in player_1_total_profit_by_card.items()},
            total_episodes=num_episodes
        )


In [24]:
# simulate random player vs random player
simulator = Simulator([RandomPlayer(), RandomPlayer()])
results = simulator.simulate_episodes(10000)
print(results)

SimulatorResults(player_0_wins=5137, player_1_wins=4863, draws=0, average_pot=77.3126, player_0_wins_by_card=defaultdict(<class 'int'>, {0: 842, 2: 2566, 1: 1729}), player_1_wins_by_card=defaultdict(<class 'int'>, {2: 2365, 0: 797, 1: 1701}), player_0_average_profit=2.6954, player_1_average_profit=-2.6954, player_0_average_profit_by_card={1: 8.491613649508386, 0: -195.19952494061758, 2: 68.83476227591582}, player_1_average_profit_by_card={2: 67.71797040169133, 1: -4.583774250440917, 0: -224.98117942283562}, total_episodes=10000)


## Players

Next we will define a couple of different players that use different strategies to play the game.

In [25]:
class ForwardSearchPlayer(Player):
    """Player that uses forward search with policy extraction."""
    
    def __init__(self, num_simulations: int = 100):
        self.num_simulations = num_simulations
        self.policy_dict = defaultdict(lambda: defaultdict(float))
        self.value_dict = defaultdict(lambda: defaultdict(float))
        
    def choose_action(self, history: KuhnPokerHistory, player_id: int) -> int:
        legal_actions = history.get_legal_actions()
        
        if not legal_actions:
            return ActionType.CHECK  # Default to CHECK if no legal actions
        
        # Get the player's current observation
        observation = history.observations[-1]
        my_card = observation.player_hand
        all_cards = set(KuhnPokerState.DECK)
        possible_opponent_cards = all_cards - {my_card}
        
        # Get game state key for policy recording
        game_state = self.get_state_key(history)
        
        best_action = None
        best_value = float('-inf')
        action_values = {}
        visit_counts = defaultdict(int)
        
        for action in legal_actions:
            total_value = 0
            visit_counts[action] = 0
            
            for opponent_card in possible_opponent_cards:
                # Reconstruct the state from the history for simulations
                sim_state = self.reconstruct_state_from_history(history, opponent_card)
                value = self.simulate_action(sim_state, action, my_card, opponent_card, player_id)
                total_value += value
                visit_counts[action] += 1
                
            avg_value = total_value / len(possible_opponent_cards)
            action_values[action] = avg_value
            
            if avg_value > best_value:
                best_value = avg_value
                best_action = action
        
        # Record policy and values
        total_visits = sum(visit_counts.values())
        for action, visits in visit_counts.items():
            self.policy_dict[game_state][action] = visits / total_visits
            self.value_dict[game_state][action] = action_values[action]
        
        return best_action
    
    def simulate_action(self, state: KuhnPokerState, action: int, my_card: int, 
                        opponent_card: int, player_id: int) -> float:
        state.apply_action(action, player_id)
        if state.is_terminal():
            returns = state.get_returns()
            return returns[player_id]
        
        total_value = 0
        opponent_actions = state.get_legal_actions()
        
        for opp_action in opponent_actions:
            sim_state = state.copy()
            sim_state.apply_action(opp_action, 1 - player_id)
            
            if sim_state.is_terminal():
                returns = sim_state.get_returns()
                value = returns[player_id]
            else:
                value = self.evaluate_position(sim_state, my_card, opponent_card, player_id)
            
            total_value += value / len(opponent_actions)
        
        return total_value
    
    def evaluate_position(self, state: KuhnPokerState, my_card: int, opponent_card: int, 
                          player_id: int) -> float:
        if my_card > opponent_card:
            return state.get_pot() * 0.8
        elif my_card < opponent_card:
            return -state.get_pot() * 0.8
        else:
            return -state.get_pot() * 0.1
    
    def get_state_key(self, history: KuhnPokerHistory) -> str:
        """Create a key representing the current game state."""
        observation = history.observations[-1]
        betting_history = f"Bets:{observation.bets}|Folded:{observation.folded}"
        return f"Card:{observation.player_hand}|{betting_history}"
    
    def reconstruct_state_from_history(self, history: KuhnPokerHistory, opponent_card: int) -> KuhnPokerState:
        """Reconstruct the state from the given history and opponent's card."""
        latest_observation = history.observations[-1]
        return KuhnPokerState.init_from_observation(latest_observation, opponent_card)
    
    def get_policy(self) -> Dict:
        """Return the extracted policy."""
        return dict(self.policy_dict)
    
    def get_values(self) -> Dict:
        """Return the action values."""
        return dict(self.value_dict)
