In [5]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from pettingzoo.classic import texas_holdem_v4

class SingleAgentWrapper(gym.Env):

    metadata = {"render_modes": ["human"]}

    def __init__(self):
        super().__init__()
        self.env = texas_holdem_v4.env()
        self.env.reset()
        
        self.agent = self.env.agents[0]

        # sample observation
        obs_dict, _, _, _, _ = self.env.last()
        obs = obs_dict["observation"]
        action_mask = obs_dict["action_mask"]

        # define observation space
        self.observation_space = spaces.Box(
            low=-np.inf,
            high=np.inf,
            shape=obs.shape,
            dtype=obs.dtype,
        )

        # define action space
        self.action_space = spaces.Discrete(len(action_mask))

    def reset(self, seed=None, options=None):
        self.env.reset(seed=seed)
        obs_dict, _, _, _, _ = self.env.last()
        self.last_mask = obs_dict["action_mask"]
        return obs_dict["observation"], {}

    def step(self, action):
        # ----------- 强制合法动作（核心修复） ---------------- #
        if self.last_mask[action] == 0:
            # chose a legal action instead
            legal_actions = np.where(self.last_mask == 1)[0]
            action = np.random.choice(legal_actions)

        # player 0 takes action
        self.env.step(action)

        # player 1 random action
        opp = self.env.agents[1]
        if not self.env.terminations[opp]:
            obs_dict, _, _, _, _ = self.env.last()
            legal = obs_dict["action_mask"]
            opp_action = np.random.choice(np.where(legal == 1)[0])
            self.env.step(opp_action)

        # get new player 0 observation
        obs_dict, reward, terminated, truncated, info = self.env.last()

        # update mask for next step
        self.last_mask = obs_dict["action_mask"]

        return obs_dict["observation"], reward, terminated, truncated, info


In [6]:
from stable_baselines3.common.env_checker import check_env

env = SingleAgentWrapper()
check_env(env, warn=True)


In [None]:
from stable_baselines3 import A2C

env = SingleAgentWrapper()

model = A2C(
    "MlpPolicy",
    env,
    verbose=1,
    learning_rate=7e-4,
    gamma=0.99,
)

model.learn(total_timesteps=200_000)  # 可先跑10万-20万
model.save("a2c_poker_model")


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.93     |
|    ep_rew_mean        | -0.035   |
| time/                 |          |
|    fps                | 804      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -1.38    |
|    explained_variance | -0.19    |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | 3.89     |
|    value_loss         | 10.1     |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 1.7      |
|    ep_rew_mean        | 0.015    |
| time/                 |          |
|    fps                | 799      |
|    iterations         | 200      |
|    time_elapsed 

### Calculate Mean Error

In [None]:
import numpy as np

def evaluate(model, episodes=50):
    env = SingleAgentWrapper()
    rewards = []
    for _ in range(episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0
        while not done:
            action, _ = model.predict(obs)
            obs, reward, done, _, _ = env.step(action)
            total_reward += reward
        rewards.append(total_reward)
    return np.mean(rewards), rewards

mean_reward, reward_list = evaluate(model, episodes=50)
mean_reward
