In [1]:
pip install treys

Note: you may need to restart the kernel to use updated packages.


In [2]:
import random
from enum import Enum, auto

class Action(Enum):
    FOLD = 0
    CHECK = 1
    CALL = 2
    RAISE = 3
    BET = 4

class Player:
    def __init__(self, id, stack):
        self.id = id
        self.stack = stack
        self.hand = []
        self.current_bet = 0
        self.folded = False

class TexasHoldEm:
    def __init__(self, players, small_blind=10, big_blind=20):
        self.players = players  # List of Player objects
        self.small_blind = small_blind
        self.big_blind = big_blind
        self.deck = self.initialize_deck()
        self.board = []  # Community cards
        self.pot = 0
        self.current_bet = 0
        self.game_over = False
        self.dealer_index = 0  # Index of the dealer in self.players
        self.current_player_index = (self.dealer_index + 1) % len(self.players)  # Player to act
        self.betting_round = 'pre-flop'  # Can be 'pre-flop', 'flop', 'turn', 'river'
        self.round_bets = {}  # Tracks bets per player in the current round
        self.last_raiser = None  # Tracks the last player who raised

    def initialize_deck(self):
        suits = ['H', 'D', 'C', 'S']
        ranks = range(2, 15)  # 2-14 where 11-14 are J, Q, K, A
        deck = [(rank, suit) for rank in ranks for suit in suits]
        random.shuffle(deck)
        return deck

    def deal_hole_cards(self):
        for player in self.players:
            player.hand = [self.deck.pop(), self.deck.pop()]
            player.current_bet = 0
            player.folded = False
        self.round_bets = {player.id: 0 for player in self.players}

    def post_blinds(self):
        small_blind_player = self.players[(self.dealer_index + 1) % len(self.players)]
        big_blind_player = self.players[(self.dealer_index + 2) % len(self.players)]
        
        self._post_blind(small_blind_player, self.small_blind)
        self._post_blind(big_blind_player, self.big_blind)
        
        self.current_bet = self.big_blind
        self.last_raiser = big_blind_player.id

        # Set current player to the one after the big blind
        self.current_player_index = (self.dealer_index + 3) % len(self.players)

    def _post_blind(self, player, amount):
        player.stack -= amount
        player.current_bet = amount
        self.pot += amount
        self.round_bets[player.id] = amount

    def deal_flop(self):
        self.deck.pop()  # Burn card
        self.board.extend([self.deck.pop() for _ in range(3)])
        self.betting_round = 'flop'
        self.reset_bets()

    def deal_turn(self):
        self.deck.pop()  # Burn card
        self.board.append(self.deck.pop())
        self.betting_round = 'turn'
        self.reset_bets()

    def deal_river(self):
        self.deck.pop()  # Burn card
        self.board.append(self.deck.pop())
        self.betting_round = 'river'
        self.reset_bets()

    def reset_bets(self):
        self.current_bet = 0
        for player in self.players:
            player.current_bet = 0
        self.round_bets = {player.id: 0 for player in self.players}
        self.current_player_index = self.dealer_index  # Start with the player after the dealer
        self.last_raiser = None

    def get_current_player(self):
        while True:
            player = self.players[self.current_player_index]
            if not player.folded:
                return player
            self.current_player_index = (self.current_player_index + 1) % len(self.players)

    def get_available_actions(self, player):
        if player.current_bet < self.current_bet:
            # Player needs to call or fold
            actions = [Action.FOLD, Action.CALL, Action.RAISE]
        else:
            # Player can check or bet/raise
            if self.current_bet == 0:
                actions = [Action.CHECK, Action.BET]
            else:
                actions = [Action.CHECK, Action.RAISE]
        return actions

    def execute_action(self, player, action, raise_amount=0):
        if action == Action.FOLD:
            self.handle_fold(player)
        elif action == Action.CHECK:
            self.handle_check(player)
        elif action == Action.CALL:
            self.handle_call(player)
        elif action == Action.BET:
            self.handle_bet(player, raise_amount)
        elif action == Action.RAISE:
            self.handle_raise(player, raise_amount)
        else:
            raise ValueError("Invalid action")

        # Move to the next player
        self.current_player_index = (self.current_player_index + 1) % len(self.players)

    def handle_fold(self, player):
        player.folded = True
        print(f"Player {player.id} folds.")

    def handle_check(self, player):
        print(f"Player {player.id} checks.")

    def handle_call(self, player):
        call_amount = self.current_bet - player.current_bet
        player.stack -= call_amount
        player.current_bet += call_amount
        self.pot += call_amount
        self.round_bets[player.id] += call_amount
        print(f"Player {player.id} calls {call_amount}.")

    def handle_bet(self, player, amount):
        if amount <= 0 or amount > player.stack:
            raise ValueError("Invalid bet amount")
        player.stack -= amount
        player.current_bet += amount
        self.current_bet = player.current_bet
        self.pot += amount
        self.round_bets[player.id] += amount
        self.last_raiser = player.id
        print(f"Player {player.id} bets {amount}.")

    def handle_raise(self, player, amount):
        if amount <= 0 or amount > player.stack:
            raise ValueError("Invalid raise amount")
        call_amount = self.current_bet - player.current_bet
        total_amount = call_amount + amount
        player.stack -= total_amount
        player.current_bet += total_amount
        self.current_bet = player.current_bet
        self.pot += total_amount
        self.round_bets[player.id] += total_amount
        self.last_raiser = player.id
        print(f"Player {player.id} raises by {amount} to {player.current_bet}.")

    def is_round_over(self):
        # The betting round is over when all players have either called the current bet or folded
        active_players = [p for p in self.players if not p.folded]
        if len(active_players) == 1:
            return True  # Only one player remains
        for player in active_players:
            if player.id == self.last_raiser:
                continue  # Skip the last raiser
            if player.current_bet != self.current_bet:
                return False
        return True

    def progress_round(self):
        if self.betting_round == 'pre-flop':
            self.deal_flop()
        elif self.betting_round == 'flop':
            self.deal_turn()
        elif self.betting_round == 'turn':
            self.deal_river()
        elif self.betting_round == 'river':
            self.game_over = True  # Proceed to showdown
        else:
            raise ValueError("Invalid betting round")

    def is_game_over(self):
        # The game is over if only one player remains or all betting rounds are complete
        active_players = [p for p in self.players if not p.folded]
        if len(active_players) == 1:
            return True
        return self.game_over

    def determine_winner(self):
        # If only one player remains
        active_players = [p for p in self.players if not p.folded]
        if len(active_players) == 1:
            winner = active_players[0]
            winner.stack += self.pot
            print(f"Player {winner.id} wins the pot of {self.pot} by default.")
            self.pot = 0
            return

        # Showdown: compare hands
        from treys import Evaluator, Card
        evaluator = Evaluator()
        best_rank = None
        winners = []
        for player in active_players:
            hand = [Card.new(f"{self.rank_to_str(rank)}{suit}") for rank, suit in player.hand]
            board = [Card.new(f"{self.rank_to_str(rank)}{suit}") for rank, suit in self.board]
            rank = evaluator.evaluate(board, hand)
            if best_rank is None or rank < best_rank:
                best_rank = rank
                winners = [player]
            elif rank == best_rank:
                winners.append(player)
        # Split the pot among winners
        split_pot = self.pot / len(winners)
        for winner in winners:
            winner.stack += split_pot
            print(f"Player {winner.id} wins {split_pot} from the pot.")
        self.pot = 0

    def rank_to_str(self, rank):
        if rank == 14:
            return 'A'
        elif rank == 13:
            return 'K'
        elif rank == 12:
            return 'Q'
        elif rank == 11:
            return 'J'
        elif rank == 10:
            return 'T'
        else:
            return str(rank)

    def get_reward(self, player):
        # Define reward as the change in the player's stack
        return player.stack - 1000  # Assuming initial stack is 1000

    # Additional helper methods can be added as needed

In [3]:
class BeliefState:
    def __init__(self, observed_actions, public_cards, pot_size=0):
        self.observed_actions = observed_actions  # Sequence of actions taken
        self.public_cards = public_cards          # Community cards revealed
        self.private_cards = None                 # The AI's own hand
        self.pot_size = pot_size                  # Pot size

    def update(self, action, new_public_cards=None, pot_size=None):
        self.observed_actions.append(action)
        if new_public_cards is not None:
            self.public_cards = new_public_cards
        if pot_size is not None:
            self.pot_size = pot_size

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim

class ValueNetwork(nn.Module):
    def __init__(self, input_size):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 1)  # Output a scalar value
        )

    def forward(self, x):
        return self.fc(x)

class PolicyNetwork(nn.Module):
    def __init__(self, input_size, action_space):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, action_space)
        )

    def forward(self, x):
        return nn.Softmax(dim=-1)(self.fc(x))



MAX_FEATURE_LENGTH = 25  # Adjust this value based on your game dynamics

def extract_features(belief_state):
    # Convert actions to numerical representation
    action_features = [action.value for action in belief_state.observed_actions]
    
    # Flatten public cards
    public_card_features = [rank for (rank, suit) in belief_state.public_cards]
    
    # Combine features
    features = action_features + public_card_features
    
    # Pad or truncate to fixed size
    features = features[:MAX_FEATURE_LENGTH] + [0] * max(0, MAX_FEATURE_LENGTH - len(features))
    
    return torch.tensor(features, dtype=torch.float32)


In [5]:
# Global dictionary to store CFRNodes
cfr_nodes = {}

class CFRNode:
    def __init__(self, info_set, actions):
        self.info_set = info_set
        self.actions = actions
        self.regret_sum = {action: 0.0 for action in actions}
        self.strategy = {action: 1.0 / len(actions) for action in actions}
        self.strategy_sum = {action: 0.0 for action in actions}

    def get_strategy(self):
        normalizing_sum = 0.0
        for action in self.actions:
            self.strategy[action] = max(self.regret_sum[action], 0.0)
            normalizing_sum += self.strategy[action]
        if normalizing_sum > 0:
            for action in self.actions:
                self.strategy[action] /= normalizing_sum
        else:
            for action in self.actions:
                self.strategy[action] = 1.0 / len(self.actions)
        # Update strategy sum for averaging
        for action in self.actions:
            self.strategy_sum[action] += self.strategy[action]
        return self.strategy

def get_or_create_cfr_node(info_set, actions):
    if info_set not in cfr_nodes:
        cfr_nodes[info_set] = CFRNode(info_set, actions)
    return cfr_nodes[info_set]

def cfr(node, p0, p1):
    if node.is_terminal():
        return node.utility()
    info_set = node.get_info_set()
    actions = node.get_available_actions()
    cfr_node = get_or_create_cfr_node(info_set, actions)
    strategy = cfr_node.get_strategy()
    util = {}
    node_util = 0
    for action in actions:
        next_node = node.take_action(action)
        if node.current_player == 0:
            util[action] = -cfr(next_node, p0 * strategy[action], p1)
        else:
            util[action] = -cfr(next_node, p0, p1 * strategy[action])
        node_util += strategy[action] * util[action]
    for action in actions:
        regret = util[action] - node_util
        if node.current_player == 0:
            cfr_node.regret_sum[action] += p1 * regret
        else:
            cfr_node.regret_sum[action] += p0 * regret
    return node_util

In [6]:
num_players = 2  # Since we're working with a two-player game

In [7]:
# class GameNode:
#     def __init__(self, current_player, hands, board, history, num_players=2, 
#                  pot=0, player_stacks=None, player_bets=None, current_bet=0, 
#                  folded_players=None, betting_round='pre-flop', last_raiser=None):
#         self.current_player = current_player  # ID of the current player
#         self.hands = hands  # Dictionary mapping player IDs to their hands
#         self.board = board  # Community cards
#         self.history = history  # List of actions taken
#         self.num_players = num_players  # Total number of players
#         self.pot = pot  # Current size of the pot
#         self.player_stacks = player_stacks if player_stacks is not None else {player_id: 1000 for player_id in hands}
#         self.player_bets = player_bets if player_bets is not None else {player_id: 0 for player_id in hands}
#         self.current_bet = current_bet  # Highest bet in the current betting round
#         self.folded_players = folded_players if folded_players is not None else set()  # Players who have folded
#         self.betting_round = betting_round  # Current betting round ('pre-flop', 'flop', 'turn', 'river')
#         self.last_raiser = last_raiser  # ID of the last player who raised

    
#     def is_terminal(self):
#         # Implement logic to determine if the game has ended
#         pass
    
#     def utility(self):
#         # Calculate the utility (payoff) for the current player
#         pass
    
#     def get_info_set(self):
#         # As defined earlier
#         player_hand = self.hands[self.current_player]
#         public_cards = self.board
#         observed_actions = self.history
#         info_set = (
#             tuple(sorted(player_hand)),
#             tuple(sorted(public_cards)),
#             tuple(observed_actions)
#         )
#         return info_set
    
#     def get_available_actions(self):
#         # Return a list of possible actions at this node
#         return [Action.FOLD, Action.CALL, Action.RAISE]  # Example actions


#     def take_action(self, action, amount=0):
#         new_history = self.history + [action]
#         new_hands = self.hands.copy()
#         new_board = self.board.copy()
#         new_pot = self.pot
#         new_player_stacks = self.player_stacks.copy()
#         new_player_bets = self.player_bets.copy()
#         new_folded_players = self.folded_players.copy()
#         new_current_bet = self.current_bet
#         new_last_raiser = self.last_raiser
#         new_betting_round = self.betting_round

#         current_player_id = self.current_player

#         if action == Action.FOLD:
#             new_folded_players.add(current_player_id)
#         elif action == Action.CALL:
#             call_amount = new_current_bet - new_player_bets[current_player_id]
#             if call_amount > new_player_stacks[current_player_id]:
#                 call_amount = new_player_stacks[current_player_id]  # All-in
#             new_player_stacks[current_player_id] -= call_amount
#             new_player_bets[current_player_id] += call_amount
#             new_pot += call_amount
#         elif action == Action.RAISE:
#             raise_amount = amount
#             total_bet = new_current_bet + raise_amount
#             bet_amount = total_bet - new_player_bets[current_player_id]
#             if bet_amount > new_player_stacks[current_player_id]:
#                 bet_amount = new_player_stacks[current_player_id]  # All-in
#                 total_bet = new_player_bets[current_player_id] + bet_amount
#             new_player_stacks[current_player_id] -= bet_amount
#             new_player_bets[current_player_id] += bet_amount
#             new_pot += bet_amount
#             new_current_bet = total_bet
#             new_last_raiser = current_player_id
#         elif action == Action.CHECK:
#             pass  # No changes needed for check
#         elif action == Action.BET:
#             bet_amount = amount
#             if bet_amount > new_player_stacks[current_player_id]:
#                 bet_amount = new_player_stacks[current_player_id]  # All-in
#             new_player_stacks[current_player_id] -= bet_amount
#             new_player_bets[current_player_id] += bet_amount
#             new_pot += bet_amount
#             new_current_bet = new_player_bets[current_player_id]
#             new_last_raiser = current_player_id
#         else:
#             raise ValueError("Invalid action")

#         # Determine the next player
#         next_player = (self.current_player + 1) % self.num_players
#         while next_player in new_folded_players:
#             next_player = (next_player + 1) % self.num_players
#             if next_player == self.current_player:
#                 break  # All other players have folded

#         # Create a new GameNode with updated state
#         new_node = GameNode(
#             current_player=next_player,
#             hands=new_hands,
#             board=new_board,
#             history=new_history,
#             num_players=self.num_players,
#             pot=new_pot,
#             player_stacks=new_player_stacks,
#             player_bets=new_player_bets,
#             current_bet=new_current_bet,
#             folded_players=new_folded_players,
#             betting_round=new_betting_round,
#             last_raiser=new_last_raiser
#         )

#         # Update betting round if needed
#         if self.should_progress_round(new_node):
#             new_node.progress_round()

#         return new_node
    
#     def should_progress_round(self, node):
#         active_players = [p for p in range(self.num_players) if p not in node.folded_players]
#         # If only one player remains, the game ends
#         if len(active_players) <= 1:
#             return False
#         # If all active players have matched the current bet or are all-in
#         for player_id in active_players:
#             player_bet = node.player_bets[player_id]
#             if player_bet != node.current_bet and node.player_stacks[player_id] > 0:
#                 return False
#         return True
    
#     def progress_round(self):
#         self.player_bets = {player_id: 0 for player_id in self.player_bets}
#         self.current_bet = 0
#         self.last_raiser = None

#         if self.betting_round == 'pre-flop':
#             self.deal_flop()
#             self.betting_round = 'flop'
#         elif self.betting_round == 'flop':
#             self.deal_turn()
#             self.betting_round = 'turn'
#         elif self.betting_round == 'turn':
#             self.deal_river()
#             self.betting_round = 'river'
#         elif self.betting_round == 'river':
#             self.betting_round = 'showdown'
#         else:
#             pass  # Game is over

#     def deal_flop(self):
#         # Assuming self.deck is managed elsewhere
#         self.board.extend([self.deck.pop() for _ in range(3)])

#     def deal_turn(self):
#         self.board.append(self.deck.pop())

#     def deal_river(self):
#         self.board.append(self.deck.pop())

In [8]:
class GameNode:
    def __init__(self, current_player, hands, board, history, num_players=2, 
                 pot=0, player_stacks=None, player_bets=None, current_bet=0, 
                 folded_players=None, betting_round='pre-flop', last_raiser=None, deck=None):
        self.current_player = current_player  # ID of the current player
        self.hands = hands  # Dictionary mapping player IDs to their hands
        self.board = board  # Community cards
        self.history = history  # List of actions taken
        self.num_players = num_players  # Total number of players
        self.pot = pot  # Current size of the pot
        self.player_stacks = player_stacks if player_stacks is not None else {player_id: 1000 for player_id in hands}
        self.player_bets = player_bets if player_bets is not None else {player_id: 0 for player_id in hands}
        self.current_bet = current_bet  # Highest bet in the current betting round
        self.folded_players = folded_players if folded_players is not None else set()  # Players who have folded
        self.betting_round = betting_round  # Current betting round ('pre-flop', 'flop', 'turn', 'river', 'showdown')
        self.last_raiser = last_raiser  # ID of the last player who raised
        self.deck = deck if deck is not None else self.initialize_deck()  # Remaining cards in the deck

    def initialize_deck(self):
        suits = ['h', 'd', 'c', 's']
        ranks = range(2, 15)  # 2-14 where 11-14 are J, Q, K, A
        deck = [(rank, suit) for rank in ranks for suit in suits]
        random.shuffle(deck)
        return deck

    def is_terminal(self):
        active_players = [p for p in range(self.num_players) if p not in self.folded_players]
        # Game ends if only one player remains or after showdown
        if len(active_players) <= 1 or self.betting_round == 'showdown':
            return True
        return False

    def utility(self):
        if not self.is_terminal():
            return 0  # Utility is zero if the game is not over

        active_players = [p for p in range(self.num_players) if p not in self.folded_players]
        if len(active_players) == 1:
            # Only one player left; they win the pot
            winner = active_players[0]
            if self.current_player == winner:
                return self.pot
            else:
                return -self.pot
        else:
            # Showdown: compare hands
            hand_strengths = {p: self.evaluate_hand(p) for p in active_players}
            best_strength = min(hand_strengths.values())
            winners = [p for p, strength in hand_strengths.items() if strength == best_strength]
            if self.current_player in winners:
                # Split the pot among winners
                return self.pot / len(winners)
            else:
                return -self.pot / len(winners)

    def evaluate_hand(self, player_id):
        from treys import Evaluator, Card
        evaluator = Evaluator()
        player_hand = self.hands[player_id]
        board = self.board
        # Convert to treys Card objects
        player_cards = [Card.new(f"{self.rank_to_str(rank)}{suit}") for rank, suit in player_hand]
        board_cards = [Card.new(f"{self.rank_to_str(rank)}{suit}") for rank, suit in board]
        rank = evaluator.evaluate(board_cards, player_cards)
        return rank  # Lower rank means a better hand

    def rank_to_str(self, rank):
        rank_dict = {14: 'A', 13: 'K', 12: 'Q', 11: 'J', 10: 'T'}
        return rank_dict.get(rank, str(rank))

    def get_info_set(self):
        player_hand = self.hands[self.current_player]
        public_cards = self.board
        observed_actions = self.history
        info_set = (
            tuple(sorted(player_hand)),
            tuple(sorted(public_cards)),
            tuple(observed_actions)
        )
        return info_set

    def get_available_actions(self):
        if self.current_bet > self.player_bets[self.current_player]:
            # Player needs to call, raise, or fold
            return [Action.FOLD, Action.CALL, Action.RAISE]
        else:
            # Player can check or bet/raise
            if self.current_bet == 0:
                return [Action.CHECK, Action.BET]
            else:
                return [Action.CHECK, Action.RAISE]

    def take_action(self, action, amount=0):
        new_history = self.history + [action]
        new_hands = self.hands.copy()
        new_board = self.board.copy()
        new_pot = self.pot
        new_player_stacks = self.player_stacks.copy()
        new_player_bets = self.player_bets.copy()
        new_folded_players = self.folded_players.copy()
        new_current_bet = self.current_bet
        new_last_raiser = self.last_raiser
        new_betting_round = self.betting_round
        new_deck = self.deck.copy()

        current_player_id = self.current_player

        if action == Action.FOLD:
            new_folded_players.add(current_player_id)
        elif action == Action.CALL:
            call_amount = new_current_bet - new_player_bets[current_player_id]
            available_stack = new_player_stacks[current_player_id]
            actual_call = min(call_amount, available_stack)
            new_player_stacks[current_player_id] -= actual_call
            new_player_bets[current_player_id] += actual_call
            new_pot += actual_call
        elif action == Action.RAISE:
            raise_amount = amount
            total_bet = new_current_bet + raise_amount
            bet_amount = total_bet - new_player_bets[current_player_id]
            available_stack = new_player_stacks[current_player_id]
            actual_bet = min(bet_amount, available_stack)
            new_player_stacks[current_player_id] -= actual_bet
            new_player_bets[current_player_id] += actual_bet
            new_pot += actual_bet
            new_current_bet = new_player_bets[current_player_id]
            new_last_raiser = current_player_id
        elif action == Action.CHECK:
            pass  # No action needed
        elif action == Action.BET:
            bet_amount = amount
            available_stack = new_player_stacks[current_player_id]
            actual_bet = min(bet_amount, available_stack)
            new_player_stacks[current_player_id] -= actual_bet
            new_player_bets[current_player_id] += actual_bet
            new_pot += actual_bet
            new_current_bet = new_player_bets[current_player_id]
            new_last_raiser = current_player_id
        else:
            raise ValueError("Invalid action")

        # Determine the next player
        next_player = (self.current_player + 1) % self.num_players
        while next_player in new_folded_players:
            next_player = (next_player + 1) % self.num_players
            if next_player == self.current_player:
                break  # All other players have folded

        # Create a new GameNode with updated state
        new_node = GameNode(
            current_player=next_player,
            hands=new_hands,
            board=new_board,
            history=new_history,
            num_players=self.num_players,
            pot=new_pot,
            player_stacks=new_player_stacks,
            player_bets=new_player_bets,
            current_bet=new_current_bet,
            folded_players=new_folded_players,
            betting_round=new_betting_round,
            last_raiser=new_last_raiser,
            deck=new_deck
        )

        # Update betting round if needed
        if self.should_progress_round(new_node):
            new_node.progress_round()

        return new_node

    def should_progress_round(self, node):
        active_players = [p for p in range(self.num_players) if p not in node.folded_players]
        if len(active_players) <= 1:
            return True  # Only one player remains
        # Check if all active players have matched the current bet or are all-in
        for player_id in active_players:
            if node.player_stacks[player_id] > 0 and node.player_bets[player_id] != node.current_bet:
                return False
        if node.last_raiser == node.current_player:
            return True
        return False

    def progress_round(self):
        # Reset player bets for the new round
        self.player_bets = {player_id: 0 for player_id in self.player_bets}
        self.current_bet = 0
        self.last_raiser = None

        if self.betting_round == 'pre-flop':
            self.deal_flop()
            self.betting_round = 'flop'
        elif self.betting_round == 'flop':
            self.deal_turn()
            self.betting_round = 'turn'
        elif self.betting_round == 'turn':
            self.deal_river()
            self.betting_round = 'river'
        elif self.betting_round == 'river':
            self.betting_round = 'showdown'
        else:
            pass  # Game is over

    def deal_flop(self):
        # Burn a card
        self.deck.pop()
        # Deal three community cards
        self.board.extend([self.deck.pop() for _ in range(3)])

    def deal_turn(self):
        # Burn a card
        self.deck.pop()
        # Deal one community card
        self.board.append(self.deck.pop())

    def deal_river(self):
        # Burn a card
        self.deck.pop()
        # Deal one community card
        self.board.append(self.deck.pop())



hands = {
    0: [(14, 'h'), (13, 'd')],  # Player 0's hand
    1: [(12, 's'), (11, 'c')]   # Player 1's hand
}

# Create the root game node
root_node = GameNode(
    current_player=0,
    hands=hands,
    board=[],
    history=[],
    num_players=2
)

# Player 0 takes an action
next_node = root_node.take_action(Action.CALL)

# Player 1 takes an action
next_node = next_node.take_action(Action.CHECK)

In [9]:
# value_net = ValueNetwork(input_size=MAX_FEATURE_LENGTH)
# policy_net = PolicyNetwork(input_size=MAX_FEATURE_LENGTH, action_space=len(Action))
# value_optimizer = optim.Adam(value_net.parameters(), lr=1e-4)
# policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)

# NUM_EPISODES = 500

# import torch
# import torch.nn as nn
# import torch.optim as optim
# from enum import Enum, auto

# # Define sample_action function
# def sample_action(action_probs):
#     """
#     Samples an action from the given action probabilities.

#     Args:
#         action_probs (torch.Tensor): A tensor containing the probabilities for each action.

#     Returns:
#         Action: The selected action.
#     """
#     # Ensure action_probs is a 1D tensor
#     if action_probs.dim() > 1:
#         action_probs = action_probs.squeeze()

#     # Sample an action index based on the probabilities
#     action_index = torch.multinomial(action_probs, num_samples=1).item()

#     # Map the index to an action
#     action_list = list(Action)
#     selected_action = action_list[action_index]

#     return selected_action

# # Initialize neural networks and optimizers
# value_net = ValueNetwork(input_size=MAX_FEATURE_LENGTH)
# policy_net = PolicyNetwork(input_size=MAX_FEATURE_LENGTH, action_space=len(Action))
# value_optimizer = optim.Adam(value_net.parameters(), lr=1e-4)
# policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)

# # Training loop
# for episode in range(NUM_EPISODES):
#     # Initialize players
#     players = [Player(id=0, stack=1000), Player(id=1, stack=1000)]
    
#     # Initialize the game
#     game = TexasHoldEm(players)
#     game.deal_hole_cards()
    
#     # Initialize belief states for each player
#     belief_states = {player.id: BeliefState(observed_actions=[], public_cards=game.board) for player in players}
    
#     done = False
#     experiences = []  # To store experiences for training
    
#     while not done:
#         current_player = game.get_current_player()
#         belief_state = belief_states[current_player.id]
        
#         # Extract features
#         features = extract_features(belief_state)
        
#         # Get action probabilities from the policy network
#         action_probs = policy_net(features)
        
#         # Choose an action
#         action = sample_action(action_probs)
        
#         # Execute the action
#         game.execute_action(current_player, action)
        
#         # Get reward (to be defined based on your reward structure)
#         reward = game.get_reward(current_player)
        
#         # Record experience
#         experiences.append((features, action, reward))
        
#         # Update belief states for all players
#         for player in players:
#             belief_states[player.id].update(action, new_public_cards=game.board)
        
#         # Check for end of round/game
#         if game.is_round_over():
#             game.progress_round()
#         if game.is_game_over():
#             game.determine_winner()
#             done = True  # End the game loop
    
#     # After the game, update the networks using collected experiences

#     # Step 1: Prepare the training data
#     states = torch.stack([exp[0] for exp in experiences])  # Features are tensors
#     actions = [exp[1] for exp in experiences]              # Actions are enums
#     rewards = [exp[2] for exp in experiences]              # Rewards are scalars

#     # Step 2: Convert actions to indices
#     action_to_index = {action: idx for idx, action in enumerate(Action)}
#     action_indices = torch.tensor([action_to_index[action] for action in actions], dtype=torch.long)

#     # Step 3: Compute value targets (discounted cumulative rewards)
#     rewards = torch.tensor(rewards, dtype=torch.float32)
#     gamma = 1.0  # No discounting
#     returns = []
#     R = 0
#     for r in reversed(rewards):
#         R = r + gamma * R
#         returns.insert(0, R)
#     value_targets = torch.tensor(returns, dtype=torch.float32)

#     # Step 4: Compute value loss
#     value_predictions = value_net(states).squeeze()
#     value_loss_fn = nn.MSELoss()
#     value_loss = value_loss_fn(value_predictions, value_targets)

#     # Step 5: Compute policy loss
#     policy_outputs = policy_net(states)  # Shape: (batch_size, num_actions)
#     action_probs = policy_outputs.gather(1, action_indices.unsqueeze(1)).squeeze()
#     log_probs = torch.log(action_probs + 1e-10)
#     with torch.no_grad():
#         advantages = value_targets - value_predictions.detach()
#     policy_loss = - (log_probs * advantages).mean()

#     # Step 6: Update the networks
#     # Update value network
#     value_optimizer.zero_grad()
#     value_loss.backward()
#     value_optimizer.step()

#     # Update policy network
#     policy_optimizer.zero_grad()
#     policy_loss.backward()
#     policy_optimizer.step()
    
#     # Compute entropy of the policy
#     entropy = - (policy_outputs * torch.log(policy_outputs + 1e-10)).sum(dim=1).mean()

#     # Add entropy regularization to the policy loss
#     entropy_coef = 0.01  # Adjust as needed
#     policy_loss = policy_loss - entropy_coef * entropy
    
#     # Extract CFR strategies and values
#     cfr_strategies = torch.stack([exp[2] for exp in experiences])  # CFR strategies
#     cfr_values = torch.tensor([exp[3] for exp in experiences], dtype=torch.float32)  # CFR values

#     # Value Loss
#     value_predictions = value_net(states).squeeze()
#     value_loss_fn = nn.MSELoss()
#     value_loss = value_loss_fn(value_predictions, cfr_values)

#     # Policy Loss
#     policy_outputs = policy_net(states)
#     policy_loss_fn = nn.CrossEntropyLoss()
#     policy_loss = policy_loss_fn(policy_outputs, cfr_strategies.argmax(dim=1))



#     # Optionally, print progress
#     if (episode + 1) % 100 == 0:
#         print(f"Completed episode {episode + 1}/{NUM_EPISODES}")

In [10]:
import torch.nn as nn

MAX_FEATURE_LENGTH = 25

# Define ValueNetwork
class ValueNetwork(nn.Module):
    def __init__(self, input_size):
        super(ValueNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)  # Output a single value
        )
    
    def forward(self, x):
        return self.fc(x)

# Define PolicyNetwork
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, action_space):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, action_space)
        )
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        x = self.fc(x)
        return self.softmax(x)
    
# Initialize neural networks and optimizers
value_net = ValueNetwork(input_size=MAX_FEATURE_LENGTH)
policy_net = PolicyNetwork(input_size=MAX_FEATURE_LENGTH, action_space=len(Action))

In [11]:
class Player:
    def __init__(self, id, stack):
        self.id = id
        self.stack = stack  # Total chips the player has
        self.hand = []      # The player's hole cards
        self.current_bet = 0
        self.folded = False

    def reset_for_new_hand(self):
        self.hand = []
        self.current_bet = 0
        self.folded = False

import random
from enum import Enum

class TexasHoldEm:
    def __init__(self, players, small_blind=10, big_blind=20):
        self.players = players
        self.small_blind = small_blind
        self.big_blind = big_blind
        self.deck = self.initialize_deck()
        self.board = []  # Community cards
        self.pot = 0
        self.current_bet = 0
        self.game_over = False
        self.dealer_index = 0
        self.current_player_index = (self.dealer_index + 1) % len(self.players)
        self.betting_round = 'pre-flop'
        self.round_bets = {player.id: 0 for player in self.players}
        self.last_raiser = None

    def initialize_deck(self):
        suits = ['h', 'd', 'c', 's']
        ranks = list(range(2, 15))  # 2-14 where 11-14 are J, Q, K, A
        deck = [(rank, suit) for rank in ranks for suit in suits]
        random.shuffle(deck)
        return deck

    def deal_hole_cards(self):
        for player in self.players:
            player.hand = [self.deck.pop(), self.deck.pop()]
            player.current_bet = 0
            player.folded = False
        self.round_bets = {player.id: 0 for player in self.players}

    def post_blinds(self):
        small_blind_player = self.players[(self.dealer_index + 1) % len(self.players)]
        big_blind_player = self.players[(self.dealer_index + 2) % len(self.players)]
        self._post_blind(small_blind_player, self.small_blind)
        self._post_blind(big_blind_player, self.big_blind)
        self.current_bet = self.big_blind
        self.last_raiser = big_blind_player.id
        self.current_player_index = (self.dealer_index + 3) % len(self.players)

    def _post_blind(self, player, amount):
        player.stack -= amount
        player.current_bet = amount
        self.pot += amount
        self.round_bets[player.id] += amount

    def get_current_player(self):
        while True:
            player = self.players[self.current_player_index]
            if not player.folded and player.stack > 0:
                return player
            self.current_player_index = (self.current_player_index + 1) % len(self.players)

    def get_available_actions(self, player):
        actions = []
        if player.current_bet < self.current_bet:
            actions.extend([Action.FOLD, Action.CALL])
            if player.stack > (self.current_bet - player.current_bet):
                actions.append(Action.RAISE)
        else:
            actions.append(Action.CHECK)
            if player.stack > 0:
                actions.append(Action.BET)
        return actions

    def execute_action(self, player, action, raise_amount=0):
        if action == Action.FOLD:
            self.handle_fold(player)
        elif action == Action.CHECK:
            self.handle_check(player)
        elif action == Action.CALL:
            self.handle_call(player)
        elif action == Action.BET:
            self.handle_bet(player, raise_amount)
        elif action == Action.RAISE:
            self.handle_raise(player, raise_amount)
        else:
            raise ValueError("Invalid action")

        # Move to the next player
        self.current_player_index = (self.current_player_index + 1) % len(self.players)

    def handle_fold(self, player):
        player.folded = True

    def handle_check(self, player):
        pass  # No action needed

    def handle_call(self, player):
        call_amount = self.current_bet - player.current_bet
        actual_call = min(call_amount, player.stack)
        player.stack -= actual_call
        player.current_bet += actual_call
        self.pot += actual_call
        self.round_bets[player.id] += actual_call

    def handle_bet(self, player, amount):
        if amount <= 0 or amount > player.stack:
            raise ValueError("Invalid bet amount")
        player.stack -= amount
        player.current_bet += amount
        self.current_bet = player.current_bet
        self.pot += amount
        self.round_bets[player.id] += amount
        self.last_raiser = player.id

    def handle_raise(self, player, amount):
        if amount <= 0 or amount > player.stack:
            raise ValueError("Invalid raise amount")
        raise_amount = amount
        call_amount = self.current_bet - player.current_bet
        total_bet = call_amount + raise_amount
        if total_bet > player.stack:
            total_bet = player.stack
        player.stack -= total_bet
        player.current_bet += total_bet
        self.current_bet = player.current_bet
        self.pot += total_bet
        self.round_bets[player.id] += total_bet
        self.last_raiser = player.id

    def is_round_over(self):
        active_players = [p for p in self.players if not p.folded and p.stack > 0]
        if len(active_players) <= 1:
            return True
        for player in active_players:
            if player.current_bet != self.current_bet:
                return False
        return True

    def progress_round(self):
        for player in self.players:
            player.current_bet = 0
        self.round_bets = {player.id: 0 for player in self.players}
        self.current_bet = 0
        self.last_raiser = None

        if self.betting_round == 'pre-flop':
            self.deal_flop()
            self.betting_round = 'flop'
        elif self.betting_round == 'flop':
            self.deal_turn()
            self.betting_round = 'turn'
        elif self.betting_round == 'turn':
            self.deal_river()
            self.betting_round = 'river'
        elif self.betting_round == 'river':
            self.game_over = True

    def deal_flop(self):
        self.deck.pop()  # Burn card
        self.board.extend([self.deck.pop() for _ in range(3)])

    def deal_turn(self):
        self.deck.pop()  # Burn card
        self.board.append(self.deck.pop())

    def deal_river(self):
        self.deck.pop()  # Burn card
        self.board.append(self.deck.pop())

    def is_game_over(self):
        active_players = [p for p in self.players if not p.folded and p.stack > 0]
        if len(active_players) <= 1:
            return True
        return self.game_over

    def determine_winner(self):
        active_players = [p for p in self.players if not p.folded]
        if len(active_players) == 1:
            winner = active_players[0]
            winner.stack += self.pot
            self.pot = 0
        else:
            # Showdown: evaluate hands
            from treys import Evaluator, Card
            evaluator = Evaluator()
            best_rank = None
            winners = []
            for player in active_players:
                player_cards = [Card.new(f"{self.rank_to_str(rank)}{suit}") for rank, suit in player.hand]
                board_cards = [Card.new(f"{self.rank_to_str(rank)}{suit}") for rank, suit in self.board]
                hand_rank = evaluator.evaluate(board_cards, player_cards)
                if best_rank is None or hand_rank < best_rank:
                    best_rank = hand_rank
                    winners = [player]
                elif hand_rank == best_rank:
                    winners.append(player)
            split_pot = self.pot / len(winners)
            for winner in winners:
                winner.stack += split_pot
            self.pot = 0

    def rank_to_str(self, rank):
        rank_dict = {14: 'A', 13: 'K', 12: 'Q', 11: 'J', 10: 'T'}
        return rank_dict.get(rank, str(rank))

    def get_reward(self, player):
        initial_stack = 1000
        return player.stack - initial_stack

class BeliefState:
    def __init__(self, observed_actions, public_cards, pot_size=0):
        self.observed_actions = observed_actions
        self.public_cards = public_cards
        self.pot_size = pot_size  # Add pot size

    def update(self, action, new_public_cards, pot_size):
        self.observed_actions.append(action)
        self.public_cards = new_public_cards
        self.pot_size = pot_size

def extract_features(belief_state):
    features = []

    # Encode observed actions
    action_encoding = {
        Action.FOLD: 0,
        Action.CHECK: 1,
        Action.CALL: 2,
        Action.RAISE: 3,
        Action.BET: 4
    }
    max_history_length = 12  # Increased from 10 to 12
    action_features = [action_encoding[action] for action in belief_state.observed_actions[-max_history_length:]]
    action_features += [0] * (max_history_length - len(action_features))
    features.extend(action_features)

    # Encode public cards
    rank_encoding = {r: i for i, r in enumerate(range(2, 15), start=1)}
    suit_encoding = {'h': 0, 'd': 1, 'c': 2, 's': 3}
    max_board_cards = 6  # Increased from 5 to 6
    board_features = []
    for rank, suit in belief_state.public_cards:
        rank_feature = rank_encoding.get(rank, 0)
        suit_feature = suit_encoding.get(suit, 0)
        board_features.extend([rank_feature, suit_feature])
    while len(board_features) < max_board_cards * 2:
        board_features.extend([0, 0])
    features.extend(board_features)

    # Ensure total feature length is 25
    # Current length: max_history_length (12) + max_board_cards * 2 (12) = 24
    # Add one more feature, e.g., pot size
    pot_size = belief_state.pot_size if hasattr(belief_state, 'pot_size') else 0
    features.append(pot_size / 1000)  # Normalize pot size

    # Convert to tensor
    features = torch.tensor(features, dtype=torch.float32)
    return features

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from enum import Enum, auto

class Action(Enum):
    FOLD = 0
    CHECK = 1
    CALL = 2
    RAISE = 3
    BET = 4

# Define MAX_FEATURE_LENGTH
MAX_FEATURE_LENGTH = 25  # Adjust as needed

# Define sample_action function
def sample_action(action_probs, valid_actions):
    """
    Samples an action from the given action probabilities, considering only valid actions.

    Args:
        action_probs (torch.Tensor): A tensor containing the probabilities for each action.
        valid_actions (list of Action): List of valid actions in the current state.

    Returns:
        Action: The selected action.
    """
    action_list = list(Action)
    action_to_index = {action: idx for idx, action in enumerate(action_list)}
    valid_action_indices = [action_to_index[action] for action in valid_actions]

    # Get probabilities of valid actions
    valid_action_probs = action_probs[valid_action_indices]
    valid_action_probs /= valid_action_probs.sum()  # Normalize

    # Sample from valid actions
    chosen_index = torch.multinomial(valid_action_probs, num_samples=1).item()
    selected_action = valid_actions[chosen_index]

    return selected_action

def determine_raise_amount(player, game):
    """
    Determines the amount to raise or bet.

    Args:
        player (Player): The player who is raising.
        game (TexasHoldEm): The current game state.

    Returns:
        float: The raise amount.
    """
    # Simple strategy: raise by a fixed amount or percentage of the pot
    # For this example, we'll raise by half the pot or the player's remaining stack, whichever is smaller
    raise_amount = min(player.stack, max(game.pot * 0.5, game.big_blind))
    return raise_amount

# Initialize neural networks and optimizers
value_net = ValueNetwork(input_size=MAX_FEATURE_LENGTH)
policy_net = PolicyNetwork(input_size=MAX_FEATURE_LENGTH, action_space=len(Action))
value_optimizer = optim.Adam(value_net.parameters(), lr=1e-4)
policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-4)

NUM_EPISODES = 500

# Training loop
for episode in range(NUM_EPISODES):
    # Initialize players
    players = [Player(id=0, stack=1000), Player(id=1, stack=1000)]
    
    # Initialize the game
    game = TexasHoldEm(players)
    game.deal_hole_cards()
    game.post_blinds()
    
    # Initialize belief states for each player
    belief_states = {
        player.id: BeliefState(
            observed_actions=[],
            public_cards=game.board,
            pot_size=game.pot  # Optional
        ) for player in players
    }
    
    done = False
    experiences = {player.id: [] for player in players}  # Collect experiences for each player separately
    
    while not done:
        current_player = game.get_current_player()
        belief_state = belief_states[current_player.id]
        
        # Extract features
        features = extract_features(belief_state)

        # Ensure features have the correct shape
        if features.dim() == 1:
            features = features.unsqueeze(0)  # Add batch dimension if necessary
        
        # Get action probabilities from the policy network
        action_probs = policy_net(features)
        action_probs = action_probs.squeeze(0)  # Remove batch dimension
        
        # Get available actions for the current player
        valid_actions = game.get_available_actions(current_player)
        
        # Choose an action
        action = sample_action(action_probs, valid_actions)
        
        # Determine raise amount if necessary
        if action in [Action.BET, Action.RAISE]:
            raise_amount = determine_raise_amount(current_player, game)
            game.execute_action(current_player, action, raise_amount=raise_amount)
        else:
            game.execute_action(current_player, action)
        
        # Record experience for the current player
        experiences[current_player.id].append((features.squeeze(0), action, 0))  # Squeeze features before storing, and reward is zero for now
        
        # Update belief states for all players
        for player in players:
            belief_states[player.id].update(
                action,
                new_public_cards=game.board,
                pot_size=game.pot  # Add this argument
            )
        
        # Check for end of round/game
        if game.is_round_over():
            game.progress_round()
        if game.is_game_over():
            game.determine_winner()
            done = True  # End the game loop
    
    # After the game ends, assign rewards and update networks
    
    # Assign final rewards to each player's experiences
    # Update belief states for all players
    for player in players:
        belief_states[player.id].update(
            action,
            new_public_cards=game.board,
            pot_size=game.pot
        )
    
    # Combine experiences from both players
    all_experiences = []
    for exp_list in experiences.values():
        all_experiences.extend(exp_list)
    
    # Proceed to prepare data and update networks

    # Step 1: Prepare the training data
    states = torch.stack([exp[0] for exp in all_experiences])  # Features are tensors
    actions = [exp[1] for exp in all_experiences]              # Actions are enums
    rewards = [exp[2] for exp in all_experiences]              # Rewards are scalars

    # Step 2: Convert actions to indices
    action_to_index = {action: idx for idx, action in enumerate(Action)}
    action_indices = torch.tensor([action_to_index[action] for action in actions], dtype=torch.long)

    # Step 3: Compute value targets
    value_targets = torch.tensor(rewards, dtype=torch.float32)
    
    # Step 4: Compute value loss
    value_predictions = value_net(states).squeeze()
    value_loss_fn = nn.MSELoss()
    value_loss = value_loss_fn(value_predictions, value_targets)

    # Step 5: Compute policy loss
    policy_outputs = policy_net(states)  # Shape: (batch_size, num_actions)
    
    action_probs_taken = policy_outputs.gather(1, action_indices.unsqueeze(1)).squeeze()
    log_probs = torch.log(action_probs_taken + 1e-10)
    with torch.no_grad():
        advantages = value_targets - value_predictions.detach()
    policy_loss = - (log_probs * advantages).mean()
    
    # Optional: Add entropy regularization to encourage exploration
    entropy = - (policy_outputs * torch.log(policy_outputs + 1e-10)).sum(dim=1).mean()
    entropy_coef = 0.01  # Adjust as needed
    policy_loss = policy_loss - entropy_coef * entropy

    # Step 6: Update the networks
    # Update value network
    value_optimizer.zero_grad()
    value_loss.backward()
    value_optimizer.step()

    # Update policy network
    policy_optimizer.zero_grad()
    policy_loss.backward()
    policy_optimizer.step()
    
    # Optionally, print progress
    if (episode + 1) % 100 == 0:
        print(f"Completed episode {episode + 1}/{NUM_EPISODES}")

  return F.mse_loss(input, target, reduction=self.reduction)


Completed episode 100/500
Completed episode 200/500
Completed episode 300/500
Completed episode 400/500
Completed episode 500/500


In [13]:
# After training is complete
torch.save(policy_net.state_dict(), 'policy_net.pth')
torch.save(value_net.state_dict(), 'value_net.pth')