## Training Trump players

In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.86"

In [2]:
# Some actual computation to wake up that lazy ass

import jax.numpy as jnp

x = jnp.ones((10000, 10000))
y = jnp.dot(x, x)

print(y)  # Make sure computation actually happens

[[10000. 10000. 10000. ... 10000. 10000. 10000.]
 [10000. 10000. 10000. ... 10000. 10000. 10000.]
 [10000. 10000. 10000. ... 10000. 10000. 10000.]
 ...
 [10000. 10000. 10000. ... 10000. 10000. 10000.]
 [10000. 10000. 10000. ... 10000. 10000. 10000.]
 [10000. 10000. 10000. ... 10000. 10000. 10000.]]


In [3]:
del x, y

In [4]:
from trump_utils import *
from cfvfp_main import DeepCFVFPSolver

In [None]:
solver = DeepCFVFPSolver(
    game_name="trump",
    q_value_network_models=q_value_network_models,
    avg_policy_network_models=avg_policy_network_models,
    info_state_tensor_transformers=info_state_tensor_transformers,
    action_transformers=action_transformers,
    phase_classifier_fn=trump_phase_classifier,
    dummy_infostate=np.array([dummy_infostate], dtype=np.float32),
    data_augmentors=data_augmentors,
    revelation_transformer=revelation_transformer,
    revelation_intensity=[0.1, 0.1],
    revelation_decay_mode='linear',
    num_iterations=150,
    num_iterations_q_per_pi=3,
    num_traversals_per_player=72,
    uniform=True,
    learning_rate=1e-4,
    batch_size_q_value=[512, 128, 2048], 
    batch_size_avg_policy=[512, 128, 1024], 
    q_value_network_train_steps=[500, 250, 1000],
    avg_policy_network_train_steps=[250, 250, 500],
    q_value_memory_capacity=[6e3, 1e3, 7.8e4],
    avg_policy_memory_capacity=[3e4, 5e3, 3.9e5],
    save_dir_buffers="cfvfp_buffers_z3-q-only",
    save_dir_nets="cfvfp_nets_z3-q-only",
    seed=2,
    num_workers=6
)

# 5. Run the training loop
print(f"\n--- Starting Trump game training with DeepCFVFPSolver ({solver._num_phases} phases) ---")
print(f"Global number of actions used by solver: {solver._global_num_actions}")

final_policy_params_by_phase, q_losses_by_phase, avg_policy_losses_by_phase = solver.solve()

print("--- Training finished ---")

In [None]:
solver = DeepCFVFPSolver(
    game_name="trump",
    q_value_network_models=q_value_network_models,
    avg_policy_network_models=avg_policy_network_models,
    info_state_tensor_transformers=[bid_transformer, ad_transformer, play_transformer_pi],
    action_transformers=action_transformers,
    phase_classifier_fn=trump_phase_classifier,
    dummy_infostate=np.array([dummy_infostate], dtype=np.float32),
    data_augmentors=data_augmentors,
    revelation_transformer=revelation_transformer,
    revelation_intensity=[0.2, 0.2],
    revelation_decay_mode='linear',
    num_iterations=600,
    num_iterations_q_per_pi=6,
    num_traversals_per_player=300,
    uniform=True,
    learning_rate=1e-4,
    batch_size_q_value=[2048, 512, 4096], 
    batch_size_avg_policy=[2048, 1024, 2048],
    q_value_network_train_steps=[1500, 120, 3000],
    avg_policy_network_train_steps=[500, 30, 800], 
    q_value_memory_capacity=[6e4, 2e4, 7.8e5],
    avg_policy_memory_capacity=[6e4, 2e4, 7.8e5],
    save_dir_buffers="cfvfp_buffers",
    save_dir_nets="cfvfp_nets",
    seed=1,
    num_workers=6
)

# 5. Run the training loop
print(f"\n--- Starting Trump game training with DeepCFVFPSolver ({solver._num_phases} phases) ---")
print(f"Global number of actions used by solver: {solver._global_num_actions}")

final_policy_params_by_phase, q_losses_by_phase, avg_policy_losses_by_phase = solver.solve()

print("--- Training finished ---")

In [10]:
# 6. Optional: Print or analyze the losses
print("\nQ-Value Losses (last value per player, per phase index):")
for phase, p_losses_dict in q_losses_by_phase.items():
    # Assuming phase_idx corresponds to the integer index (0, 1, 2)
    phase_name = PHASES[phase] # Use PHASES list to get the name
    print(f"  Phase '{phase_name}':")
    for player, losses_list in p_losses_dict.items():
        if losses_list:
            print(f"    Player {player}: {losses_list[-1]:.4f} (Avg: {np.mean(losses_list):.4f}, Min: {np.min(losses_list):.4f}, Max: {np.max(losses_list):.4f} over {len(losses_list)} entries)")
        else:
            print(f"    Player {player}: No Q-loss data")

print("\nAverage Policy Losses (last value per player, per phase index):")
for phase, p_losses_dict in avg_policy_losses_by_phase.items():
    phase_name = PHASES[phase]
    print(f"  Phase '{phase_name}':")
    for player, losses_list in p_losses_dict.items():
        if losses_list:
            print(f"    Player {player}: {losses_list[-1]:.4f} (Avg: {np.mean(losses_list):.4f}, Min: {np.min(losses_list):.4f}, Max: {np.max(losses_list):.4f} over {len(losses_list)} entries)")
        else:
            print(f"    Player {player}: No AvgPolicy-loss data")

print("\nSolver run complete.")


Q-Value Losses (last value per player, per phase index):
  Phase 'bid':
    Player 0: 17.6334 (Avg: 16.9620, Min: 16.2329, Max: 17.8518 over 30 entries)
  Phase 'ad':
    Player 0: 0.6131 (Avg: 0.5561, Min: 0.2232, Max: 1.0016 over 30 entries)
  Phase 'play':
    Player 0: 2.8631 (Avg: 2.7940, Min: 2.4560, Max: 2.9394 over 30 entries)

Average Policy Losses (last value per player, per phase index):
  Phase 'bid':
    Player 0: 0.0379 (Avg: 0.0270, Min: 0.0140, Max: 0.0379 over 10 entries)
  Phase 'ad':
    Player 0: 0.0053 (Avg: 0.0064, Min: 0.0008, Max: 0.0134 over 10 entries)
  Phase 'play':
    Player 0: 0.8269 (Avg: 0.8239, Min: 0.8185, Max: 0.8269 over 10 entries)

Solver run complete.


## Interactive Game with trained strategy displayed

In [3]:
import pyspiel

game = pyspiel.load_game("trump")

def print_game_info(game):
    game_type = game.get_type()
    print("Game Short Name:", game_type.short_name)
    print("Game Long Name:", game_type.long_name)
    print("Number of Players:", game.num_players())
    print("Min Utility:", game.min_utility())
    print("Max Utility:", game.max_utility())
    print("Max Game Length:", game.max_game_length())
    print("Tensor Shape (Infostate):", game.information_state_tensor_shape())
    
print_game_info(game)

Game Short Name: trump
Game Long Name: Trump
Number of Players: 4
Min Utility: -33.0
Max Utility: 28.0
Max Game Length: 120
Tensor Shape (Infostate): [588]


In [4]:
import jax
jax.config.update('jax_platform_name', 'cpu')

from trump_interactive import interactive_trump_game

state = interactive_trump_game(game)

=== STARTING INTERACTIVE TRUMP GAME ===

--- 1. Dealing Cards (Randomly) ---

--- Initial Hands Dealt (Player 0's perspective for their hand) ---
Player 0 Hand: C3, C5, CT, D4, D6, D7, D8, DA, HK, S2, S5, S7, ST
Player 1 Hand: C2, C7, CK, DK, H4, H5, H6, H7, H9, HT, S3, SJ, SK
Player 2 Hand: C6, C8, CA, D2, D3, H2, H3, H8, HA, HQ, S6, S8, S9
Player 3 Hand: C4, C9, CJ, CQ, D5, D9, DJ, DQ, DT, HJ, S4, SA, SQ

--- 2. Bidding Phase ---
Current state before bids:
Phase: Bidding
Current Player: P0
Hands (visible in full state string):
P0: C3 C5 CT DA D4 D6 D7 D8 HK S2 S5 S7 ST
P1: C2 C7 CK DK H4 H5 H6 H7 H9 HT S3 SJ SK
P2: CA C6 C8 D2 D3 HA H2 H3 H8 HQ S6 S8 S9
P3: C4 C9 CJ CQ D5 D9 DT DJ DQ HJ SA S4 SQ
Bid Cards Status: (All bids are hidden until revealed simultaneously after this phase)


Player 0's turn to bid.
  Legal moves for P0: C5:0.01 D4:0.03 HK:0.00 ST:0.00 S5:0.00 D6:0.02 D8:0.00 S2:0.00 C3:0.94 CT:0.00 D7:0.00 DA:0.00 S7:0.00
  A random choice: C3
Player 0 bids with: C3

Player 1

In [10]:
from trump_utils import print_graveyards

tensor = state.information_state_tensor(state.current_player())
print_graveyards(tensor)

Graveyards (Opponent Card Knowledge):
Values: -1=has, 0=unknown, 1=had, 2=never had

Opponent 1:
C: C2: 0 C3: 0 C4: 0 C5: 0 C6: 2 C7: 0 C8: 2 C9: 0 CT: 0 CJ: 0 CQ: 0 CK: 0 CA: 2 
D: D2: 2 D3: 0 D4: 0 D5:-1 D6: 2 D7: 0 D8: 2 D9: 0 DT: 0 DJ: 0 DQ: 0 DK: 2 DA: 0 
H: H2: 0 H3: 2 H4: 2 H5: 2 H6: 2 H7: 0 H8: 0 H9: 0 HT: 0 HJ: 0 HQ: 2 HK: 0 HA: 0 
S: S2: 2 S3: 2 S4: 2 S5: 0 S6: 1 S7: 2 S8: 0 S9: 0 ST: 2 SJ: 2 SQ: 0 SK: 0 SA: 1 

Opponent 2:
C: C2: 0 C3: 0 C4: 0 C5: 0 C6: 2 C7: 0 C8: 2 C9: 0 CT: 0 CJ: 0 CQ: 0 CK: 0 CA: 2 
D: D2: 2 D3: 0 D4: 0 D5: 2 D6: 2 D7: 0 D8: 2 D9: 0 DT: 0 DJ: 0 DQ: 0 DK: 2 DA: 0 
H: H2: 0 H3: 2 H4: 2 H5: 2 H6: 2 H7: 0 H8: 0 H9: 0 HT: 0 HJ: 0 HQ: 2 HK: 0 HA: 0 
S: S2: 2 S3: 2 S4: 1 S5: 0 S6: 2 S7: 2 S8: 0 S9: 0 ST: 1 SJ: 2 SQ: 0 SK: 0 SA: 2 

Opponent 3:
C: C2: 0 C3: 0 C4: 0 C5: 0 C6: 2 C7: 0 C8: 2 C9: 0 CT: 0 CJ: 0 CQ: 0 CK: 0 CA: 2 
D: D2: 2 D3: 0 D4: 0 D5: 2 D6:-1 D7: 0 D8: 2 D9: 0 DT: 0 DJ: 0 DQ: 0 DK: 2 DA: 0 
H: H2: 0 H3: 2 H4: 2 H5: 2 H6: 2 H7: 0 H8: 0 H9: 0 HT: 0 

In [None]:
print(tensor)

[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, -1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 12.0, 1.0, 0.0, 0.0, 0.0, 13.0, 1.0, 0.0, 0.0, 0.0, 11.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 

In [None]:
from trump_utils import revelation_transformer

tensor0 = state.information_state_tensor(0)
tensor1 = state.information_state_tensor(1)
tensor2 = state.information_state_tensor(2)
tensor3 = state.information_state_tensor(3)
all_tensors = [tensor0, tensor1, tensor2, tensor3]

tensor_revealed = revelation_transformer(all_tensors, state.current_player())
print_graveyards(tensor_revealed)

Graveyards (Opponent Card Knowledge):
Values: -1=has, 0=unknown, 1=had, 2=never had

Opponent 1:
C: C2: 1 C3: 2 C4: 2 C5: 2 C6: 2 C7: 2 C8: 2 C9:-1 CT:-1 CJ: 2 CQ: 2 CK:-1 CA:-1 
D: D2: 2 D3: 2 D4: 2 D5: 2 D6: 1 D7: 1 D8: 2 D9:-1 DT: 2 DJ: 2 DQ: 2 DK: 1 DA: 2 
H: H2: 2 H3:-1 H4: 2 H5: 2 H6: 2 H7: 2 H8: 2 H9: 2 HT: 2 HJ: 2 HQ: 2 HK: 2 HA: 2 
S: S2: 2 S3: 1 S4: 2 S5: 2 S6: 2 S7: 2 S8: 2 S9: 2 ST: 2 SJ: 1 SQ:-1 SK: 2 SA: 2 

Opponent 2:
C: C2: 2 C3: 1 C4: 2 C5: 2 C6: 2 C7: 2 C8: 2 C9: 2 CT: 2 CJ:-1 CQ: 2 CK: 2 CA: 2 
D: D2: 1 D3: 2 D4:-1 D5: 2 D6: 2 D7: 2 D8: 2 D9: 2 DT: 1 DJ: 2 DQ: 2 DK: 2 DA: 1 
H: H2:-1 H3: 2 H4: 2 H5: 2 H6:-1 H7:-1 H8: 2 H9: 2 HT: 2 HJ: 2 HQ: 2 HK: 2 HA: 2 
S: S2: 1 S3: 2 S4:-1 S5: 2 S6: 2 S7: 1 S8: 2 S9: 2 ST:-1 SJ: 2 SQ: 2 SK: 2 SA: 2 

Opponent 3:
C: C2: 2 C3: 2 C4:-1 C5:-1 C6: 1 C7: 2 C8: 2 C9: 2 CT: 2 CJ: 2 CQ:-1 CK: 2 CA: 2 
D: D2: 2 D3: 1 D4: 2 D5: 2 D6: 2 D7: 2 D8: 2 D9: 2 DT: 2 DJ: 1 DQ: 1 DK: 2 DA: 2 
H: H2: 2 H3: 2 H4: 2 H5:-1 H6: 2 H7: 2 H8: 2 H9: 2 HT: 2 

In [None]:
print(state.history())

[38, 11, 37, 5, 40, 12, 9, 16, 50, 15, 32, 47, 14, 28, 10, 24, 43, 26, 17, 2, 19, 48, 4, 35, 18, 0, 3, 29, 25, 1, 30, 41, 34, 44, 7, 46, 27, 8, 20, 33, 51, 42, 39, 49, 21, 6, 22, 23, 45, 31, 36, 13, 40, 28, 27, 42, 13, 16, 24, 25, 23, 15, 19, 20, 21, 14, 18, 8, 39, 50, 48, 41, 51, 40, 43, 44, 22, 5, 17, 34, 45, 47]


Traceback (most recent call last):
  File "/mnt/s/py_repos/my-drl-gaming/venv_wsl/lib/python3.11/site-packages/IPython/core/completer.py", line 3246, in _complete
    result = matcher(context)
             ^^^^^^^^^^^^^^^^
  File "/mnt/s/py_repos/my-drl-gaming/venv_wsl/lib/python3.11/site-packages/IPython/core/completer.py", line 2139, in magic_matcher
    matches = self.magic_matches(text)
              ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/s/py_repos/my-drl-gaming/venv_wsl/lib/python3.11/site-packages/IPython/core/completer.py", line 2172, in magic_matches
    global_matches = self.global_matches(bare_text)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/s/py_repos/my-drl-gaming/venv_wsl/lib/python3.11/site-packages/IPython/core/completer.py", line 1114, in global_matches
    for word in lst:
RuntimeError: dictionary changed size during iteration


In [3]:
from trump_z3_bool_current_hand import analyze_hand_cards
import numpy as np

tensor = np.array(state.information_state_tensor(state.current_player()))
analyze_hand_cards(tensor, debug=True)

3♣ loses due to not following suit.
6♣ loses due to not following suit.
Q♦ loses due to not following suit.
K♦ loses due to not following suit.
A valid play where player plays 5♠:
Opponent Current Hands:
  P1: J♦ 8♥ 4♠ J♠ Q♠ A♠
  P2: 3♦ 8♦ T♦ Q♥ 7♠ 9♠
  P3: 5♦ 7♦ 2♥ 3♥ 9♥ K♠
Current Trick:    P1:8♥ P2:Q♥ P3:9♥
A scenario where player loses with 5♠:
Opponent Current Hands:
  P1: 7♦ 8♦ 7♠ 9♠ Q♠ A♠
  P2: T♦ 2♥ 3♥ 8♥ Q♥ K♠
  P3: 3♦ 5♦ J♦ 9♥ 4♠ J♠
Current Trick:    P1:7♠ P2:Q♥ P3:9♥
A scenario where player wins with 5♠:
Opponent Current Hands:
  P1: 7♦ 8♦ 4♠ 7♠ 9♠ Q♠
  P2: T♦ 2♥ 3♥ 8♥ Q♥ A♠
  P3: 3♦ 5♦ J♦ 9♥ J♠ K♠
Current Trick:    P1:4♠ P2:Q♥ P3:9♥
A valid play where player plays T♠:
Opponent Current Hands:
  P1: 7♦ 8♦ 4♠ 7♠ 9♠ Q♠
  P2: T♦ 2♥ 3♥ 8♥ Q♥ A♠
  P3: 3♦ 5♦ J♦ 9♥ J♠ K♠
Current Trick:    P1:4♠ P2:Q♥ P3:9♥
A scenario where player loses with T♠:
Opponent Current Hands:
  P1: 7♦ 8♦ 4♠ 7♠ Q♠ K♠
  P2: T♦ 2♥ 3♥ 8♥ Q♥ A♠
  P3: 3♦ 5♦ J♦ 9♥ 9♠ J♠
Current Trick:    P1:K♠ P2:Q♥ P3:9♥
A scenar

[-1,
 -1,
 1,
 -1,
 -1,
 1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 1,
 1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 3,
 -1,
 -1,
 -1,
 -1,
 3,
 -1,
 -1,
 -1]

In [None]:
import numpy as np
from trump_utils import reveal_p0
from trump_z3_bool_current_hand import analyze_hand_cards

tensor_p1 = np.array(state.information_state_tensor(0))
tensor_p2 = np.array(state.information_state_tensor(1))
tensor_p3 = np.array(state.information_state_tensor(2))

tensor_all = np.array([tensor, tensor_p1, tensor_p2, tensor_p3])

tensor_revealed = reveal_p0(tensor_all)
analyze_hand_cards(tensor_revealed, debug=True, max_timeout_ms=None)

3♣ loses to 6♣
5♣ loses to 6♣
A valid play where player plays T♣:
Opponent Current Hands:
  P1: 4♣ 2♦ 5♦ 6♦ 7♦ 9♦ T♥ 2♠ 3♠ 4♠ T♠
  P2: 9♣ 4♦ Q♦ K♦ 7♥ K♥ 5♠ 8♠ 9♠ Q♠ K♠
  P3: 6♣ Q♣ A♣ 3♦ J♦ 5♥ 6♥ 8♥ J♥ J♠ A♠
Current Trick:    P1:4♣ P2:9♣ P3:6♣
A valid play where player plays K♣:
Opponent Current Hands:
  P1: 4♣ 2♦ 5♦ 6♦ 7♦ 9♦ T♥ 2♠ 3♠ 4♠ T♠
  P2: 9♣ 4♦ Q♦ K♦ 7♥ K♥ 5♠ 8♠ 9♠ Q♠ K♠
  P3: 6♣ Q♣ A♣ 3♦ J♦ 5♥ 6♥ 8♥ J♥ J♠ A♠
Current Trick:    P1:4♣ P2:9♣ P3:6♣


[-1,
 -1,
 1,
 -1,
 1,
 -1,
 -1,
 -1,
 -1,
 4,
 -1,
 -1,
 4,
 0,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 0,
 -1,
 0,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 0,
 -1,
 -1,
 0,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 0,
 0,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1]

In [None]:
from trump_utils import print_graveyards

print_graveyards(tensor_revealed)

Graveyards (Opponent Card Knowledge):
Values: -1=has, 0=unknown, 1=had, 2=never had

Opponent 1:
C: C2:-1 C3: 2 C4: 2 C5: 2 C6: 2 C7:-1 C8: 2 C9: 2 CT: 2 CJ: 2 CQ: 2 CK:-1 CA: 2 
D: D2:-1 D3: 2 D4: 2 D5: 1 D6: 2 D7: 2 D8: 2 D9: 2 DT: 2 DJ: 2 DQ:-1 DK: 2 DA:-1 
H: H2: 2 H3:-1 H4: 2 H5: 2 H6: 2 H7:-1 H8:-1 H9: 2 HT: 2 HJ: 2 HQ: 2 HK: 2 HA: 2 
S: S2: 2 S3: 2 S4: 2 S5:-1 S6: 2 S7:-1 S8: 2 S9: 2 ST: 2 SJ:-1 SQ: 2 SK: 2 SA: 2 

Opponent 2:
C: C2: 2 C3: 2 C4: 2 C5: 2 C6: 2 C7: 2 C8: 2 C9:-1 CT: 2 CJ:-1 CQ: 2 CK: 2 CA: 2 
D: D2: 2 D3: 2 D4: 1 D5: 2 D6: 2 D7: 2 D8: 2 D9: 2 DT:-1 DJ: 2 DQ: 2 DK: 2 DA: 2 
H: H2: 2 H3: 2 H4: 2 H5:-1 H6:-1 H7: 2 H8: 2 H9: 2 HT: 2 HJ:-1 HQ:-1 HK: 2 HA: 2 
S: S2:-1 S3:-1 S4: 2 S5: 2 S6:-1 S7: 2 S8:-1 S9: 2 ST:-1 SJ: 2 SQ: 2 SK: 2 SA: 2 

Opponent 3:
C: C2: 2 C3: 2 C4:-1 C5:-1 C6: 2 C7: 2 C8: 2 C9: 2 CT:-1 CJ: 2 CQ:-1 CK: 2 CA:-1 
D: D2: 2 D3: 2 D4: 2 D5: 2 D6: 2 D7: 1 D8:-1 D9:-1 DT: 2 DJ: 2 DQ: 2 DK: 2 DA: 2 
H: H2: 2 H3: 2 H4: 2 H5: 2 H6: 2 H7: 2 H8: 2 H9:-1 HT:-1 

In [4]:
from trump_utils import data_augmentor
import jax.numpy as jnp

tensor_permuted = data_augmentor(jnp.asarray(tensor), jax.random.PRNGKey(22))
analyze_hand_cards(tensor_permuted, debug=True)

Q♣ loses due to not following suit.
K♣ loses due to not following suit.
3♠ loses due to not following suit.
6♠ loses due to not following suit.
A valid play where player plays 5♦:
Opponent Current Hands:
  P1: J♣ J♦ Q♦ 2♥ 3♥ 8♥
  P2: 5♣ 7♣ T♣ 4♦ 7♦ Q♥
  P3: 3♣ 8♣ 9♦ K♦ A♦ 9♥
Current Trick:    P1:3♥ P2:Q♥ P3:9♥
A scenario where player loses with 5♦:
Opponent Current Hands:
  P1: 7♣ 8♣ 7♦ Q♦ K♦ A♦
  P2: 5♣ 9♦ J♦ 3♥ 8♥ Q♥
  P3: 3♣ T♣ J♣ 4♦ 2♥ 9♥
Current Trick:    P1:7♦ P2:Q♥ P3:9♥
A scenario where player wins with 5♦:
Opponent Current Hands:
  P1: 7♣ 8♣ 4♦ 7♦ K♦ A♦
  P2: 9♦ J♦ Q♦ 3♥ 8♥ Q♥
  P3: 3♣ 5♣ T♣ J♣ 2♥ 9♥
Current Trick:    P1:4♦ P2:Q♥ P3:9♥
A valid play where player plays T♦:
Opponent Current Hands:
  P1: 7♣ 8♣ 4♦ 7♦ K♦ A♦
  P2: 9♦ J♦ Q♦ 3♥ 8♥ Q♥
  P3: 3♣ 5♣ T♣ J♣ 2♥ 9♥
Current Trick:    P1:4♦ P2:Q♥ P3:9♥
A scenario where player loses with T♦:
Opponent Current Hands:
  P1: 7♣ 8♣ 4♦ Q♦ K♦ A♦
  P2: 7♦ 9♦ J♦ 3♥ 8♥ Q♥
  P3: 3♣ 5♣ T♣ J♣ 2♥ 9♥
Current Trick:    P1:Q♦ P2:Q♥ P3:9♥
A scenar

[-1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 1,
 1,
 -1,
 -1,
 -1,
 -1,
 3,
 -1,
 -1,
 -1,
 -1,
 3,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 1,
 -1,
 -1,
 1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1,
 -1]

In [12]:
from trump_utils import print_graveyards

tensor = state.information_state_tensor(state.current_player())
print_graveyards(np.array(tensor_permuted))

Graveyards (Opponent Card Knowledge):
Values: -1=has, 0=unknown, 1=had, 2=never had

Opponent 1:
C: C2: 0 C3: 2 C4: 2 C5: 2 C6: 2 C7: 0 C8: 0 C9: 0 CT: 0 CJ: 0 CQ: 2 CK: 0 CA: 0 
D: D2: 2 D3: 2 D4: 2 D5: 0 D6: 1 D7: 2 D8: 0 D9: 0 DT: 2 DJ: 2 DQ: 0 DK: 0 DA: 1 
H: H2: 0 H3: 0 H4: 0 H5: 0 H6: 2 H7: 0 H8: 2 H9: 0 HT: 0 HJ: 0 HQ: 0 HK: 0 HA: 2 
S: S2: 2 S3: 0 S4: 0 S5:-1 S6: 2 S7: 0 S8: 2 S9: 0 ST: 0 SJ: 0 SQ: 0 SK: 2 SA: 0 

Opponent 2:
C: C2: 0 C3: 2 C4: 2 C5: 2 C6: 2 C7: 0 C8: 0 C9: 0 CT: 0 CJ: 0 CQ: 2 CK: 0 CA: 0 
D: D2: 2 D3: 2 D4: 1 D5: 0 D6: 2 D7: 2 D8: 0 D9: 0 DT: 1 DJ: 2 DQ: 0 DK: 0 DA: 2 
H: H2: 0 H3: 0 H4: 0 H5: 0 H6: 2 H7: 0 H8: 2 H9: 0 HT: 0 HJ: 0 HQ: 0 HK: 0 HA: 2 
S: S2: 2 S3: 0 S4: 0 S5: 2 S6: 2 S7: 0 S8: 2 S9: 0 ST: 0 SJ: 0 SQ: 0 SK: 2 SA: 0 

Opponent 3:
C: C2: 0 C3: 2 C4: 2 C5: 2 C6: 2 C7: 0 C8: 0 C9: 0 CT: 0 CJ: 0 CQ: 2 CK: 0 CA: 0 
D: D2: 2 D3: 1 D4: 2 D5: 0 D6: 2 D7: 2 D8: 0 D9: 0 DT: 2 DJ: 1 DQ: 0 DK: 0 DA: 2 
H: H2: 0 H3: 0 H4: 0 H5: 0 H6: 2 H7: 0 H8: 2 H9: 0 HT: 0 

In [None]:
print(tensor)

[0.0, 1.0, 1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 1.0, 0.0, 3.0, 1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,

In [None]:
len(tensor)

588

## Players from different iterations playing against each other

In [None]:
import os
import numpy as np
import jax
jax.config.update('jax_platform_name', 'cpu')
import jax.numpy as jnp
import flax
import flax.linen as nn
import pyspiel
from typing import List, Dict, Any, Callable
from collections import defaultdict
from tqdm.notebook import tqdm

def load_params(save_dir_nets, player, phase, iteration):
    """Load policy parameters for a player, phase, and iteration."""
    pi_params_path = os.path.join(save_dir_nets, f"player{player}", f"phase{phase}", "pi_data", f"pi_params_iter{iteration}.msgpack")
    if not os.path.exists(pi_params_path):
        raise FileNotFoundError(f"Parameter file not found: {pi_params_path}")
    with open(pi_params_path, 'rb') as f:
        state_dict = flax.serialization.from_bytes(None, f.read())
    if 'params' not in state_dict:
        raise KeyError(f"'params' key not found in loaded state_dict from {pi_params_path}")
    return state_dict['params']

def evaluate_players(
    game_name: str,
    save_dir_nets: str,
    iterations_per_player: List[int],
    num_eval_games: int,
    pi_models: List[nn.Module],
    info_state_tensor_transformers: List[Callable[[jax.Array], jax.Array]],
    action_transformers: List[Callable[[jax.Array], jax.Array]],
    phase_classifier_fn: Callable[[jax.Array], int],
    uniform: bool = False,
    seed: int = 42
) -> Dict[int, float]:
    """
    Evaluate players by having them play against each other.
    
    Args:
        game_name: Name of the game (e.g., 'kuhn_poker')
        save_dir_nets: Directory where network parameters are saved
        iterations_per_player: List of iteration numbers for each player [player0_iter, player1_iter, ...]
        num_eval_games: Number of games to play for evaluation
        pi_models: List of policy network models for each phase
        info_state_tensor_transformers: List of transformers for each phase
        action_transformers: List of action transformers for each phase
        phase_classifier_fn: Function to classify game phase from info state
        seed: Random seed for reproducibility
        
    Returns:
        Dictionary mapping player_id to average score across all games
    """
    # Initialize game and random state
    game = pyspiel.load_game(game_name)
    num_players = game.num_players()
    global_num_actions = game.num_distinct_actions()
    np.random.seed(seed)

    num_phases = len(pi_models)
    
    if len(iterations_per_player) != num_players:
        raise ValueError(f"Number of iterations ({len(iterations_per_player)}) must match number of players ({num_players})")
    
    # Load all player parameters for all phases
    player_params = []
    for player in range(num_players):
        player_phase_params = []
        for phase in range(num_phases):
            params = load_params(save_dir_nets, player if not uniform else 0, phase, iterations_per_player[player])
            player_phase_params.append(params)
        player_params.append(player_phase_params)
    
    # Setup JIT compiled inference functions for all players and phases
    jitted_inference_pi = []
    for phase in range(num_phases):
        phase_inferences = []
        for player in range(num_players if not uniform else 1):
            inference_fn = _get_jitted_avg_policy(pi_models[phase], action_transformers[phase])
            phase_inferences.append(inference_fn)
        jitted_inference_pi.append(phase_inferences)
    
    # Play evaluation games
    total_scores = defaultdict(float)
    
    for _game_round in tqdm(range(num_eval_games)):
        state = game.new_initial_state()
        
        while not state.is_terminal():
            if state.is_chance_node():
                # Handle chance nodes
                chance_outcome_actions, chance_outcome_probs = zip(*state.chance_outcomes())
                chance_outcome_probs_np = np.array(chance_outcome_probs, dtype=np.float64)
                chance_outcome_probs_np /= np.sum(chance_outcome_probs_np)
                sampled_action = np.random.choice(chance_outcome_actions, p=chance_outcome_probs_np)
                state = state.child(sampled_action)
            else:
                # Handle player decision nodes
                active_player = state.current_player()
                
                # Get info state and determine phase
                full_info_state_np = np.array(state.information_state_tensor(active_player), dtype=np.float32)
                phase = phase_classifier_fn(full_info_state_np)
                
                # Transform info state for this phase
                info_state_transformed_np = info_state_tensor_transformers[phase](full_info_state_np)
                legal_actions_mask_global_np = np.array(state.legal_actions_mask(active_player), dtype=bool)
                
                # Add batch dimensions for network inference
                info_state_transformed_np = jnp.expand_dims(info_state_transformed_np, axis=0)
                legal_actions_mask_global_np = jnp.expand_dims(legal_actions_mask_global_np, axis=0)
                
                # Get policy probabilities from the network
                policy_probs_global = jitted_inference_pi[phase][active_player if not uniform else 0](
                    player_params[active_player if not uniform else 0][phase],
                    info_state_transformed_np,
                    legal_actions_mask_global_np
                )
                
                # Convert to numpy and ensure proper normalization
                policy_probs_np_global = np.array(policy_probs_global)
                policy_probs_np_global /= np.sum(policy_probs_np_global)
                
                # Sample action according to policy
                sampled_action = np.random.choice(global_num_actions, p=policy_probs_np_global)
                state = state.child(sampled_action)
        
        # Return final scores/returns for all players
        returns = state.returns()
        game_scores = {player: returns[player] for player in range(len(returns))}
        
        # Accumulate scores
        for player, score in game_scores.items():
            total_scores[player] += score
    
    # Calculate average scores
    average_scores = {player_idx: total_scores[player_idx] / num_eval_games 
                     for player_idx in range(num_players)}
    
    return average_scores

def _get_jitted_avg_policy(pi_model_instance, action_transformer_fn):
    """Create JIT compiled policy inference function."""
    @jax.jit
    def get_policy(params_avg_policy: Any, info_state_transformed: jax.Array, legal_actions_mask_global: jax.Array):
        masked_logits_net_output = pi_model_instance.apply(
            {'params': params_avg_policy}, info_state_transformed, legal_actions_mask_global
        )
        masked_logits_global = action_transformer_fn(masked_logits_net_output)
        avg_policy_probs_global = jax.nn.softmax(masked_logits_global, axis=-1)
        return jnp.squeeze(avg_policy_probs_global, axis=0)
    return get_policy

In [None]:
from trump_utils import (
    PolicyNetworkWrapper,
    TrumpBiddingPolicyNet, TrumpAD_PolicyNet, TrumpPlayPolicyNet,
    action_transformers, #info_state_tensor_transformers,
    trump_phase_classifier
)

# --- Tensor Specification, Slicing Helper, and Phase Detection Indices ---
TENSOR_COMPONENT_SPEC = [
    ("Hand", 52), 
    ("BidCards", 4 * 5),       # 20
    ("TrumpSuit", 4),          
    ("RoundBidStatus", 1),     
    ("History", 13 * (4 + 4 * 5)),   # 312
    ("OpponentGraveyard", 3 * 52), # 156
    ("ANTC", 4),               
    ("BreakOccurred", 1),      
    ("CurrentTrickCards", 4 * 5), # 20
    ("CurrentTrickLeader", 4), 
    ("CurrentTrickTrumpUncertainty", 13),    # NEW ITEM #11
    ("CurrentTrickNumber", 1)  
]

GLOBAL_NUM_ACTIONS = 52

_tensor_component_slices: Dict[str, slice] = {}
_current_offset = 0
for _name, _size in TENSOR_COMPONENT_SPEC:
    _tensor_component_slices[_name] = slice(_current_offset, _current_offset + _size)
    _current_offset += _size

# Define indices for phase classification using the slices
INDEX_ROUND_BID_STATUS = _tensor_component_slices["RoundBidStatus"].start
INDEX_FIRST_BID_CARD_FEATURE = _tensor_component_slices["BidCards"].start
INDEX_HAND_START = _tensor_component_slices["Hand"].start
INDEX_HAND_END = _tensor_component_slices["Hand"].stop
INDEX_BID_CARDS_END = _tensor_component_slices["BidCards"].stop

def bid_transformer(infostate_tensor: jax.Array) -> jax.Array:
    return infostate_tensor[..., INDEX_HAND_START:INDEX_HAND_END]

def ad_transformer(infostate_tensor: jax.Array) -> jax.Array:
    return infostate_tensor[..., INDEX_HAND_START:INDEX_BID_CARDS_END]

def play_transformer_pi(infostate_tensor: jax.Array) -> jax.Array:
    return infostate_tensor
info_state_tensor_transformers = [bid_transformer, ad_transformer, play_transformer_pi]

pi_models = [
    PolicyNetworkWrapper(phase_net=TrumpBiddingPolicyNet()),  # Phase 0
    PolicyNetworkWrapper(phase_net=TrumpAD_PolicyNet()),      # Phase 1
    PolicyNetworkWrapper(phase_net=TrumpPlayPolicyNet())      # Phase 2
]

evaluate_players(
    game_name="trump",
    save_dir_nets="cfvfp_nets_z3-q-only",
    iterations_per_player=[300, 300, 300, 700],
    uniform=True,
    num_eval_games=1000,
    pi_models=pi_models,
    info_state_tensor_transformers=info_state_tensor_transformers,
    action_transformers=action_transformers,
    phase_classifier_fn=trump_phase_classifier,
    seed=256
)

  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: 