/
generate_mcts_games.py
70 lines (55 loc) · 2.16 KB
/
generate_mcts_games.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
import numpy as np
from dlgo.encoders import get_encoder_by_name
from dlgo import goboard
from dlgo import mcts
from dlgo.utils import print_board, print_move
def generate_game(board_size, rounds, max_moves, temperature):
"""Will let instance of MCTSAgent play games against itself."""
boards, moves = [], []
encoder = get_encoder_by_name('oneplane', board_size)
game = goboard.GameState.new_game(board_size)
bot = mcts.MCTSAgent(rounds, temperature)
num_moves = 0
while not game.is_over():
print_board(game.board)
move = bot.select_move(game)
if move.is_play:
boards.append(encoder.encode(game))
move_one_hot = np.zeros(encoder.num_points())
move_one_hot[encoder.encode_point(move.point)] = 1
moves.append(move_one_hot)
print_move(game.next_player, move)
game = game.apply_move(move)
num_moves += 1
if num_moves > max_moves:
break
return np.array(boards), np.array(moves)
def main():
"""Run a few games and persist them afterward."""
parser = argparse.ArgumentParser()
parser.add_argument('--board-size', '-b', type=int, default=9)
parser.add_argument('--rounds', '-r', type=int, default=1000)
parser.add_argument('--temperature', '-t', type=float, default=0.8)
parser.add_argument('--max-moves', '-m', type=int, default=60, help='Max moves per game.')
parser.add_argument('--num-games', '-n', type=int, default=10)
parser.add_argument('--board-out')
parser.add_argument('--move-out')
# Parse command line args
args = parser.parse_args()
xs = []
ys = []
for i in range(args.num_games):
print('Generating game %d/%d...' % (i+1, args.num_games))
# Generate the Game data
x, y = generate_game(args.board_size, args.rounds, args.max_moves, args.temperature)
xs.append(x)
ys.append(y)
# Concatenate features and labels after all games generated
x = np.concatenate(xs)
y = np.concatenate(ys)
# Store feature and label data
np.save(args.board_out, x)
np.save(args.move_out, y)
if __name__ == '__main__':
main()