In [None]:
import numpy as np
import torch
from ai.debug_utils import debug_print

# Value Loss

In [10]:
def r_gamma(rewards: np.array, gamma):
    r_gamma = 0
    for reward in rewards[:0:-1]:  # Slicing to reverse except the first element
        r_gamma = gamma * (r_gamma + reward)
        debug_print("reward: ", reward)
    r_gamma += rewards[0]
    return r_gamma

def value_function(state): #TODO: implement this
    return 0

def v_loss(r_gamma, state, deltas):
    return (np.clip(r_gamma, -deltas[1], deltas[2])-value_function(state))**2

# Policy Loss

In [3]:
def policy(state):
    if state == 'Preflop':
        # return np.array([0.5, 0.5, 0, 0, 0])
        return np.array([0, 1, 0, 0, 0])
    elif state == 'Flop':
        # return np.array([0, 0.5, 0.5, 0, 0])
        return np.array([0, 0, 1, 0, 0])
    elif state == 'Turn':
        # return np.array([0, 0, 1/3, 1/3, 1/3])
        return np.array([0, 0, 0, 1, 0])
    elif state == 'River':
        # return np.array([0, 1/2, 1/2, 0, 0])
        return np.array([0, 0, 1, 0, 0])
    return 0

def get_deltas(state): # I'm not sure wether the deltas should add the chips from previous states or just the current state
    delta1 = 3
    if state == 'Preflop':
        delta2 = 20
        delta3 = 10 # The opponent put in the big blind and the agent just betted. The opponnent hasnt put in any chips yet.
    elif state == 'Flop':
        delta2 = 40
        delta3 = 20
    elif state == 'Turn':
        delta2 = 120
        delta3 = 80
    elif state == 'River':
        delta2 = 120
        delta3 = 120
    return delta1, delta2, delta3

def ratio(old_policy, new_policy, action, state):
    return new_policy(state)[action] / old_policy(state)[action]

def a_gae(results, states, value_function, gamma, lambda_):
    """
    Generalized Advantage Estimator (GAE) where:
      - len(states) == len(results)
      - We do NOT assume an extra 'terminal state' beyond these states.
    
    results:       list/array of rewards at each timestep
    states:        list/array of states at each timestep
    value_function: function that takes a state and returns a scalar value
    gamma:         discount factor
    lambda_:       GAE parameter
    """
    N = len(results)
    if N == 0:
        return 0.0
    
    # For convenience, compute V(s0) once
    v0 = value_function(states[0])
    
    # --------------------------------------------------------
    # 1) Precompute partial sums of discounted rewards:
    #    S[k] = sum_{i=0..k-1} gamma^i * results[i], with S[0] = 0
    #
    #    Then the "raw" advantage term (before weighting by λ^(k-1)) is:
    #       a_k = - V(s0) + S[k] + gamma^k * V(sk),
    #    for k in 1..N-1 (because states[k] must be valid).
    # --------------------------------------------------------
    S = np.zeros(N+1, dtype=float)
    for i in range(N):
        S[i+1] = S[i] + (gamma ** i) * results[i]

    # --------------------------------------------------------
    # 2) Accumulate the GAE sum:
    #
    #    A = (1 - λ) * Σ (λ^(k-1) * a_k), for k = 1..N-1
    #
    #    We use k=1..N-1 so that states[k] is still in range.
    # --------------------------------------------------------
    gae_sum = 0.0
    for k in range(1, N):
        a_k = -v0 + S[k] + (gamma ** k) * value_function(states[k])
        gae_sum += (lambda_ ** (k - 1)) * a_k
    
    return (1 - lambda_) * gae_sum

# I wasn't sure how to treat the showdown state. The approach I am following is when the only states that are fed to the a_gae function are the river and the showdown, the resulting a_k() is 
# -V(river) + r(river) + V(showdown). I think the values for the river state and the showdown state will be different because the showdown value depends on the amount of chips that the agent has played
# in the river state.
    
def tc_loss_function(ratio, advantage, epsilon, deltas): #We compute this for every hand and then average it
    return np.clip(ratio, np.clip(ratio, 1 - epsilon, 1 + epsilon), deltas[0]) * advantage
    

In [4]:
def get_action(policy: callable, state):
    return np.random.choice(len(policy(state)), p=policy(state))

In [11]:
rewards = np.array([-20, -40, 0, -100, 320])
deltas = [3, 160, 160]
debug_print(r_gamma(rewards, 0.999))
debug_print(v_loss(r_gamma(rewards, 0.999), 0, deltas))

reward:  320
reward:  -100
reward:  0
reward:  -40
159.06161882032
reward:  320
reward:  -100
reward:  0
reward:  -40
25300.598581740778


In [6]:
def get_losses(states, rewards, policy, value_function):
    tc_loss = 0
    value_loss = 0
    states_without_showdown = states[:-1]
    for i, state in enumerate(states_without_showdown):
        deltas = get_deltas(state)
        rewards_from_now = rewards[i:]
        states_from_now = states[i:]
        advantage = a_gae(rewards_from_now, states_from_now, value_function, 0.999, 0.99) #I'm not sure if this is correct
        action = get_action(policy, state)
        old_policy = policy
        new_policy = policy
        r = ratio(old_policy, new_policy, action, state)
        tc_loss += tc_loss_function(r, advantage, 0.2, deltas)
        value_loss += v_loss(r_gamma(rewards_from_now, 0.999), 0, deltas)

    tc_loss /= len(states_without_showdown)
    value_loss /= len(states_without_showdown)
    return tc_loss, value_loss

In [12]:
states = ['Preflop', 'Flop', 'Turn', 'River', 'Showdown']
rewards = [-20, -20, -80, 0, 240] #There should be one more reward than states
tc_loss = 0
value_loss = 0



states_without_showdown = states[:-1]
for i, state in enumerate(states_without_showdown):
    deltas = get_deltas(state)
    debug_print(states[i:])
    debug_print('deltas: ', deltas)
    rewards_from_now = rewards[i:]
    states_from_now = states[i:]
    advantage = a_gae(rewards_from_now, states_from_now, value_function, 0.999, 0.99) #I'm not sure if this is correct
    debug_print('Advantage: ', advantage)
    action = get_action(policy, state)
    old_policy = policy
    new_policy = policy
    r = ratio(old_policy, new_policy, action, state)
    debug_print(deltas[1], deltas[2])
    tc_loss += tc_loss_function(r, advantage, 0.2, deltas)
    value_loss += v_loss(r_gamma(rewards_from_now, 0.999), 0, deltas)

tc_loss /= len(states_without_showdown)
value_loss /= len(states_without_showdown)
debug_print('TC loss: ', tc_loss)
debug_print('Value loss: ', value_loss)

['Preflop', 'Flop', 'Turn', 'River', 'Showdown']
deltas:  (3, 20, 10)
Advantage:  -2.932771642119203
20 10
reward:  240
reward:  0
reward:  -80
reward:  -20
['Flop', 'Turn', 'River', 'Showdown']
deltas:  (3, 40, 20)
Advantage:  -2.168523920000002
40 20
reward:  240
reward:  0
reward:  -80
['Turn', 'River', 'Showdown']
deltas:  (3, 120, 80)
Advantage:  -1.5920000000000012
120 80
reward:  240
reward:  0
['River', 'Showdown']
deltas:  (3, 120, 120)
Advantage:  0.0
120 120
reward:  240
TC loss:  -1.6733238905298016
Value loss:  5325.0


In [8]:
states = ['Preflop', 'Flop', 'Turn', 'River', 'Showdown']
rewards = [-20, -20, -80, 0, 240] #There should be one more reward than states
tc_loss = 0
value_loss = 0



states_without_showdown = states[:-1]
for i, state in enumerate(states_without_showdown):
    deltas = get_deltas(state)
    debug_print(states[i:])
    rewards_from_now = rewards[i:]
    states_from_now = states[i:]
    advantage = a_gae(rewards_from_now, states_from_now, value_function, 0.999, 0.99) #I'm not sure if this is correct
    debug_print('Advantage: ', advantage)
    action = get_action(policy, state)
    old_policy = policy
    new_policy = policy
    r = ratio(old_policy, new_policy, action, state)
    debug_print(deltas[1], deltas[2])
    tc_loss += tc_loss_function(r, advantage, 0.2, deltas)
    value_loss += v_loss(r_gamma(rewards_from_now, 0.999), 0, deltas)

tc_loss /= len(states_without_showdown)
value_loss /= len(states_without_showdown)
debug_print('TC loss: ', tc_loss)
debug_print('Value loss: ', value_loss)

['Preflop', 'Flop', 'Turn', 'River', 'Showdown']
Advantage:  -2.932771642119203
20 10
240
0
-80
-20
['Flop', 'Turn', 'River', 'Showdown']
Advantage:  -2.168523920000002
40 20
240
0
-80
['Turn', 'River', 'Showdown']
Advantage:  -1.5920000000000012
120 80
240
0
['River', 'Showdown']
Advantage:  0.0
120 120
240
TC loss:  -1.6733238905298016
Value loss:  5325.0


In [41]:
debug_print(get_losses(states, rewards, policy, value_function))

1 [-20, -20, -80, 0, 240]
2 [-20, -20, -80, 0, 240]
3 [-20, -20, -80, 0, 240]
4 [-20, -20, -80, 0, 240]
240
0
-80
-20
1 [-20, -80, 0, 240]
2 [-20, -80, 0, 240]
3 [-20, -80, 0, 240]
240
0
-80
1 [-80, 0, 240]
2 [-80, 0, 240]
240
0
1 [0, 240]
240
(np.float64(-0.9304432345098008), np.float64(5325.0))


# Game State Representation

## Card representation

In [14]:
import numpy as np

class CardRepresentation:
    """
    Incrementally build a 6 x 4 x 13 card tensor:
      - Channel 0: hole cards
      - Channel 1: flop
      - Channel 2: turn
      - Channel 3: river
      - Channel 4: all public (flop+turn+river)
      - Channel 5: hole + public
    """
    def __init__(self):
        # Create it once, fill incrementally
        self.card_tensor = np.zeros((6, 4, 13), dtype=np.float32)
        
        # Track which cards have been seen
        self.hole_cards = []
        self.public_cards = []
    
    def _mark_card(self, channel, rank, suit):
        """Helper to set a 1 for (channel, suit, rank)."""
        self.card_tensor[channel, suit, rank] = 1.0
    
    def set_preflop(self, hole_cards):
        """
        hole_cards: list of 2 tuples [(rank, suit), (rank, suit)]
        Fills channel 0 (hole) and partially updates channel 5 (hole+public).
        """
        self.hole_cards = hole_cards[:]  # store
        for (r, s) in hole_cards:
            self._mark_card(0, r, s)  # Channel 0: hole cards
            self._mark_card(5, r, s)  # Channel 5: hole+public (so far, just hole)
    
    def set_flop(self, flop_cards):
        """
        flop_cards: list of 3 tuples [(rank, suit), ...]
        Fills channel 1 (flop), channel 4 (all public), channel 5 (hole+public).
        """
        for (r, s) in flop_cards:
            self._mark_card(1, r, s)  # Channel 1: flop
            self._mark_card(4, r, s)  # Channel 4: public
            self._mark_card(5, r, s)  # Channel 5: hole+public
        self.public_cards.extend(flop_cards)
    
    def set_turn(self, turn_card):
        """
        turn_card: single tuple (rank, suit)
        Fills channel 2 (turn), channel 4 (public), channel 5 (hole+public).
        """
        if turn_card:
            r, s = turn_card
            self._mark_card(2, r, s)  # Channel 2: turn
            self._mark_card(4, r, s)  # Channel 4: public
            self._mark_card(5, r, s)  # Channel 5: hole+public
            self.public_cards.append(turn_card)
    
    def set_river(self, river_card):
        """
        river_card: single tuple (rank, suit)
        Fills channel 3 (river), channel 4 (public), channel 5 (hole+public).
        """
        if river_card:
            r, s = river_card
            self._mark_card(3, r, s)  # Channel 3: river
            self._mark_card(4, r, s)  # Channel 4: public
            self._mark_card(5, r, s)  # Channel 5: hole+public
            self.public_cards.append(river_card)


In [15]:

# Example usage:
# 1) Initialize
card_rep = CardRepresentation()

# 2) Preflop
hole_cards = [(12, 3), (12, 2)]  # 'AsAc' (rank=12 => Ace, suits=0 => spade, 2 => diamond)
card_rep.set_preflop(hole_cards)

# 3) Flop arrives
flop = [(7, 1), (3, 3), (9, 2)] # '8d 4s Tc'
card_rep.set_flop(flop)

# 4) Turn arrives
turn = (5, 0) # '6h'
card_rep.set_turn(turn)

# 5) River arrives
river = (11, 3) # 'Qs'
card_rep.set_river(river)

debug_print("Card tensor shape:", card_rep.card_tensor.shape)  # (6, 4, 13)
debug_print("Card tensor:", card_rep.card_tensor)

Card tensor shape: (6, 4, 13)
Card tensor: [[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]

 [[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0.]]

 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0.]
  [0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 1. 0.]]

 [[0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0.]
  [0. 0. 0. 0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 

## Action Representation

### You build a (24, 4, 9) tensor in the example (rounds=4, max_actions_per_round=6, nb=9):
So there are 24 “channels” → (4 betting rounds × 6 possible actions per round).
Each channel is a 4×9 “mini-grid,” where:
row=0/1 might mark which action was chosen by player 0 or 1,
row=2 might store cumulative pot or sum,
row=3 might store legality of each action.
The idea is that you “fill in” these channels incrementally as the hand progresses (and each action is taken). Then the CNN can glean the pattern and sequence of actions so far.

In [16]:
class ActionRepresentation:
    """
    Incrementally build a 24 x 4 x nb action tensor:
      - 24 channels => 4 rounds * 6 actions per round
      - each channel => shape (4, nb), for [P1 row, P2 row, sum row, legal row] x bet options
    """
    def __init__(self, nb=9, max_actions_per_round=6, rounds=4):
        self.nb = nb
        self.max_actions = max_actions_per_round
        self.rounds = rounds
        
        # 24 channels total, each is 4 x nb (in this case 4x9)
        self.action_tensor = np.zeros((rounds * max_actions_per_round, 4, nb), 
                                      dtype=np.float32)
    
    def add_action(self, round_id, action_index_in_round, player_id, action_idx, legal_actions=None, sum_idx=None):
        """
        round_id in [0..3]
        action_index_in_round in [0..5]
        player_id in [0..1]  (player 1 or player 2)
        action_idx in [0..nb-1] (which bet option was chosen)
        legal_actions: a list of valid action_idx's at this step (if you want to mark row 3)
        sum_idx: optional single int to mark row=2 (the 'sum of bets' row, or pot-size index)
        """
        channel_id = round_id * self.max_actions + action_index_in_round
        # Mark the chosen action for the current player
        self.action_tensor[channel_id, player_id, action_idx] = 1.0
        
        # If you want to store sum-of-bets so far in row=2:
        if sum_idx is not None and 0 <= sum_idx < self.nb:
            self.action_tensor[channel_id, 2, sum_idx] = 1.0
        
        # If you want to store legal actions in row=3
        if legal_actions:
            for la in legal_actions:
                if 0 <= la < self.nb:
                    self.action_tensor[channel_id, 3, la] = 1.0

In [17]:
# Example usage:
action_rep = ActionRepresentation(nb=9, max_actions_per_round=6, rounds=4)

# Preflop, first action: round_id=0, action_idx_in_round=0
# Player 0 (small blind) "bet pot" => let's say pot = 1, action_idx=6 means "bet pot"
# legal actions might be [0,1,2,3,4,5,6,7,8] if all are valid
action_rep.add_action(
    round_id=0, 
    action_index_in_round=0, 
    player_id=0, 
    action_idx=6, 
    legal_actions=range(9),  # all are valid (actually here only calling, raising or folding or available since we need to match BB)
    sum_idx=None            # or some pot-based index if desired
)

# Next action: round_id=0, action_idx_in_round=1
# Player 1 calls => action_idx=1 means "check/call"
action_rep.add_action(
    round_id=0,
    action_index_in_round=1,
    player_id=1,
    action_idx=1,
    legal_actions=range(9)
)

debug_print("Action tensor shape:", action_rep.action_tensor.shape)  # (24, 4, 9)

Action tensor shape: (24, 4, 9)


# Pseudo-Siamese Network Implementation

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PseudoSiameseNet(nn.Module):
    def __init__(
        self,
        action_in_shape=(24, 4, 9),  # Example: (channels=24, H=4, W=9)
        card_in_shape=(6, 4, 13),   # Example: (channels=6,  H=4, W=13)
        conv_out_dim=128,           # Dim of each branch's embedded output
        hidden_dim=256,             # Dim of fused hidden layer
        num_actions=9               # Example final policy dimension
    ):
        super(PseudoSiameseNet, self).__init__()
        
        # 1) Convolutional branch for the "action" tensor
        # Example architecture: Conv->Pool->Conv->Flatten->Linear
        # (Your architecture may vary; just ensure the output = conv_out_dim)
        self.action_conv = nn.Sequential(
            nn.Conv2d(action_in_shape[0], 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        # We'll need to figure out the flatten size after these convs.
        # A quick trick is to do a test forward pass on dummy data in __init__
        
        # 2) Convolutional branch for the "card" tensor
        self.card_conv = nn.Sequential(
            nn.Conv2d(card_in_shape[0], 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        
        # We'll define linear heads to transform each branch's conv output into conv_out_dim
        # after we figure out each flatten dimension
        dummy_action = torch.zeros(1, action_in_shape[0], action_in_shape[1], action_in_shape[2])
        dummy_card   = torch.zeros(1, card_in_shape[0],   card_in_shape[1],   card_in_shape[2])
        
        # Pass through each conv to see resulting shape
        with torch.no_grad():
            act_out = self.action_conv(dummy_action)
            card_out = self.card_conv(dummy_card)
            # Flatten dimension
            self.act_conv_flat_size  = act_out.view(1, -1).size(1)
            self.card_conv_flat_size = card_out.view(1, -1).size(1)
        
        # Now define linear layers to get each branch to conv_out_dim
        self.action_fc = nn.Sequential(
            nn.Linear(self.act_conv_flat_size, conv_out_dim),
            nn.ReLU()
        )
        self.card_fc = nn.Sequential(
            nn.Linear(self.card_conv_flat_size, conv_out_dim),
            nn.ReLU()
        )
        
        # 3) Fusion FC layers
        # After concatenation, total input dim = conv_out_dim * 2
        fusion_in_dim = conv_out_dim * 2
        self.fusion_fc = nn.Sequential(
            nn.Linear(fusion_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # 4) Output heads: policy and value
        self.policy_head = nn.Linear(hidden_dim, num_actions)
        self.value_head  = nn.Linear(hidden_dim, 1)
    
    def forward(self, action_input, card_input):
        """
        Inputs:
            action_input: shape (B, 24, 4, 9) for example
            card_input:   shape (B, 6, 4, 13)
        Outputs:
            policy_logits: shape (B, num_actions)
            value:         shape (B, 1)
        """
        
        # ----- Branch A: Action Info -----
        x_a = self.action_conv(action_input)     # shape: (B, 64, H', W')
        x_a = x_a.view(x_a.size(0), -1)          # flatten
        x_a = self.action_fc(x_a)                # shape: (B, conv_out_dim)
        
        # ----- Branch B: Card Info -----
        x_c = self.card_conv(card_input)         # shape: (B, 64, H'', W'')
        x_c = x_c.view(x_c.size(0), -1)
        x_c = self.card_fc(x_c)                  # shape: (B, conv_out_dim)
        
        # ----- Fuse -----
        x = torch.cat([x_a, x_c], dim=1)         # shape: (B, 2*conv_out_dim)
        x = self.fusion_fc(x)                    # shape: (B, hidden_dim)
        
        # ----- Heads -----
        policy_logits = self.policy_head(x)      # shape: (B, num_actions)
        value         = self.value_head(x)       # shape: (B, 1)
        
        return policy_logits, value


In [19]:
def logits_to_probs(logits):
    """
    Convert raw policy logits to probabilities.
    This is a common softmax pattern.
    """
    return F.softmax(logits, dim=-1)

In [20]:
# Suppose batch_size=2 for quick test
batch_size = 1

# Create random input for action tensor: (B, 24, 4, 9)
action_input = torch.randn(batch_size, 24, 4, 9)
# Create random input for card tensor: (B, 6, 4, 13)
card_input = torch.randn(batch_size, 6, 4, 13)

model = PseudoSiameseNet(
    action_in_shape=(24, 4, 9),
    card_in_shape=(6, 4, 13),
    conv_out_dim=128,
    hidden_dim=256,
    num_actions=9
)

policy_logits, value = model(action_input, card_input)
debug_print("policy_logits shape:", policy_logits.shape)  # (B, 9)
debug_print("value shape:", value.shape)                  # (B, 1)

debug_print(policy_logits)
debug_print(value)

policy_logits shape: torch.Size([1, 9])
value shape: torch.Size([1, 1])
tensor([[-0.0289,  0.0285,  0.0359, -0.0546, -0.0115, -0.0646, -0.0314,  0.0004,
          0.0022]], grad_fn=<AddmmBackward0>)
tensor([[-0.0673]], grad_fn=<AddmmBackward0>)


In [29]:
def get_action_from_probs(probs):
    """Sample an action index given a numpy array of probabilities."""
    return np.random.choice(len(probs), p=probs)

In [30]:
import torch.optim as optim


##########################################################
# 4) DEMONSTRATION: TWO ITERATIONS WITH POLICY CHANGES
##########################################################

def build_card_rep():
    """Build the same card scenario each time."""
    cr = CardRepresentation()
    # Preflop: As,Ac
    cr.set_preflop([(12,3),(12,2)])
    # Flop: 8d,4s,Tc
    cr.set_flop([(7,1),(3,3),(9,2)])
    # Turn: 6h
    cr.set_turn((5,0))
    # River: Qs
    cr.set_river((11,3))
    return cr

def build_action_rep():
    """Initial action rep with 2 actions on Preflop for demonstration."""
    ar = ActionRepresentation(nb=9, max_actions_per_round=6, rounds=4)
    # Preflop action 0 => channel 0
    ar.add_action(0, 0, 0, 6, legal_actions=range(9))
    # Preflop action 1 => channel 1
    ar.add_action(0, 1, 1, 1, legal_actions=range(9))
    return ar

def to_torch_input(card_rep, action_rep):
    """Convert to shape (1,...) for model."""
    card_np = card_rep.card_tensor[np.newaxis,...]
    action_np = action_rep.action_tensor[np.newaxis,...]
    return torch.from_numpy(action_np).float(), torch.from_numpy(card_np).float()

def run_one_iteration(model, optimizer, iteration_idx):
    """
    Runs the SAME scenario: states=['Preflop','Flop','Turn','River','Showdown'],
    rewards=[-20, -20, -80, 0, 240].
    We'll do a single pass: compute a policy distribution from the model,
    pick an action, do the advantage, policy loss, value loss, THEN update the model.

    Because we do an update, the next iteration might yield different policy distributions
    => different actions.
    """
    debug_print(f"\n=== Iteration {iteration_idx} ===")
    # Same states & rewards
    states = ['Preflop','Flop','Turn','River','Showdown']
    rewards = [-20, -20, -80, 0, 240]
    
    # Build reps
    card_rep = build_card_rep()
    action_rep = build_action_rep()
    
    # We'll track total policy loss and value loss
    tc_loss_total = 0
    value_loss_total = 0
    steps_count = 0
    
    # We'll store the final distribution & chosen action
    chosen_actions = []
    policy_distributions = []
    
    # (like your code, only up to the 2nd-last state)
    for i, state in enumerate(states[:-1]):
        # A) GAE advantage
        rewards_from_now = rewards[i:]
        states_from_now  = states[i:]
        advantage = a_gae(rewards_from_now, states_from_now, value_function, 0.999, 0.99)
        
        # B) Model forward => get policy distribution
        action_t, card_t = to_torch_input(card_rep, action_rep)
        policy_logits, val_out = model(action_t, card_t)
        probs = logits_to_probs(policy_logits)[0].detach().numpy()  # shape (9,)
        
        # pick an action from this distribution
        action_idx = get_action_from_probs(probs)
        
        # store distribution + chosen action
        policy_distributions.append(probs)
        chosen_actions.append(action_idx)
        
        # C) ratio = 1 for demonstration (we're not storing old vs. new)
        r_val = 1.0
        
        # D) Trinal-Clip policy loss
        pol_loss_val = tc_loss_function(r_val, advantage, 0.2, get_deltas(state))
        
        # E) Value loss
        #    compute r_gamma from future rewards
        r_g = r_gamma(np.array(rewards_from_now), gamma=0.999)
        v_loss_val = v_loss(r_g, 0, get_deltas(state))
        
        # We'll do a minimal single-sample gradient step in PyTorch 
        # (in real code, you'd accumulate across the entire trajectory).
        # Let's define a combined loss = pol_loss + val_loss using the model's 
        # actual outputs.
        
        # For demonstration, let's just treat pol_loss_val & v_loss_val 
        # as if they are "ground truth" signals => 
        # We'll do a simple MSE against the model's value or something. 
        # It's not correct PPO, but it shows how the update changes the model.
        
        # Convert pol_loss_val & v_loss_val to torch
        pol_loss_tensor = torch.tensor(pol_loss_val, dtype=torch.float32, requires_grad=True)
        val_loss_tensor = torch.tensor(v_loss_val, dtype=torch.float32, requires_grad=True)
        
        # We'll define a dummy "training loss" that tries to push the model's 
        # value_head close to the negative of v_loss_val, and the policy_head 
        # close to some direction that correlates with pol_loss_val. 
        # This is purely for demonstration. The real PPO is more complicated.
        
        # Suppose we interpret pol_loss_val > 0 => we want to increase log(prob(action_idx)).
        # We'll do a negative log_prob for that action => partial surrogate.
        
        log_probs = F.log_softmax(policy_logits, dim=-1)[0]
        chosen_log_prob = log_probs[action_idx]
        
        # We'll do something like:
        # combined_loss = - pol_loss_val * chosen_log_prob + (val_out[0] - v_loss_val)**2
        # This isn't real PPO, but it ensures changes in the net if pol_loss_val or v_loss_val are large.
        combined_loss = - pol_loss_val * chosen_log_prob + (val_out[0] - val_loss_tensor)**2
        
        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()
        
        # accumulate
        tc_loss_total += pol_loss_val
        value_loss_total += v_loss_val
        steps_count += 1
        
        # For demonstration, let's add a single action to the ActionRepresentation
        # so it changes shape slightly
        action_rep.add_action(
            round_id=i, 
            action_index_in_round=0, 
            player_id=0,
            action_idx=action_idx,
            legal_actions=range(9),
            sum_idx=None
        )
        
        debug_print(f"  State={state}, advantage={advantage:.2f}, chosen_action={action_idx}")
        debug_print(f"    distribution={probs}")
        debug_print(f"    pol_loss_val={pol_loss_val:.3f}, val_loss_val={v_loss_val:.3f}")
    
    tc_loss_avg = tc_loss_total / max(1, steps_count)
    val_loss_avg = value_loss_total / max(1, steps_count)
    
    debug_print("\nActions chosen this iteration:", chosen_actions)
    debug_print("Policy distributions each step:")
    for dist in policy_distributions:
        debug_print(" ", dist.round(3))
    debug_print(f"\n=> iteration {iteration_idx} done. avg policy loss={tc_loss_avg:.3f}, avg value loss={val_loss_avg:.3f}")
    return chosen_actions, policy_distributions


if __name__ == "__main__":
    # Build model & optimizer
    model = PseudoSiameseNet()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # Run iteration 1
    debug_print("RUNNING ITERATION 1 ...")
    actions1, dists1 = run_one_iteration(model, optimizer, iteration_idx=1)
    
    # Run iteration 2 (same states & rewards => see if policy changes)
    debug_print("\nRUNNING ITERATION 2 (same scenario) ...")
    actions2, dists2 = run_one_iteration(model, optimizer, iteration_idx=2)
    
    # Compare chosen actions / distributions
    debug_print("\nCOMPARISON:")
    debug_print(" Iteration 1 chosen actions:", actions1)
    debug_print(" Iteration 2 chosen actions:", actions2)
    debug_print("\nYou may see differences in the chosen actions or in the distributions,")
    debug_print("which shows the policy was updated between iterations.\n")


RUNNING ITERATION 1 ...

=== Iteration 1 ===
reward:  240
reward:  0
reward:  -80
reward:  -20
  State=Preflop, advantage=-2.93, chosen_action=2
    distribution=[0.11144958 0.11221532 0.10817553 0.10764666 0.11327564 0.11123017
 0.11398275 0.11353203 0.10849233]
    pol_loss_val=-2.933, val_loss_val=100.000
reward:  240
reward:  0
reward:  -80
  State=Flop, advantage=-2.17, chosen_action=7
    distribution=[0.11164989 0.11254499 0.10774592 0.1077005  0.113539   0.11010551
 0.11415704 0.11402137 0.10853586]
    pol_loss_val=-2.169, val_loss_val=400.000
reward:  240
reward:  0
  State=Turn, advantage=-1.59, chosen_action=7
    distribution=[0.11223447 0.11286046 0.10725048 0.107858   0.11389577 0.10940454
 0.11449676 0.11322542 0.10877414]
    pol_loss_val=-1.592, val_loss_val=6400.000
reward:  240
  State=River, advantage=0.00, chosen_action=5
    distribution=[0.11293585 0.1133142  0.10618255 0.10805468 0.11445083 0.10903878
 0.11470548 0.11203111 0.10928655]
    pol_loss_val=0.000, v