In [None]:
import numpy as np
import chess
from chess import svg, Move
from ipywidgets import interact

from pypad.games.chess import Chess, ChessState
from pypad.games.chess_enums import ObsPlanes, ActionPlanes, KeyGames

GAME_NAMES = KeyGames.__members__.keys()
OBS_NAMES = ObsPlanes.__members__.keys()
ACTION_NAMES = ActionPlanes.__members__.keys()

# View Key Games

In [None]:
def view_key_games(label: str, move_count: int, flip: bool):
    sans = KeyGames.get(label)
    state = ChessState.create(sans[:move_count])
    board = state.board
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=flip and not board.turn, size=350, lastmove=lastmove)

_ = interact(view_key_games, label=GAME_NAMES, move_count=range(445), flip=False)

# Inspect Observation Planes

In [None]:
def inspect_observation_planes(label: str, plane_name: str, move_count: int):
    sans = KeyGames.get(label)
    state = ChessState.create(sans[:move_count])
    board = state.board
    
    plane_idx = ObsPlanes.get(plane_name)
    state.plot(plane_idx)

    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=350, lastmove=lastmove)

_ = interact(inspect_observation_planes, label=GAME_NAMES, plane_name=OBS_NAMES, move_count=range(445))

# Inspect Action Planes

In [None]:
from chess.svg import Arrow
from pypad.views.plot import plot_chess_slice

def inspect_policy_planes(label: str, plane_name: str, move_count: int):
    sans = KeyGames.get(label)
    state = ChessState.create(sans[:move_count])
    board = state.board
    
    arrows = []
    plane = ActionPlanes.get(plane_name)
    policy = np.zeros(ActionPlanes.shape())
    for move in state.status().legal_moves:
        p, r, c = state.policy_loc_3d(move)
        policy[p, r, c] = 1.0
        if p == plane:
            arrows.append(Arrow(move.from_square, move.to_square))

    plot_chess_slice(policy, plane, (3,3))
    
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=390, lastmove=lastmove, arrows=arrows)

_ = interact(inspect_policy_planes, label=GAME_NAMES, plane_name=ACTION_NAMES, move_count=range(200))

#  Fen Inspector

In [None]:
state = ChessState.create('7k/8/p5pp/5nn1/3P3P/P2r4/1P2r2q/6RK w - - 1 49')
state.board