Skip to content

Commit

Permalink
documented MCTS more, updated MCTS_eval_demo, corrected score to be i…
Browse files Browse the repository at this point in the history
…n POV view
  • Loading branch information
QueensGambit committed Oct 25, 2018
1 parent 91a09e0 commit 7e126bf
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 167 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
* linguist-vendored
*.py linguist-vendored=false
*.ipynb linguist-vendored=false
*.sh linguist-vendored=false
File renamed without changes.
229 changes: 142 additions & 87 deletions DeepCrazyhouse/src/domain/agent/player/MCTSAgent.py

Large diffs are not rendered by default.

3 changes: 0 additions & 3 deletions DeepCrazyhouse/src/domain/agent/player/RawNetAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ def evaluate_board_state(self, state: _GameState, verbose=True):
t_start_eval = time()
pred_value, pred_policy = self._net.predict_single(state.get_state_planes())

if state.is_white_to_move() is False:
pred_value *= -1

legal_moves = list(state.get_legal_moves())
p_vec_small = get_probs_of_move_list(pred_policy, legal_moves, state.is_white_to_move())

Expand Down
4 changes: 2 additions & 2 deletions DeepCrazyhouse/src/domain/agent/player/util/Node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ class Node:
Helper Class which stores the statistics of all child nodes of this node in the search tree
"""

def __init__(self, p_vec_small: np.ndarray, legal_moves: [chess.Move], str_legal_moves: str):
def __init__(self, value, p_vec_small: np.ndarray, legal_moves: [chess.Move], str_legal_moves: str):

# lock object for this node to protect its member variables
self.lock = Lock()

# store the initial value prediction of the current board position
#self.v = v
self.v = value
# specify the number of direct child nodes from this node
self.nb_direct_child_nodes = np.array(len(p_vec_small)) #, np.uint32)
# prior probability selecting each child, which is estimated by the neural network
Expand Down
1 change: 0 additions & 1 deletion DeepCrazyhouse/src/domain/crazyhouse/GameState.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def get_pythonchess_board(self):

def is_draw(self):
# check if you can claim a draw - its assumed that the draw is always claimed
print('check for draw')
return self.board.can_claim_draw()

def is_won(self):
Expand Down
67 changes: 44 additions & 23 deletions DeepCrazyhouse/src/samples/MCTS_eval_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,7 @@
"metadata": {},
"outputs": [],
"source": [
"batch_size = 8\n",
"nb_playouts = 256\n",
"cpuct = 1\n",
"dirichlet_epsilon = 0.25\n",
"nb_workers = 64"
"raw_agent = RawNetAgent(net)"
]
},
{
Expand All @@ -76,18 +72,9 @@
"metadata": {},
"outputs": [],
"source": [
"rawAgent = RawNetAgent(net)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"AGENT = MCTSAgent(net, cpuct=1, nb_playouts_empty_pockets=2048, nb_playouts_filled_pockets=4096,\n",
" cpuct_decay=0, cpuct_min=0, search_depth=40,\n",
" nb_playouts_update=256, max_search_time_s=900, dirichlet_alpha=0.2, dirichlet_epsilon=0.1)"
"mcts_agent = MCTSAgent(net, threads=8, playouts_empty_pockets=128, playouts_filled_pockets=256,\n",
" playouts_update=512, cpuct=1, dirichlet_epsilon=.1, dirichlet_alpha=0.2, max_search_time_s=300,\n",
" max_search_depth=15, temperature=0., clip_quantil=0., virtual_loss=3, verbose=True)"
]
},
{
Expand All @@ -98,15 +85,15 @@
"source": [
"board = chess.variant.CrazyhouseBoard()\n",
"\n",
"#board.push_uci('e2e4')\n",
"board.push_uci('e2e4')\n",
"#board.push_uci('e7e6')\n",
"\n",
"#fen = 'rnbqkb1r/ppp1pppp/5n2/3P4/8/8/PPPP1PPP/RNBQKBNR/P w KQkq - 1 3'\n",
"fen = 'r4rk1/ppp2pp1/3p1q1p/n1bPp3/2B1B1b1/3P1N2/PPP2PPP/R2Q1RK1[Nn] w - - 2 13'\n",
"#fen = 'r4rk1/ppp2pp1/3p1q1p/n1bPp3/2B1B1b1/3P1N2/PPP2PPP/R2Q1RK1[Nn] w - - 2 13'\n",
"#fen = 'rnb2rk1/p3bppp/2p5/3p2P1/4n3/8/PPPPBPPP/RNB1K1NR/QPPq w KQ - 0 11'\n",
"#fen = 'r1b1kbnr/ppp1pppp/2n5/3q4/3P4/8/PPP1NPPP/RNBQKB1R/Pp b KQkq - 1 4'\n",
"#fen = 'r1b1k2r/ppp2ppp/2n5/3np3/3P4/2PBP3/PpPB1PPP/1Q2K1NR/QNrb b Kkq - 27 14'\n",
"board.set_fen(fen)\n",
"#board.set_fen(fen)\n",
"\n",
"state = GameState(board)\n",
"board"
Expand Down Expand Up @@ -147,15 +134,49 @@
" plt.yticks(range(len(moves_ordered)), moves_ordered)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evalution using the raw network"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t_s = time()\n",
"value, legal_moves, p_vec_small = raw_agent.evaluate_board_state(state)\n",
"print('Elapsed time: %.4fs' % (time()-t_s))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_moves_with_prob(legal_moves, p_vec_small, only_top_x=10)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Evalution using the MCTS-Agent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"t_s = time()\n",
"value, legal_moves, p_vec_small = AGENT.evaluate_board_state(state)\n",
"print('Elapsed time: %.4fs' % (t_s - time()))"
"value, legal_moves, p_vec_small = mcts_agent.evaluate_board_state(state)\n",
"print('Elapsed time: %.4fs' % (time()-t_s))"
]
},
{
Expand All @@ -164,7 +185,7 @@
"metadata": {},
"outputs": [],
"source": [
"AGENT.get_calclated_line()"
"mcts_agent.get_calclated_line()"
]
},
{
Expand Down
3 changes: 1 addition & 2 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,7 @@ to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.

<one line to give the program's name and a brief idea of what it does.>
Copyright (C) <year> <name of author>
CrazyAra Copyright (c) 2018 Johannes Czech, Alena Beyer, Moritz Willig

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
Expand Down
108 changes: 59 additions & 49 deletions crazyara.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ def log(text: str):
print(INTRO_PT2, end="")

# GLOBAL VARIABLES
MCTSAGENT = None
RAWNETAGENT = None
GAMESTATE = None
mcts_agent = None
rawnet_agent = None
gamestate = None

SWITCHED_TO_RAW_NET = False
switched_to_raw_net = False

SETUP_DONE = False
setup_done = False

# SETTINGS
S = {
s = {
'UCI_Variant': 'crazyhouse',
# set the context in which the neural networks calculation will be done
# choose 'gpu' using the settings if there is one available
Expand Down Expand Up @@ -111,32 +111,32 @@ def setup_network():
:return:
"""

global GAMESTATE
global SETUP_DONE
global RAWNETAGENT
global MCTSAGENT
global S
global gamestate
global setup_done
global rawnet_agent
global mcts_agent
global s

if SETUP_DONE is False:
if setup_done is False:
from DeepCrazyhouse.src.domain.crazyhouse.GameState import GameState
from DeepCrazyhouse.src.domain.agent.NeuralNetAPI import NeuralNetAPI
from DeepCrazyhouse.src.domain.agent.player.RawNetAgent import RawNetAgent
from DeepCrazyhouse.src.domain.agent.player.MCTSAgent import MCTSAgent

net = NeuralNetAPI(ctx=S['context'])
net = NeuralNetAPI(ctx=s['context'])

RAWNETAGENT = RawNetAgent(net, temperature=S['centi_temperature'], clip_quantil=S['centi_clip_quantil'])
rawnet_agent = RawNetAgent(net, temperature=s['centi_temperature'], clip_quantil=s['centi_clip_quantil'])

MCTSAGENT = MCTSAgent(net, cpuct=S['centi_cpuct']/100, nb_playouts_empty_pockets=S['playouts_empty_pockets'],
nb_playouts_filled_pockets=S['playouts_filled_pockets'], max_search_depth=S['max_search_depth'],
nb_playouts_update=S['playouts_update_stats'], max_search_time_s=S['max_search_time_s'],
dirichlet_alpha=S['centi_dirichlet_alpha']/100, dirichlet_epsilon=S['centi_dirichlet_epsilon']/100,
virtual_loss=S['virtual_loss'], threads=S['threads'], temperature=S['centi_temperature']/100,
clip_quantil=S['centi_clip_quantil']/100)
mcts_agent = MCTSAgent(net, cpuct=s['centi_cpuct'] / 100, playouts_empty_pockets=s['playouts_empty_pockets'],
playouts_filled_pockets=s['playouts_filled_pockets'], max_search_depth=s['max_search_depth'],
playouts_update=s['playouts_update_stats'], max_search_time_s=s['max_search_time_s'],
dirichlet_alpha=s['centi_dirichlet_alpha'] / 100, dirichlet_epsilon=s['centi_dirichlet_epsilon'] / 100,
virtual_loss=s['virtual_loss'], threads=s['threads'], temperature=s['centi_temperature'] / 100,
clip_quantil=s['centi_clip_quantil'] / 100)

GAMESTATE = GameState()
gamestate = GameState()

SETUP_DONE = True
setup_done = True


def perform_action(cmd_list):
Expand All @@ -145,37 +145,37 @@ def perform_action(cmd_list):
:return:
"""

global SWITCHED_TO_RAW_NET
global switched_to_raw_net
global AGENT
global GAMESTATE
global MCTSAGENT
global RAWNETAGENT
global gamestate
global mcts_agent
global rawnet_agent

if len(cmd_list) >= 5:
if cmd_list[1] == 'wtime' and cmd_list[3] == 'btime':

wtime = int(cmd_list[2])
btime = int(cmd_list[4])

if GAMESTATE.is_white_to_move() is True:
if gamestate.is_white_to_move() is True:
my_time = wtime
else:
my_time = btime

if SWITCHED_TO_RAW_NET is False and int(my_time) < S['threshold_time_for_raw_net_ms']:
if switched_to_raw_net is False and int(my_time) < s['threshold_time_for_raw_net_ms']:
log_print('Switching to raw network for fast mode...')
# switch to RawNetwork-Agent
SWITCHED_TO_RAW_NET = True
switched_to_raw_net = True

elif SWITCHED_TO_RAW_NET is True and my_time >= S['threshold_time_for_raw_net_ms']:
elif switched_to_raw_net is True and my_time >= s['threshold_time_for_raw_net_ms']:
log_print('Switching back to MCTS network for slow mode...')
# switch to RawNetwork-Agent
SWITCHED_TO_RAW_NET = False
switched_to_raw_net = False

if SWITCHED_TO_RAW_NET is True or S['use_raw_network'] is True:
value, selected_move, confidence, _ = RAWNETAGENT.perform_action(GAMESTATE)
if switched_to_raw_net is True or s['use_raw_network'] is True:
value, selected_move, confidence, _ = rawnet_agent.perform_action(gamestate)
else:
value, selected_move, confidence, _ = MCTSAGENT.perform_action(GAMESTATE)
value, selected_move, confidence, _ = mcts_agent.perform_action(gamestate)

log_print('bestmove %s' % selected_move.uci())

Expand All @@ -184,17 +184,17 @@ def setup_gamestate(cmd_list):

position_type = cmd_list[1]
if position_type == "startpos":
GAMESTATE.new_game()
gamestate.new_game()

elif position_type == "fen":
sub_command_offset = cmd_list.index("moves") if "moves" in cmd_list else len(cmd_list)
fen = " ".join(cmd_list[2:sub_command_offset])

GAMESTATE.set_fen(fen)
gamestate.set_fen(fen)

mv_list = cmd_list[3:]
for move in mv_list:
GAMESTATE.apply_move(chess.Move.from_uci(move))
gamestate.apply_move(chess.Move.from_uci(move))


def set_options(cmd_list):
Expand All @@ -205,12 +205,12 @@ def set_options(cmd_list):
:return:
"""
# SETTINGS
global S
global s

if cmd_list[1] == 'name' and cmd_list[3] == 'value':
option_name = cmd_list[2]

if option_name not in S:
if option_name not in s:
raise Exception("The given option %s wasn't found in the settings list" % option_name)

if option_name in ['UCI_Variant', 'context', 'use_raw_network']:
Expand All @@ -220,11 +220,11 @@ def set_options(cmd_list):

if option_name == 'use_raw_network':
if value == 'true':
S['use_raw_network'] = True
s['use_raw_network'] = True
else:
S['use_raw_network'] = False
s['use_raw_network'] = False
else:
S[option_name] = value
s[option_name] = value

log_print('Updated option %s to %s' % (option_name, value))

Expand All @@ -235,7 +235,6 @@ def set_options(cmd_list):

# wait for an std-in input command
if line:
try:
# split the line to a list which makes parsing easier
cmd_list = line.rstrip().split(' ')
# extract the first command from the list for evaluation
Expand Down Expand Up @@ -278,11 +277,22 @@ def set_options(cmd_list):
elif main_cmd == "setoption":
set_options(cmd_list)
elif main_cmd == 'go':
perform_action(cmd_list)
try:
perform_action(cmd_list)
except:
# log the error message to the log-file and exit the script
traceback_text = traceback.format_exc()
log_print(traceback_text)
sys.exit(-1)
elif main_cmd == 'quit' or 'exit':
sys.exit(0)
except:
# log the error message to the log-file and exit the script
traceback_text = traceback.format_exc()
log_print(traceback_text)
sys.exit(-1)
else:
# give the user a message that the command was ignored
print("Unknown command: %s" % line)


# TODO: Fix missleading text for exception text
# describe the parameters more
# add network visualization
# correct link for lichess-Account
# check why threads number doesn't give as much

0 comments on commit 7e126bf

Please sign in to comment.