Skip to content

Commit

Permalink
support undo
Browse files Browse the repository at this point in the history
  • Loading branch information
wodesuck committed Feb 12, 2019
1 parent d446f5a commit 55dba6b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 16 deletions.
22 changes: 19 additions & 3 deletions mcts/mcts_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,24 @@ MCTSEngine::~MCTSEngine()
LOG(INFO) << "~MCTSEngine: Deconstruct MCTSEngin succ";
}

void MCTSEngine::Reset()
void MCTSEngine::Reset(const std::string &init_moves)
{
SearchPause();
ChangeRoot(nullptr);
m_board.CopyFrom(GoState(!m_config.disable_positional_superko()));
m_simulation_counter = 0;
m_num_moves = 0;
m_moves_str.clear();
m_num_moves = (init_moves.size() + 1) / 3;
m_moves_str = init_moves;
m_gen_passes = 0;
m_byo_yomi_timer.Reset();

for (size_t i = 0; i < init_moves.size(); i += 3) {
GoCoordId x, y;
GoFunction::StrToCoord(init_moves.substr(i, 2), x, y);
m_board.Move(x, y);
m_root->move = GoFunction::CoordToId(x, y);
}

if (m_config.enable_background_search()) {
SearchResume();
}
Expand Down Expand Up @@ -219,6 +226,15 @@ void MCTSEngine::GenMove(GoCoordId &x, GoCoordId &y, std::vector<int> &visit_cou
}
}

bool MCTSEngine::Undo()
{
if (m_num_moves == 0) {
return false;
}
Reset(m_moves_str.substr(0, m_moves_str.size() - 3));
return true;
}

const GoState &MCTSEngine::GetBoard()
{
return m_board;
Expand Down
3 changes: 2 additions & 1 deletion mcts/mcts_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,11 @@ class MCTSEngine
MCTSEngine(const MCTSConfig &config);
~MCTSEngine();

void Reset();
void Reset(const std::string &init_moves="");
void Move(GoCoordId x, GoCoordId y);
void GenMove(GoCoordId &x, GoCoordId &y);
void GenMove(GoCoordId &x, GoCoordId &y, std::vector<int> &visit_count, float &v_resign);
bool Undo();
const GoState &GetBoard();
MCTSConfig &GetConfig();
void SetPendingConfig(std::unique_ptr<MCTSConfig> config);
Expand Down
25 changes: 13 additions & 12 deletions mcts/mcts_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ void ReloadConfig(MCTSEngine &engine, const std::string &config_path)
}
}

void InitMoves(MCTSEngine &engine, const std::string &moves)
{
for (size_t i = 0; i < moves.size(); i += 3) {
GoCoordId x, y;
GoFunction::StrToCoord(moves.substr(i, 2), x, y);
engine.Move(x, y);
}
}

std::string EncodeMove(GoCoordId x, GoCoordId y)
{
if (GoFunction::IsPass(x, y)) {
Expand Down Expand Up @@ -148,7 +139,7 @@ std::pair<bool, std::string> GTPExecute(MCTSEngine &engine, const std::string &c
return {true, "2"};
}
if (op == "list_commands") {
return {true, "name\nversion\nprotocol_version\nlist_commands\nquit\nclear_board\nboardsize\nkomi\ntime_settings\ntime_left\nplace_free_handicap\nset_free_handicap\nplay\ngenmove\nfinal_score\nget_debug_info\nget_last_move_debug_info"};
return {true, "name\nversion\nprotocol_version\nlist_commands\nquit\nclear_board\nboardsize\nkomi\ntime_settings\ntime_left\nplace_free_handicap\nset_free_handicap\nplay\ngenmove\nfinal_score\nget_debug_info\nget_last_move_debug_info\nundo"};
}
if (op == "quit") {
return {true, ""};
Expand Down Expand Up @@ -259,14 +250,22 @@ std::pair<bool, std::string> GTPExecute(MCTSEngine &engine, const std::string &c
if (op == "get_last_move_debug_info") {
return {true, engine.GetDebugger().GetLastMoveDebugStr()};
}
if (op == "undo") {
if (!engine.Undo()) {
return {false, "stack empty"};
}
return {true, ""};
}
LOG(ERROR) << "invalid op: " << op;
return {false, "unknown command"};
}

void GTPServing(std::istream &in, std::ostream &out)
{
auto engine = InitEngine(FLAGS_config_path);
InitMoves(*engine, FLAGS_init_moves);
if (FLAGS_init_moves.size()) {
engine->Reset(FLAGS_init_moves);
}
std::cerr << std::flush;

int id;
Expand Down Expand Up @@ -306,7 +305,9 @@ void GTPServing(std::istream &in, std::ostream &out)
void GenMoveOnce()
{
auto engine = InitEngine(FLAGS_config_path);
InitMoves(*engine, FLAGS_init_moves);
if (FLAGS_init_moves.size()) {
engine->Reset(FLAGS_init_moves);
}

GoCoordId x, y;
engine->GenMove(x, y);
Expand Down

0 comments on commit 55dba6b

Please sign in to comment.