In [None]:
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from pettingzoo.classic import texas_holdem_v4
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env

import matplotlib.pyplot as plt
import pickle
import os
import imageio


In [None]:
class SingleAgentWrapper(gym.Env):
    """Convert PettingZoo Texas Hold'em into single-agent SB3-friendly env."""

    metadata = {"render_modes": ["rgb_array"]}

    def __init__(self):
        super().__init__()
        self.env = texas_holdem_v4.env()
        self.env.reset()

        self.agent = self.env.agents[0]
        
        obs_dict, _, _, _, _ = self.env.last()
        obs = obs_dict["observation"]
        action_mask = obs_dict["action_mask"]

        # OBS space
        self.observation_space = spaces.Box(
            low=-np.inf, high=np.inf,
            shape=obs.shape, dtype=obs.dtype
        )

        # Action space
        self.action_space = spaces.Discrete(len(action_mask))

        # store mask
        self.last_mask = action_mask

    def reset(self, seed=None, options=None):
        self.env.reset(seed=seed)
        obs_dict, _, _, _, _ = self.env.last()
        self.last_mask = obs_dict["action_mask"]
        return obs_dict["observation"], {}

    def step(self, action):
        # --- Enforce legal action (most important part) ---
        if self.last_mask[action] == 0:
            legal = np.where(self.last_mask == 1)[0]
            action = np.random.choice(legal)

        # Player 0 move
        self.env.step(action)

        # Opponent random move
        opp = self.env.agents[1]
        if not self.env.terminations[opp]:
            obs_dict, _, _, _, _ = self.env.last()
            legal = obs_dict["action_mask"]
            opp_action = np.random.choice(np.where(legal == 1)[0])
            self.env.step(opp_action)

        # Next obs
        obs_dict, reward, terminated, truncated, info = self.env.last()
        self.last_mask = obs_dict["action_mask"]

        return obs_dict["observation"], reward, terminated, truncated, info


### Check Environment

In [None]:
test_env = SingleAgentWrapper()
check_env(test_env, warn=True)
print("Environment OK!")

### Run Directory Helper

In [None]:
def make_run_dir(base="runs/a2c"):
    os.makedirs("runs", exist_ok=True)
    os.makedirs(base, exist_ok=True)

    run_id = 1
    while os.path.exists(f"{base}_run_{run_id}"):
        run_id += 1

    run_dir = f"{base}_run_{run_id}"
    os.makedirs(run_dir)
    print(f"Created directory: {run_dir}")
    return run_dir


### Evaluation Function

In [None]:
def evaluate(model, episodes=20):
    env_eval = SingleAgentWrapper()
    rewards = []

    for _ in range(episodes):
        obs, _ = env_eval.reset()
        done = False
        total = 0
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            obs, reward, done, _, _ = env_eval.step(action)
            total += reward
        rewards.append(total)

    return np.mean(rewards), rewards


In [None]:
import imageio
from pettingzoo.classic import texas_holdem_v4

# Decode card ID (0â€“51) -> "A of Hearts"
value_map = {
    1: "A", 2: "2", 3: "3", 4:"4", 5:"5", 6:"6",
    7:"7", 8:"8", 9:"9", 10:"10", 11:"J", 12:"Q", 13:"K"
}
suit_map = {0:"Hearts", 1:"Diamonds", 2:"Clubs", 3:"Spades"}

def decode_card(card_id):
    value = (card_id % 13) + 1
    suit = card_id // 13
    return f"{value_map[value]} of {suit_map[suit]}"

def extract_visible_cards(env, player):
    """
    Extract hole cards + community cards for a given player.
    This works for PettingZoo TexasHoldem.
    """
    cards = []

    # Private (hole) cards: 2 cards per player
    try:
        hole = env.env.hands[player]   # a list of 2 card IDs
        for cid in hole:
            cards.append(decode_card(cid))
    except:
        pass

    # Public cards (flop/turn/river)
    try:
        for cid in env.env.community_cards:
            cards.append(decode_card(cid))
    except:
        pass

    return cards


def generate_pretty_poker_gif_verbose(model=None, path="poker.gif", max_steps=200):
    """
    Generates:
    - High-quality Poker GIF
    - Console text logs (same as your PPO teammate!)
    """
    env = texas_holdem_v4.env(render_mode="rgb_array")
    env.reset()

    frames = []
    obs_dict, reward, terminated, truncated, _ = env.last()

    total_reward = 0
    step = 0

    while not (terminated or truncated) and step < max_steps:
        frame = env.render()          # high quality frame
        frames.append(frame)

        player = env.agent_selection  # player_0 or player_1
        visible = extract_visible_cards(env, player)

        print("\n-------------------")
        print(player)
        print(f"Visible cards: {visible}")
        print(f"Reward so far: {total_reward}")

        obs = obs_dict["observation"]
        mask = obs_dict["action_mask"]

        # ---- Choose action ----
        if model is None:
            legal = np.where(mask == 1)[0]
            action = int(np.random.choice(legal))
        else:
            action_pred, _ = model.predict(obs, deterministic=True)
            action = int(action_pred)  # convert numpy scalar -> int

            # illegal fix
            if mask[action] == 0:
                legal = np.where(mask == 1)[0]
                action = int(np.random.choice(legal))

        action_names = {0:"Fold", 1:"Call", 2:"Raise", 3:"Check", 4:"Bet"}
        action_name = action_names.get(action, str(action))
        print(f"Chosen action: {action_name}")
        print("-------------------")

        # ---- Step ----
        env.step(action)
        obs_dict, reward, terminated, truncated, _ = env.last()
        total_reward += reward
        step += 1

    # Save GIF
    imageio.mimsave(path, frames, fps=2)
    print(f"Saved to {path}")


### Start Training

In [None]:
run_dir = make_run_dir()

env = SingleAgentWrapper()
model = A2C("MlpPolicy", env, verbose=1, learning_rate=7e-4, gamma=0.99)

TIMESTEPS_PER_ITER = 5000
N_ITERS = 40  # 40 * 5000 = 200k steps

reward_history = []

for i in range(N_ITERS):
    model.learn(total_timesteps=TIMESTEPS_PER_ITER, reset_num_timesteps=False)
    mean_r, _ = evaluate(model, episodes=20)
    reward_history.append(mean_r)
    print(f"Iter {i+1}/{N_ITERS} | mean reward = {mean_r:.3f}")

# Save model
model.save(f"{run_dir}/a2c_model.zip")

# Save reward history
with open(f"{run_dir}/reward_history.pkl", "wb") as f:
    pickle.dump(reward_history, f)

print(f"Model + history saved to folder: {run_dir}")


### Plot Training Curve

In [None]:
plt.plot(reward_history)
plt.xlabel("Iteration (x5000 timesteps)")
plt.ylabel("Mean reward vs Random")
plt.title("A2C Poker Training Curve")
plt.grid()
plt.savefig(f"{run_dir}/training_curve.png")
plt.show()


In [None]:
# Before training (random baseline)
generate_pretty_poker_gif_verbose(
    model=None,
    path=f"{run_dir}/before_training_verbose.gif"
)

# After training (A2C policy)
generate_pretty_poker_gif_verbose(
    model=model,
    path=f"{run_dir}/after_training_verbose.gif"
)
