In [96]:
import ray
from ray.rllib.env import PettingZooEnv
from ray.tune.registry import register_env
from pettingzoo.classic import texas_holdem_v4

import gymnasium as gym

RAISE_ACTION = 1 
RAISE_PENALTY = 0.5
# In here penalize the reward by making it expensive a bit
class RaisePenaltyWrapper(PettingZooEnv):
    def __init__(self, env, positive_penalize=0.5, negative_penalize=1.5):
        super().__init__(env)
        self.positive_penalize = positive_penalize
        self.negative_penalize = negative_penalize

    def step(self, action_dict):
        obs, rewards, terminations, truncations, infos = super().step(action_dict)

        new_rewards = {}

        for agent_id, r in rewards.items():
            a = action_dict.get(agent_id, None)

            # 1) small penalty per raise
            if a == RAISE_ACTION:
                r -= 0.1

            # 2) asymmetric win/loss scaling at terminal
            done = terminations.get(agent_id, False) or truncations.get(agent_id, False)
            if done:
                if r > 0:
                    r = self.positive_penalize * r
                elif r < 0:
                    r = self.negative_penalize * r

            new_rewards[agent_id] = r
        return obs, new_rewards, terminations, truncations, infos


def env_creator(config=None):
    # raw PettingZoo AEC env
    env = texas_holdem_v4.env()
    env.reset(seed=(config or {}).get("seed", None))
    # RLlib wrapper so it looks like a MultiAgentEnv
    #env = PettingZooEnv(env)
    wrapper_env = RaisePenaltyWrapper(
        env,
        positive_penalize=0.75,
        negative_penalize=1.5,
    )
    return wrapper_env

env_name = "texas_holdem_v4"
register_env(env_name, lambda config: env_creator(config))

# Create one test env to grab spaces
test_env = env_creator()
obs_space = test_env.observation_space
act_space = test_env.action_space

In [97]:
test_env = texas_holdem_v4.env(render_mode="human")
test_env.reset()
obs, reward, termination, truncation, info = test_env.last()

print("obs['observation'].shape:", obs["observation"].shape)
print("obs['action_mask']:", obs["action_mask"])

obs['observation'].shape: (72,)
obs['action_mask']: [1 1 1 0]


In [6]:
from pettingzoo.classic import texas_holdem_v4

env = texas_holdem_v4.env(render_mode="human")
env.reset(seed=42)

for agent in env.agent_iter():
    observation, reward, termination, truncation, info = env.last()

    if termination or truncation:
        action = None
    else:
        mask = observation["action_mask"]
        # this is where you would insert your policy
        action = env.action_space(agent).sample(mask)

    env.step(action)
env.close()


[2025-11-25 21:54:38,726 E 6405 4628719] core_worker_process.cc:837: Failed to establish connection to the metrics exporter agent. Metrics will not be exported. Exporter agent status: RpcError: Running out of retries to initialize the metrics agent. rpc_code: 14


In [3]:
obs_space

Dict('player_0': Dict('action_mask': Box(0, 1, (4,), int8), 'observation': Box(0.0, 1.0, (72,), float32)), 'player_1': Dict('action_mask': Box(0, 1, (4,), int8), 'observation': Box(0.0, 1.0, (72,), float32)))

In [42]:
act_space

Dict('player_0': Discrete(4), 'player_1': Discrete(4))

In [None]:
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.utils.framework import try_import_torch
import gymnasium as gym

torch, nn = try_import_torch()

import numpy as np
import gymnasium as gym

from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.models import ModelCatalog
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.core.rl_module.rl_module import RLModule


torch, nn = try_import_torch()

from ray.rllib.core.rl_module.torch.torch_rl_module import TorchRLModule
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

class MaskedDQNTorchRLModule(TorchRLModule):
    def __init__(self, observation_space, action_space, model_config, **kwargs):
        super().__init__(
            observation_space=observation_space,
            action_space=action_space,
            model_config=model_config,
            **kwargs,
        )

        # obs_space is a Dict with keys "observation" and "action_mask"
        obs_box = observation_space["observation"]
        self.mask_dim = action_space.n

        # Use obs_box directly, no manual flatten-with-mask trick
        obs_low = obs_box.low
        obs_high = obs_box.high

        self.base_model = FullyConnectedNetwork(
            obs_space=obs_box,
            action_space=action_space,
            num_outputs=self.mask_dim,
            model_config=model_config,
            name="masked_dqn_fcnet",
        )

    def _forward_inference(self, batch, **kwargs):
        return self._forward_masked(batch, explore=False)

    def _forward_exploration(self, batch, **kwargs):
        return self._forward_masked(batch, explore=True)

    def _forward_train(self, batch, **kwargs):
        # Training uses Q-values only
        out = self._forward_masked(batch, explore=False)
        return {"q_values": out["q_values"]}

    # def _forward_single(self, batch):
    #     out = {}
    #     for agent_id, obs in batch.items():
    #         out[agent_id] = self._forward_single(obs)
    #     return out

    def _to_tensor(self, x):
        if isinstance(x, torch.Tensor):
            t = x
        else:
            t = torch.as_tensor(x, dtype=torch.float32)

        if t.dim() == 1:
            t = t.unsqueeze(0)

        return t.to(next(self.parameters()).device)

    def _forward_masked(self, batch, explore: bool):
        # batch["obs"] is Dict-space: {"observation": ..., "action_mask": ...}
        # out = {}
        # for agent_id, obs in batch.items():
        #     out[agent_id] = self._forward_single(obs)
        obs_dict = batch["obs"]
        #print(obs_dict)

        obs = obs_dict["observation"]#.float()
        raw_mask = obs_dict["action_mask"]
        mask = self._to_tensor(raw_mask)

        obs = self._to_tensor(obs)    # [B, obs_dim]

        # Q-values from base FC network
        logits, _ = self.base_model({"obs": obs}, [], None)

        # Convert mask to 0/1
        legal = (mask > 0.5).float()

        # Mask illegal actions with -inf
        inf = torch.finfo(logits.dtype).min
        masked_q = logits + (1.0 - legal) * inf
        print(masked_q)

        if explore:
            # epsilon-greedy sampling
            eps = 0.05  # you can tune
            rand = torch.rand(masked_q.shape[0], device=masked_q.device)
            greedy = torch.argmax(masked_q, dim=-1)

            # random legal action
            legal_indices = legal.nonzero(as_tuple=False)
            random_actions = torch.zeros_like(greedy)
            for i in range(mask.shape[0]):
                valid = torch.where(mask[i] > 0.5)[0]
                random_actions[i] = valid[torch.randint(len(valid), (1,))]

            actions = torch.where(rand < eps, random_actions, greedy)

        else:
            # greedy action
            actions = torch.argmax(masked_q, dim=-1)

        # RLlib DQN Learner EXPECTS this structure:
        return {
            "q_values": masked_q,
            "action_distribution_inputs": masked_q,  # DQN uses Q-values as dist inputs
            "actions": actions,
            "actions_for_env": actions,  # connectors need this
        }




In [95]:
from ray.rllib.algorithms.dqn import DQNConfig
from ray.rllib.algorithms.ppo import PPOConfig
import gymnasium as gym
from ray.rllib.policy.policy import PolicySpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec


config = DQNConfig()
config.api_stack(  # <-- turn ON the new stack explicitly
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
config = (
    config
    .environment(env=env_name)
    .framework("torch")
    .env_runners(num_env_runners=0, create_env_on_local_worker=True)
    .rl_module(
        rl_module_spec= RLModuleSpec(
        module_class=MaskedDQNTorchRLModule, 
        model_config={}, 
        )
    )
    .training(
        replay_buffer_config = {
            "type": "MultiAgentPrioritizedReplayBuffer",
            "capacity": 50_000,
            "prioritized_replay_alpha": 0.6,
            "prioritized_replay_beta": 0.4,
            "prioritized_replay_eps": 1e-6,
        },
        gamma=0.99,
        lr=1e-3,
        train_batch_size=64,
        minibatch_size=32,
        #entropy_coeff=0.0001,  #penalize the deterministic policy
        #entropy_coeff_schedule=None,
        # model={
        #     #"custom_model": FlatActionMaskingModel,
        #     "use_lstm": True,
        #     "fcnet_hiddens": [512, 512],
        #     "fcnet_activation": "tanh",
        #     "_disable_preprocessor_api": True,
        # }
        
    )
    .multi_agent(
        # policies={
        #     "shared_policy": PolicySpec(
        #         observation_space=obs_space["player_0"],
        #         action_space=act_space["player_0"],
        #         config={},
        #     ),
        # },
        # policy_mapping_fn=lambda agent_id, *a, **k: "shared_policy",

        policies={
            "player_0": PolicySpec(observation_space=obs_space["player_0"],
                action_space=act_space['player_0'],
                config={},),
            "player_1": PolicySpec(observation_space=obs_space["player_0"],
                action_space=act_space['player_1'],
                config={},),
        },
        # in this env the agent_ids are literally "player_0"/"player_1"
        policy_mapping_fn=lambda agent_id, *a, **k: agent_id,
    )
    #.experimental(_validate_config=False)
)


algo = config.build()


`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
  logger.warn(f"Overriding environment {new_spec.id} already in registry.")


obs['action_mask'] contains a mask of all legal moves that can be chosen.




In [46]:
obs["player_0"]

{'observation': array([0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
        0., 0., 0., 0.], dtype=float32),
 'action_mask': array([1, 1, 1, 0], dtype=int8)}

In [84]:
obs, info = env.reset()

player0_input = {"obs": obs["player_0"]}

output = algo.get_module("player_0").forward_exploration(player0_input, explore = False)
print(output)



tensor([[-3.7183e-03, -2.4676e-03, -2.7701e-03, -3.4028e+38]],
       grad_fn=<AddBackward0>)
{'q_values': tensor([[-3.7183e-03, -2.4676e-03, -2.7701e-03, -3.4028e+38]],
       grad_fn=<AddBackward0>), 'action_distribution_inputs': tensor([[-3.7183e-03, -2.4676e-03, -2.7701e-03, -3.4028e+38]],
       grad_fn=<AddBackward0>), 'actions': tensor([1]), 'actions_for_env': tensor([1])}


In [98]:
import collections

env = env_creator({})

def select_action(module, agent_obs, explore=False):
    # Build batch of size 1 for RLModule
    batch = {
        "obs": {
            "observation": torch.tensor(agent_obs["observation"])[None, :].float(),
            "action_mask": torch.tensor(agent_obs["action_mask"])[None, :].float(),
        }
    }

    with torch.no_grad():
        out = (
            module.forward_exploration(batch)
            if explore
            else module.forward_inference(batch)
        )

    # DQN RLModule emits q_values
    qvals = out["q_values"]        # shape [1, num_actions]
    return int(torch.argmax(qvals, dim=-1)[0].item())


def eval_one_episode(render=False):
    obs, info = env.reset()
    done = False
    modules = {
        "player_0": algo.get_module("player_0"),
        "player_1": algo.get_module("player_1"),
    }
    # Start empty; we'll add agents as we see them
    ep_reward = collections.defaultdict(float)
    length = 0

    while not done:
        if render:
            env.render()

        actions = {}
        print(obs.items())
        for agent_id, agent_obs in obs.items():
            action = select_action(modules[agent_id], agent_obs)
            print(action)
            actions[agent_id] = int(action)  # Discrete(4) -> int

        obs, rewards, terminated, truncated, infos = env.step(actions)
        length += 1

        # accumulate rewards; skip special keys like "__all__"
        for aid, r in rewards.items():
            if aid == "__all__":
                continue
            ep_reward[aid] += r

        # episode done: use "__all__" if provided, else fall back
        if "__all__" in terminated:
            done = terminated["__all__"] or truncated.get("__all__", False)
        else:
            done = all(terminated.values()) or all(truncated.values())

    return dict(ep_reward), length

# Run ONE eval episode
ep_rew, ep_len = eval_one_episode(render=False)
print("Episode length:", ep_len)
print("Episode rewards:", ep_rew)


dict_items([('player_1', {'observation': array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.,
       0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
       0., 0., 0., 0.], dtype=float32), 'action_mask': array([1, 1, 1, 0], dtype=int8)})])
tensor([[-1.5070e-03, -1.1657e-03,  3.8397e-03, -3.4028e+38]])
2
Episode length: 1
Episode rewards: {'player_0': np.float64(0.375), 'player_1': np.float64(-0.75)}


In [99]:
rewards = []
lengths = []
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

import torch
from IPython.display import display, clear_output
import matplotlib.pyplot as plt

plot_handle = display(None, display_id=True)

def plot_metrics(rew_player0, rew_player1, lengths, curr_iter=0, sum_iter=0, shared=False):
    fig = plt.figure(1, figsize=(16, 8))
    plt.clf()

    # Subplots
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)

    # ---- Reward subplot ----
    if shared:
        ax1.set_title(f"Mean Rewards {curr_iter}/{sum_iter}")
        ax1.set_xlabel("Eval Interval")
        ax1.set_ylabel("Mean Reward")

        ax1.plot(rew_player0, label="player_0", color="blue")
        # ax1.plot(rew_player1, label="player_1", color="red")
        ax1.legend()
    else:
        ax1.set_title(f"Mean Rewards (2 Agents) {curr_iter}/{sum_iter}")
        ax1.set_xlabel("Eval Interval")
        ax1.set_ylabel("Mean Reward")

        ax1.plot(rew_player0, label="player_0", color="blue")
        ax1.plot(rew_player1, label="player_1", color="red")
        ax1.legend()

    # ---- Episode length subplot ----
    ax2.set_title(f"Mean Episode Length {curr_iter}/{sum_iter}")
    ax2.set_xlabel("Eval Interval")
    ax2.set_ylabel("Episode Length")
    ax2.plot(lengths, label="episode length", color="green")

    # Display
    display(fig)
    #plot_handle.update(fig)
    clear_output(wait=True)


None

In [100]:
a1_reward = []
a2_reward = []

In [101]:
TRAINING_ITERATIONS = 1500
import pickle
EVAL_INTERVAL = 50
from tqdm import tqdm

with tqdm(total=TRAINING_ITERATIONS, desc="Training", unit="iter") as pbar:
    for i in range(TRAINING_ITERATIONS):
        algo.train()

        if (i+1) % EVAL_INTERVAL == 0:
            # metrics = algo.evaluate()['evaluation']
            # rewards.append(metrics['episode_reward_mean'])
            # lengths.append(metrics['episode_len_mean'])
            metrics = algo.evaluate()
            metrics = metrics["env_runners"]

            # r_mean = metrics["episode_return_mean"]
            a1_reward.append(metrics["policy_reward_mean"]["player_0"])
            a2_reward.append(metrics["policy_reward_mean"]["player_1"])
            l_mean = metrics["episode_len_mean"]

            # rewards.append(r_mean)
            lengths.append(l_mean)
            plot_metrics(a1_reward, a2_reward, lengths, i, TRAINING_ITERATIONS)
            algo.save("./checkpoints_3_DQN")
        #pbar.update(1)

Training:   0%|          | 0/1500 [00:00<?, ?iter/s]

tensor([[ 5.7337e-04, -3.9633e-03, -1.7375e-03, -3.4028e+38]],
       grad_fn=<AddBackward0>)





AttributeError: 'list' object has no attribute 'as_multi_agent'

In [112]:
import numpy as np

SUITS = ["Spades", "Hearts", "Diamonds", "Clubs"]
RANKS = ["A", "2", "3", "4", "5", "6", "7", "8", "9", "10", "J", "Q", "K"]

def decode_visible_cards(obs_dict):
    """Return list of human-readable cards visible to this player."""
    vec = obs_dict["observation"]
    # ensure it's a flat array
    vec = np.array(vec).astype(int)

    cards_bits = vec[:52]
    cards = []
    for i, bit in enumerate(cards_bits):
        if bit == 1:
            suit = SUITS[i // 13]
            rank = RANKS[i % 13]
            cards.append(f"{rank} of {suit}")
    return cards
ACTION_NAMES = {
    0: "Call",
    1: "Raise",
    2: "Fold",
    3: "Check",
}


In [118]:
import imageio
import torch
from PIL import Image

rmode = "rgb_array"

env = texas_holdem_v4.env(render_mode=rmode)
env.reset()
frames = []
for agent in env.agent_iter():
    print("-------------------")
    print(agent)
    obs, reward, termination, truncation, info = env.last()
    # print(observation)
    visible_cards = decode_visible_cards(obs)
    print("Visible cards:", visible_cards)
    print("Reward so far:", reward)
    if termination or truncation:
        action = None
        print("-------------------")
    else:
        rllib_obs = {
                "action_mask": obs["action_mask"],
                "observation": obs["observation"],
            }

        # policy_id MUST match the agent name ("player_0" / "player_1")
        action = algo.compute_single_action(
                rllib_obs,
                policy_id= agent,#"shared_policy",
                explore=False,      # deterministic eval
        )
        action_name = ACTION_NAMES.get(int(action), f"Unknown({action})")

        print("Chosen action:", action_name)
        print("-------------------")
    if rmode == "rgb_array":
        frame = env.render()
        if frame is not None:
            frames.append(frame)
    env.step(action)

env.close()


if rmode != "human":
   imageio.mimsave("ppo_poke.gif", frames, fps=2)
   print("Saved to ppo_poke.gif")

-------------------
player_1
Visible cards: ['A of Hearts', '5 of Clubs']
Reward so far: 0
Chosen action: Call
-------------------
-------------------
player_0
Visible cards: ['8 of Clubs', '9 of Clubs']
Reward so far: 0
Chosen action: Raise
-------------------
-------------------
player_1
Visible cards: ['A of Hearts', '5 of Clubs']
Reward so far: 0
Chosen action: Raise
-------------------
-------------------
player_0
Visible cards: ['8 of Clubs', '9 of Clubs']
Reward so far: 0
Chosen action: Raise
-------------------
-------------------
player_1
Visible cards: ['A of Hearts', '5 of Clubs']
Reward so far: 0
Chosen action: Raise
-------------------
-------------------
player_0
Visible cards: ['8 of Clubs', '9 of Clubs']
Reward so far: 0
Chosen action: Call
-------------------
-------------------
player_1
Visible cards: ['A of Hearts', '2 of Diamonds', '10 of Diamonds', 'A of Clubs', '5 of Clubs']
Reward so far: 0
Chosen action: Raise
-------------------
-------------------
player_0
Vis