In [None]:
import numpy as np
import chess
from chess import svg, Move
from chess.svg import Arrow

from ipywidgets import interact
from IPython.display import display, HTML, clear_output

from pypad.alpha_zero import AlphaZero, PytorchNeuralNetwork

from pypad.games import Chess
from pypad.states import ChessState
from pypad.states.chess_enums import ObsPlanes, ActionPlanes, KeyGames

GAME_NAMES = KeyGames.labels()
OBS_NAMES = ObsPlanes.labels()
ACTION_NAMES = ActionPlanes.labels()

#  Fen Inspector

In [None]:
game = Chess()
state = game.initial_state('4Qbk1/2r4p/5Rp1/3p4/8/5N1P/3NBPPK/2r5 b - - 0 33')
state.board

# View Key Games

In [None]:
def view_key_games(label: str, move_count: int, flip: bool):
    sans = KeyGames.get(label)
    state = game.initial_state(sans[:move_count])
    board = state.board
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=flip and not board.turn, size=350, lastmove=lastmove)

_ = interact(view_key_games, label=GAME_NAMES, move_count=range(445), flip=False)

# Input Feature Planes

In [None]:
def inspect_observation_planes(label: str, plane_name: str, move_count: int):
    sans = KeyGames.get(label)
    state = game.initial_state(sans[:move_count])
    board = state.board
    
    plane_idx = ObsPlanes.get(plane_name)
    state.plot(plane_idx)

    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=350, lastmove=lastmove)

_ = interact(inspect_observation_planes, label=GAME_NAMES, plane_name=OBS_NAMES, move_count=range(445))

# Inspect Action Planes - Legal Moves

In [None]:
def inspect_policy_planes(label: str, plane_name: str, move_count: int):
    sans = KeyGames.get(label)
    state = game.initial_state(sans[:move_count])
    board = state.board

    legal_moves = state.status().legal_moves
    plane = ActionPlanes.get(plane_name)
    state.plot_policy(plane)
    
    arrows = [Arrow(move.from_square, move.to_square) for move in legal_moves if state.policy_loc_3d(move)[0] == plane]
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=350, lastmove=lastmove, arrows=arrows)

_ = interact(inspect_policy_planes, label=GAME_NAMES, plane_name=ACTION_NAMES, move_count=range(200))

# Visualising Model predictions for Chess

In [None]:
network = PytorchNeuralNetwork.create(game, '..')
alpha_zero = AlphaZero(network)

## Top moves by action plane

In [None]:
def inspect_policy_planes(label: str, plane_name: str, move_count: int, moves_to_show: int):
    sans = KeyGames.get(label)
    state = game.initial_state(sans[:move_count])
    board = state.board

    legal_moves = state.status().legal_moves
    plane = ActionPlanes.get(plane_name)
    
    policy = alpha_zero.raw_policy(state)
    state.plot_policy(plane, policy)
    
    indices = np.argpartition(policy.encoded_policy, -moves_to_show)[-moves_to_show:]
    arrows = [Arrow(move.from_square, move.to_square) for move in legal_moves if state.policy_loc_3d(move)[0] == plane and state.policy_loc(move) in indices]
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=350, lastmove=lastmove, arrows=arrows)

_ = interact(inspect_policy_planes, label=GAME_NAMES, plane_name=ACTION_NAMES, move_count=range(200), moves_to_show=(1,10))

## Top moves (all planes)

In [None]:
def inspect_policy_planes(label: str, move_count: int, moves_to_show: int):
    sans = KeyGames.get(label)
    state = game.initial_state(sans[:move_count])
    board = state.board

    policy = alpha_zero.raw_policy(state)
    state.plot_policy(-1, policy)
    
    legal_moves = state.status().legal_moves
    indices = np.argpartition(policy.encoded_policy, -moves_to_show)[-moves_to_show:]
    arrows = [Arrow(move.from_square, move.to_square) for move in legal_moves if state.policy_loc(move) in indices]
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=350, lastmove=lastmove, arrows=arrows)

_ = interact(inspect_policy_planes, label=GAME_NAMES, move_count=range(200), moves_to_show=(1,10))

In [None]:
def inspect_policy_planes(label: str, move_count: int, moves_to_show: int, num_mcts_sims: int):
    sans = KeyGames.get(label)
    state = game.initial_state(sans[:move_count])
    board = state.board

    policy = alpha_zero.policy(state, num_mcts_sims)
    state.plot_policy(-1, policy)
    
    legal_moves = state.status().legal_moves
    indices = np.argpartition(policy.encoded_policy, -moves_to_show)[-moves_to_show:]
    arrows = [Arrow(move.from_square, move.to_square) for move in legal_moves if state.policy_loc(move) in indices]
    lastmove = board.move_stack[-1] if board.move_stack else None
    return svg.board(board, flipped=not board.turn, size=350, lastmove=lastmove, arrows=arrows)

sims = [2, 100, 200, 500, 1_000, 2_000, 5_000]
_ = interact(inspect_policy_planes, label=GAME_NAMES, move_count=range(200), moves_to_show=(1,10), num_mcts_sims=sims)

# Play against Alpha Zero

In [None]:
game = Chess()
state = game.initial_state()
challenger_plays_as = 1 # play as Player: 1 or 2

network = PytorchNeuralNetwork.create(game, '..')
alpha_zero = AlphaZero(network)

network_old = PytorchNeuralNetwork.create(game, '..', 20)
alpha_zero_old = AlphaZero(network_old)

def get_move(i: int) -> int:
    if i % 2 == challenger_plays_as:
        return alpha_zero.select_move(state, 100)
    else:
        return state.get_input_move()
        # return alpha_zero_old.select_move(state, 60)

while state.status().is_in_progress:
    clear_output(); display(state.board) 
    move = get_move(state.move_count)
    state.set_move(move)

clear_output(); display(state.board)
if state.status().value > 0:
    print('Challenger wins!' if state.played_by == challenger_plays_as else "AlphaZero wins!")
else:
    print("It's a draw!")

print(state.pgn())