In [23]:
# Save this notebook at: training/roomplay/roomplay.ipynb

import os, sys, json, math, time, random
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List, Any, Optional, Tuple

import numpy as np
from PIL import Image, ImageDraw, ImageFont

import torch

# Jupyter display + widgets
from IPython.display import display, clear_output
import ipywidgets as W

# --- repo paths ---
REPO_ROOT = Path.cwd()
if not (REPO_ROOT / "training").exists():
    # If running from inside training/roomplay/, climb up until we find /training
    for p in Path.cwd().resolve().parents:
        if (p / "training").exists():
            REPO_ROOT = p
            break

sys.path.append(str(REPO_ROOT))  # so "training" package is importable

ASSETS_DIR = REPO_ROOT / "training" / "assets" / "cards"  # card images dir
DEFAULT_CONFIG = REPO_ROOT / "training" / "configs" / "full-tiny-smoke.json"  # change if needed
CHECKPOINT_PATH = REPO_ROOT / "training" / "runs" / "smoke" / "full" / "run" / "policy_final.pt"  # optional
# --- fallbacks for fonts (used for text-based card placeholders) ---
def _get_font(size=18):
    try:
        return ImageFont.truetype("DejaVuSans-Bold.ttf", size)
    except:
        return ImageFont.load_default()


In [24]:
CHECKPOINT_PATH

WindowsPath('E:/sequence_game_board/sequence_board_game/training/runs/smoke/full/run/policy_final.pt')

In [25]:
from training.utils.jsonio import load_json, deep_update
from training.utils.seeding import set_all_seeds
from training.envs.sequence_env import SequenceEnv
from training.algorithms.ppo_lstm.ppo_lstm_policy import PPORecurrentPolicy

Device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_cfg(path: Path, override: Optional[dict] = None) -> dict:
    cfg = load_json(str(path))
    if override:
        cfg = deep_update(cfg, override)
    return cfg

def to_uint8_image(obs_chw: np.ndarray) -> Image.Image:
    """Convert CHW float/uint8 numpy -> PIL.Image for preview."""
    arr = obs_chw
    if arr.dtype != np.uint8:
        a = arr.astype(np.float32)
        a = (255.0 * (a - a.min()) / (a.ptp() + 1e-8)).clip(0, 255).astype(np.uint8)
    else:
        a = arr
    if a.shape[0] in (1, 3, 4):  # CHW
        a = np.moveaxis(a, 0, 2)
    return Image.fromarray(a)
# --- Patch: make to_uint8_image robust to odd obs shapes (e.g., (1,1,10)) ---

import numpy as np
from PIL import Image

def to_uint8_image(obs) -> Image.Image:
    a = np.asarray(obs)

    def _norm(x: np.ndarray) -> np.ndarray:
        x = x.astype(np.float32)
        vmin, vmax = float(np.min(x)), float(np.max(x))
        if not np.isfinite(vmin) or not np.isfinite(vmax) or abs(vmax - vmin) < 1e-8:
            return np.zeros_like(x, dtype=np.uint8)
        x = (x - vmin) / (vmax - vmin)
        return (x * 255.0).clip(0, 255).astype(np.uint8)

    # remove singleton dims; many envs yield shapes like (1,1,10) etc.
    a = np.squeeze(a)

    # normalize to uint8 if needed
    if a.dtype != np.uint8:
        a = _norm(a)

    # Handle by dimensionality
    if a.ndim == 1:
        # vector -> 1xN grayscale strip
        a = a.reshape(1, -1)
        return Image.fromarray(a, mode="L")

    if a.ndim == 2:
        # HxW grayscale
        return Image.fromarray(a, mode="L")

    if a.ndim == 3:
        # Try to interpret channels
        # If channels-first and looks like image, move to HWC
        if a.shape[0] in (1, 3, 4) and a.shape[1] > 1 and a.shape[2] > 1:
            a = np.moveaxis(a, 0, 2)

        # After this, we expect HWC. If channels is not 1/3/4, collapse to grayscale.
        if a.shape[2] not in (1, 3, 4):
            a = a.mean(axis=2).astype(np.uint8)
            return Image.fromarray(a, mode="L")

        # Channels are 1/3/4 — pick appropriate mode
        c = a.shape[2]
        if c == 1:
            return Image.fromarray(a[:, :, 0], mode="L")
        if c == 3:
            return Image.fromarray(a, mode="RGB")
        if c == 4:
            return Image.fromarray(a, mode="RGBA")

    # Final fallback: flatten anything else to grayscale
    a = _norm(a.astype(np.float32)).reshape(1, -1)
    return Image.fromarray(a, mode="L")

def legal_random_action(legal_mask: Optional[np.ndarray], action_dim: int) -> int:
    if legal_mask is None:
        return random.randrange(action_dim)
    idxs = np.flatnonzero(legal_mask > 0.5)
    if len(idxs) == 0:
        return random.randrange(action_dim)
    return int(random.choice(idxs))


In [26]:
# Player indices are seats in turn order: 0,1,2,3 ...
TEAM_PRESETS = {
    "1v1": {
        "players": [0, 1],
        "teams": {0: [0], 1: [1]},
        "team_colors": {0: "#E63946", 1: "#457B9D"},
    },
    "2v2": {
        "players": [0, 1, 2, 3],
        "teams": {0: [0, 2], 1: [1, 3]},
        "team_colors": {0: "#E76F51", 1: "#2A9D8F"},
    },
    "1v1v1": {
        "players": [0, 1, 2],
        "teams": {0: [0], 1: [1], 2: [2]},
        "team_colors": {0: "#E63946", 1: "#457B9D", 2: "#2A9D8F"},
    },
}

def infer_num_players_from_preset(preset_name: str) -> int:
    return len(TEAM_PRESETS[preset_name]["players"])


In [27]:
CARD_W, CARD_H = 140, 200  # will be updated on first load if assets exist

def _text_wh(draw: ImageDraw.ImageDraw, txt: str, font):
    """Return (w,h) for text using textbbox, with safe fallbacks."""
    try:
        bbox = draw.textbbox((0, 0), txt, font=font)
        return bbox[2] - bbox[0], bbox[3] - bbox[1]
    except Exception:
        # older fallbacks
        try:
            return font.getsize(txt)
        except Exception:
            try:
                L = int(draw.textlength(txt, font=font))
                return L, getattr(font, "size", 18)
            except Exception:
                return len(txt) * 10, 18

def load_card_image(code: str) -> Image.Image:
    """
    code examples: '2A', 'JC', 'TD', 'QS', 'AH'
    Where the asset file is {code}.png in ASSETS_DIR.
    If missing, returns a text placeholder card.
    """
    global CARD_W, CARD_H
    p = ASSETS_DIR / f"{code}.png"
    if p.exists():
        img = Image.open(p).convert("RGBA")
        CARD_W, CARD_H = img.size
        return img
    # fallback placeholder
    img = Image.new("RGBA", (CARD_W, CARD_H), (245, 245, 245, 255))
    d = ImageDraw.Draw(img)
    d.rectangle([0, 0, CARD_W-1, CARD_H-1], outline=(50, 50, 50, 255), width=3)
    txt = code or "??"
    font = _get_font(32)
    w, h = _text_wh(d, txt, font)
    d.text(((CARD_W-w)//2, (CARD_H-h)//2), txt, fill=(10, 10, 10, 255), font=font)
    return img

def draw_chip(color_hex: str, diameter: int = 24) -> Image.Image:
    img = Image.new("RGBA", (diameter, diameter), (0, 0, 0, 0))
    d = ImageDraw.Draw(img)
    d.ellipse([2, 2, diameter-2, diameter-2], fill=color_hex, outline="black", width=2)
    return img

def stitch_h(images: List[Image.Image], pad: int = 8, bg=(0,0,0,0)) -> Image.Image:
    if not images:
        return Image.new("RGBA", (1, 1), bg)
    w = sum(im.width for im in images) + pad * (len(images)-1)
    h = max(im.height for im in images)
    out = Image.new("RGBA", (w, h), bg)
    x = 0
    for idx, im in enumerate(images):
        out.paste(im, (x, (h - im.height) // 2), im)
        x += im.width + (pad if idx < len(images)-1 else 0)
    return out

def stitch_v(images: List[Image.Image], pad: int = 8, bg=(0,0,0,0)) -> Image.Image:
    if not images:
        return Image.new("RGBA", (1, 1), bg)
    w = max(im.width for im in images)
    h = sum(im.height for im in images) + pad * (len(images)-1)
    out = Image.new("RGBA", (w, h), bg)
    y = 0
    for idx, im in enumerate(images):
        out.paste(im, ((w - im.width) // 2, y), im)
        y += im.height + (pad if idx < len(images)-1 else 0)
    return out


def card_back() -> Image.Image:
    # simple generic back
    img = Image.new("RGBA", (CARD_W, CARD_H), (30, 30, 35, 255))
    d = ImageDraw.Draw(img)
    d.rectangle([6, 6, CARD_W-6, CARD_H-6], outline=(200, 200, 220, 255), width=4)
    font = _get_font(28)
    txt = "BACK"
    w, h = _text_wh(d, txt, font)
    d.text(((CARD_W-w)//2, (CARD_H-h)//2), txt, fill=(220, 220, 235, 255), font=font)
    return img

In [28]:
@dataclass
class BoardView:
    current_player: int = 0
    table_cards: List[str] = field(default_factory=list)     # e.g., ["2A","JC"]
    hands: Dict[int, List[str]] = field(default_factory=dict)  # seat -> list[str]
    team_points: Dict[int, int] = field(default_factory=dict)  # team_id -> points
    legal_mask: Optional[np.ndarray] = None
    meta: Dict[str, Any] = field(default_factory=dict)

def decode_board_from_info(info: Dict[str, Any], fallback_hand_seats: List[int]) -> BoardView:
    """
    Attempts to interpret common keys from env.info. Customize if your env uses different names.
    """
    view = BoardView()
    # current player / turn
    for k in ("current_player", "player_turn", "turn", "actor"):
        if k in info:
            view.current_player = int(info[k])
            break
    # public/board cards
    for k in ("table_cards", "board_cards", "public_cards", "table", "board"):
        if k in info and isinstance(info[k], (list, tuple)):
            view.table_cards = list(map(str, info[k]))
            break
    # hands (per seat) — try multiple patterns
    for k in ("hands", "player_hands"):
        if k in info and isinstance(info[k], dict):
            view.hands = {int(s): list(map(str, v)) for s, v in info[k].items()}
            break
    if not view.hands:
        # per-seat singleton keys like hand0, hand1, ...
        hands_found = {}
        for seat in fallback_hand_seats:
            key = f"hand{seat}"
            if key in info:
                hands_found[seat] = list(map(str, info[key]))
        if hands_found:
            view.hands = hands_found
    # team points / scores
    for k in ("team_points", "scores", "score_by_team"):
        if k in info and isinstance(info[k], dict):
            view.team_points = {int(t): int(v) for t, v in info[k].items()}
            break
    # legal mask
    if "legal_mask" in info:
        lm = info["legal_mask"]
        view.legal_mask = np.asarray(lm, dtype=np.float32)
    # meta dump
    view.meta = {k: v for k, v in info.items() if k not in ("legal_mask",)}
    return view

class RLAgent:
    """PPO recurrent policy wrapper with per-seat hidden states."""
    def __init__(self, policy: PPORecurrentPolicy, num_seats: int, device: torch.device):
        self.policy = policy
        self.device = device
        self.h, self.c = policy.get_initial_state(batch_size=num_seats)

    def reset(self):
        self.h, self.c = self.policy.get_initial_state(batch_size=self.h.shape[1])

    @torch.no_grad()
    def act(self, obs_chw: np.ndarray, legal_mask: Optional[np.ndarray], seat: int) -> Tuple[int, Dict[str, Any]]:
        obs_t = torch.from_numpy(obs_chw).unsqueeze(0).to(self.device, dtype=torch.float32)
        h0 = self.h[:, [seat], :].contiguous()
        c0 = self.c[:, [seat], :].contiguous()
        legal_t = None
        if legal_mask is not None:
            legal_t = torch.from_numpy(legal_mask).unsqueeze(0).to(self.device, dtype=torch.float32)
        out = self.policy.select_action(obs=obs_t, legal_mask=legal_t, h0=h0, c0=c0)
        a = int(out["action"].item())
        # update hidden for just this seat
        self.h[:, [seat], :] = out["h"]
        self.c[:, [seat], :] = out["c"]
        extra = {
            "logp": float(out["log_prob"].item()),
            "value": float(out["value"].item()),
        }
        return a, extra


In [29]:
@dataclass
class StepRecord:
    obs: np.ndarray
    info: Dict[str, Any]
    action: Optional[int]
    reward: float
    done: bool
    seat: int
    view: BoardView

def build_env(cfg_path: Path, preset_name: str, seed: int = 123, override: Optional[dict] = None) -> Tuple[SequenceEnv, dict, np.ndarray, Dict[str, Any]]:
    cfg = load_cfg(cfg_path, override=override or {})
    # Try to hint the number of players in cfg if your env uses it
    n_players = infer_num_players_from_preset(preset_name)
    cfg = deep_update(cfg, {"env": {"num_players": n_players}})  # adjust to your schema if needed

    set_all_seeds(seed)
    env = SequenceEnv(cfg)
    obs, info = env.reset(seed=seed)
    return env, cfg, obs, info

def build_policy(cfg: dict, obs_shape: Tuple[int, int, int], action_dim: int, device: torch.device, ckpt_path: Optional[Path] = None) -> PPORecurrentPolicy:
    policy = PPORecurrentPolicy(
        obs_shape=tuple(obs_shape),
        action_dim=action_dim,
        conv_channels=cfg["model"]["conv_channels"],
        lstm_hidden=int(cfg["model"]["lstm_hidden"]),
        lstm_layers=int(cfg["model"].get("lstm_layers", 1)),
        device=device,
    )
    if ckpt_path and Path(ckpt_path).exists():
        sd = torch.load(str(ckpt_path), map_location=device)
        policy.load_state_dict(sd, strict=False)
        policy.eval()
    return policy

def run_episode_with_agents(
    cfg_path: Path,
    preset_name: str,
    seed: int = 123,
    ckpt_path: Optional[Path] = CHECKPOINT_PATH,
    max_steps: int = 10_000,
) -> Dict[str, Any]:
    env, cfg, obs, info = build_env(cfg_path, preset_name, seed=seed)
    obs_shape = tuple(obs.shape)
    action_dim = env.action_dim

    policy = build_policy(cfg, obs_shape, action_dim, Device, ckpt_path if ckpt_path and Path(ckpt_path).exists() else None)
    num_seats = infer_num_players_from_preset(preset_name)
    agent = RLAgent(policy, num_seats=num_seats, device=Device)

    history: List[StepRecord] = []
    done = False
    steps = 0

    # initial seat from info
    view = decode_board_from_info(info, fallback_hand_seats=list(range(num_seats)))
    current_seat = int(view.current_player)

    while not done and steps < max_steps:
        legal_mask = view.legal_mask
        action, _extra = agent.act(obs, legal_mask, seat=current_seat)
        next_obs, reward, terminated, truncated, info2 = env.step(action)
        done = bool(terminated or truncated)

        history.append(
            StepRecord(
                obs=obs.copy(),
                info=info.copy(),
                action=action,
                reward=float(reward),
                done=done,
                seat=current_seat,
                view=view,
            )
        )

        obs, info = next_obs, info2
        view = decode_board_from_info(info, fallback_hand_seats=list(range(num_seats)))
        current_seat = int(view.current_player)
        steps += 1

        if done:
            # push final frame (post-terminal) for viewing
            history.append(
                StepRecord(
                    obs=obs.copy(),
                    info=info.copy(),
                    action=None,
                    reward=0.0,
                    done=True,
                    seat=current_seat,
                    view=view,
                )
            )
            break

    return {
        "preset": preset_name,
        "history": history,
        "num_seats": num_seats,
        "teams": TEAM_PRESETS[preset_name]["teams"],
        "team_colors": TEAM_PRESETS[preset_name]["team_colors"],
        "action_dim": action_dim,
        "cfg_used": cfg,
    }


In [30]:
def render_board_frame(
    rec: StepRecord,
    teams: Dict[int, List[int]],
    team_colors: Dict[int, str],
    show_seat: Optional[int] = None,
    show_obs_fallback: bool = True,
) -> Image.Image:
    """
    Composes a single image with:
      - team chips and score line (if any)
      - table cards
      - selected agent's hand (open)
      - other players' hands as backs (unless info provided per seat)
      - fallback to obs image if no explicit card codes found
    """
    # title/status bar
    status_img = Image.new("RGBA", (max(800, CARD_W*5), 50), (250, 250, 250, 255))
    d = ImageDraw.Draw(status_img)
    font = _get_font(18)
    status = f"Seat {rec.seat} acted: {rec.action if rec.action is not None else '-'} | Reward {rec.reward:.2f} | Done={rec.done}"
    d.text((10, 15), status, fill=(10,10,10,255), font=font)

    # team chips + points
    chips = []
    for team_id, seats in teams.items():
        chip = draw_chip(team_colors.get(team_id, "#999999"), 24)
        label = Image.new("RGBA", (110, 24), (0,0,0,0))
        d2 = ImageDraw.Draw(label)
        ttxt = f"T{team_id} [{','.join(map(str,seats))}]"
        if rec.view.team_points:
            pts = rec.view.team_points.get(team_id, 0)
            ttxt += f"  {pts}"
        d2.text((0,4), ttxt, fill=(20,20,20,255), font=_get_font(14))
        chips.append(stitch_h([chip, label], pad=6))
    chips_bar = stitch_h(chips, pad=16)

    # table/public cards
    table_cards = rec.view.table_cards
    table_imgs = [load_card_image(code) for code in table_cards] if table_cards else []
    table_row = stitch_h(table_imgs, pad=8) if table_imgs else None

    # hands
    n_seats = max([*range(0,1), *rec.view.hands.keys()], default=0) + 1
    show_seat = rec.seat if show_seat is None else show_seat
    hands_blocks = []
    for seat in range(n_seats):
        label = Image.new("RGBA", (100, 24), (0,0,0,0))
        d3 = ImageDraw.Draw(label)
        tag = f"Seat {seat}" + (" ←" if seat == rec.view.current_player else "")
        d3.text((0,4), tag, fill=(0,0,0,255), font=_get_font(14))

        if seat in rec.view.hands and rec.view.hands[seat]:
            codes = rec.view.hands[seat]
            imgs = [load_card_image(c) for c in codes]
        else:
            # unknown hand — show backs (5 backs placeholder)
            imgs = [card_back() for _ in range(5)]

        # open only selected seat, back for others
        if seat != show_seat:
            imgs = [card_back() for _ in imgs]

        hands_blocks.append(stitch_v([label, stitch_h(imgs, pad=6)], pad=6))

    # fallback obs view
    obs_img = None
    if show_obs_fallback and (not table_cards and all(not v for v in rec.view.hands.values())):
        obs_img = to_uint8_image(rec.obs)

    # compose final
    rows = [status_img, chips_bar]
    if table_row:
        rows.append(table_row)
    if hands_blocks:
        rows.append(stitch_h(hands_blocks, pad=12))
    if obs_img:
        rows.append(obs_img)
    canvas = stitch_v(rows, pad=12)
    return canvas

class HistoryViewer:
    def __init__(self, run: Dict[str, Any]):
        self.run = run
        self.history: List[StepRecord] = run["history"]
        self.teams = run["teams"]
        self.colors = run["team_colors"]
        self.n = len(self.history)

        self.idx = W.IntSlider(description="Step", min=0, max=max(0, self.n-1), step=1, value=0, continuous_update=False)
        self.prev_btn = W.Button(description="◀︎ Back", layout=W.Layout(width="100px"))
        self.next_btn = W.Button(description="Forward ▶︎", layout=W.Layout(width="120px"))
        self.seat_dd = W.Dropdown(description="Show seat", options=self._seat_options(), value=self._default_show_seat())
        self.out = W.Output()

        self.prev_btn.on_click(lambda b: self._jump(-1))
        self.next_btn.on_click(lambda b: self._jump(+1))
        self.idx.observe(self._on_idx, names="value")
        self.seat_dd.observe(self._on_idx, names="value")

    def _seat_options(self):
        seats = set()
        for r in self.history:
            seats.add(r.seat)
            seats.update(r.view.hands.keys())
        opts = sorted(list(seats))
        if not opts:
            opts = [0]
        return [(f"Seat {s}", s) for s in opts]

    def _default_show_seat(self):
        return self.history[0].seat if self.history else 0

    def _jump(self, delta: int):
        self.idx.value = int(np.clip(self.idx.value + delta, self.idx.min, self.idx.max))

    def _on_idx(self, change):
        self.render()

    def render(self):
        if not self.history:
            with self.out:
                clear_output()
                print("No history recorded.")
            return
        rec = self.history[self.idx.value]
        img = render_board_frame(rec, self.teams, self.colors, show_seat=self.seat_dd.value)
        with self.out:
            clear_output(wait=True)
            display(img)

    def widget(self):
        ctrls = W.HBox([self.prev_btn, self.next_btn, self.idx, self.seat_dd])
        box = W.VBox([ctrls, self.out])
        self.render()
        return box


In [31]:
# Choose a config JSON that your env understands
CONFIG_PATH = DEFAULT_CONFIG  # set to your config path if different
SEED = 123

# 1 vs 1
run_1v1 = run_episode_with_agents(CONFIG_PATH, "1v1", seed=SEED, ckpt_path=CHECKPOINT_PATH)

In [35]:
run_1v1.get('history')[40]

StepRecord(obs=array([[[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        ...,
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ]],

       [[0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        ],
        ...,
        [0.        , 0.        , 0.        , ..., 0.        ,
         0.        , 0.        

In [33]:
viewer_1v1 = HistoryViewer(run_1v1)
display(W.HTML("<h3>1 vs 1</h3>"))
display(viewer_1v1.widget())

HTML(value='<h3>1 vs 1</h3>')

VBox(children=(HBox(children=(Button(description='◀︎ Back', layout=Layout(width='100px'), style=ButtonStyle())…