In [6]:
import numpy as np

import gym
from gym.spaces import Discrete, Box

from keras.models import Sequential
from keras.layers import Dense, Activation, Flatten
from keras.optimizers import Adam

from rl.agents.dqn import DQNAgent
from rl.policy import BoltzmannQPolicy
from rl.memory import SequentialMemory
from rl.callbacks import Callback


class TicTacToe:

    def __init__(self):
        self.board_state = None

    def set_state(self, new_state):
        """ 2d array of cell positions of the board. 0 = cell not occupied,
            1 = cross occupies cell, 2 = nought occupies cell.
            Example: [
                [0, 0, 1],
                [0, 0, 2],
                [0, 0, 0]
            ] """

        new_state = np.array(new_state)

        assert new_state.shape == (len(new_state), len(new_state))

        self.board_state = new_state

        return self.board_state

    def is_finished(self):
        """ 0 = not finished, 1 = cross win, 2 = nought win, 3 = tie """

        # Are we tied?
        if self.board_state.flatten().tolist().count(0) == 0:
            return 3

        # Stolen: https://codereview.stackexchange.com/a/24775
        positions_groups = (
            [[(x, y) for y in range(self.get_board_size())] for x in range(self.get_board_size())] +  # horizontals
            [[(x, y) for x in range(self.get_board_size())] for y in range(self.get_board_size())] +  # verticals
            [[(d, d) for d in range(self.get_board_size())]] +  # diagonal from top-left to bottom-right
            [[(2-d, d) for d in range(self.get_board_size())]]  # diagonal from top-right to bottom-left
        )
        for positions in positions_groups:
            values = [self.board_state[x][y] for (x, y) in positions]
            if len(set(values)) == 1 and values[0]:
                return values[0]

        # Game isn't finished
        return 0

    def get_board_size(self):
        return len(self.board_state)

    def get_turn(self):
        """ Returns 1 for crosses turn, 2 for noughts turn """

        flattened_list = self.board_state.flatten().tolist()

        if flattened_list.count(1) > flattened_list.count(2):
            return 2
        else:
            return 1

    def make_move(self, x, y):
        """ Updates the state with the requested new occupied cell """
        x = int(x)
        y = int(y)
        
        # Sanity check
        assert x < self.get_board_size() and y < self.get_board_size()
        assert self.board_state[y][x] == 0
        assert self.is_finished() == 0

        new_state = self.board_state.copy()
        new_state[y][x] = self.get_turn()

        return self.set_state(new_state)

    @staticmethod
    def translate_position_to_xy(position, board_size=3):
        """ Takes a single number and maps it to x, y coordinates.
            Example: 8 = 2, 2 for a board_size of 3 """

        x = position % board_size
        y = position / board_size

        return x, y


class TicTacToeEnv:
    action_space = Discrete(3**2)

    def __init__(self, board_size=3, predict_for=None):
        self.board_size = board_size
        self.predict_for = predict_for

        self.observation_space = Box(
            low=np.array([0 for cell in range(self.board_size ** 2)]),
            high=np.array([2 for cell in range(self.board_size ** 2)])
        )

    def reset(self):
        if self.predict_for is not None:
            self.tictactoe = TicTacToe()
            self.tictactoe.set_state(self.predict_for)
            return self.tictactoe.board_state.flatten()

        self.tictactoe = TicTacToe()
        self.tictactoe.set_state([
            [0, 0, 0],
            [0, 0, 0],
            [0, 0, 0]
        ])
        move = self._get_random_move()

        self.tictactoe.make_move(move[0], move[1])

        return self.tictactoe.board_state.flatten()

    def step(self, action):
        if self.predict_for is not None:
            return self.tictactoe.board_state.flatten(), 0, True, {}

        translated_action = TicTacToe.translate_position_to_xy(action)

        try:
            self.tictactoe.make_move(translated_action[0], translated_action[1])

        except AssertionError:
            return self.tictactoe.board_state.flatten(), -1, True, {}

        reward = 0
        done = False
        winner = self.tictactoe.is_finished()
        if winner == 0:
            move = self._get_random_move()
            self.tictactoe.make_move(move[0], move[1])

            next_winner = self.tictactoe.is_finished()
            if next_winner == 1:
                reward = -1
                done = True
            elif next_winner == 3:
                reward = 0
                done = True

        elif winner == 2:
            reward = 1
            done = True

        elif winner == 3:
            reward = 0
            done = True

        return self.tictactoe.board_state.flatten(), reward, done, {}

    def _get_random_move(self):
        assert self.tictactoe.is_finished() == 0

        positions = []
        for x in range(self.board_size):
            for y in range(self.board_size):
                if self.tictactoe.board_state[y][x] == 0:
                    positions.append((x, y))

        return positions[np.random.choice(len(positions), 1)[0]]

    
class ModelIntervalCheckpoint(Callback):
    def __init__(self, interval, verbose=0):
        super(ModelIntervalCheckpoint, self).__init__()
        self.interval = interval
        self.step = 0

        self.rewards = []
        self.last_max = -1

    def reset(self):
        self.rewards = []

    def on_step_begin(self, step, logs):
        if self.step % self.interval == 0:
            if len(self.rewards) > 0:
                mean_reward = np.nanmean(self.rewards, axis=0)
                if mean_reward > self.last_max:
                    filename = 'saved-weights/%s.h5f' % mean_reward
                    print("\nSaving model checkpoint with mean reward %s to %s" % (mean_reward, filename))

                    self.model.save_weights(filename, overwrite=True)

                    self.last_max = mean_reward

            self.reset()

    def on_step_end(self, step, logs={}):

        self.rewards.append(logs['reward'])
        self.step += 1
        
def predict(board_state, model_path):
    env = TicTacToeEnv(predict_for=board_state)

    dqn = build_dqn(env)

    dqn.load_weights(model_path)

    dqn.test(env, nb_episodes=1, visualize=False, verbose=0)

    return dqn.recent_action


def build_dqn(env):
    nb_actions = env.action_space.n

    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
    model.add(Dense(128))
    model.add(Activation('relu'))
    model.add(Dense(64))
    model.add(Activation('relu'))
    model.add(Dense(32))
    model.add(Activation('relu'))
    model.add(Dense(nb_actions, activation='linear'))

    memory = SequentialMemory(limit=5000000, window_length=1)
    policy = BoltzmannQPolicy()
    log_interval = 10000

    dqn = DQNAgent(
        model=model,
        nb_actions=nb_actions,
        memory=memory,
        nb_steps_warmup=1000,
        enable_dueling_network=False,
        target_model_update=1e-2,
        policy=policy
    )

    dqn.compile(Adam(lr=1e-5), metrics=['accuracy', 'mae'])

    return dqn


env = TicTacToeEnv()

nb_actions = env.action_space.n

model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(128))
model.add(Activation('relu'))
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dense(32))
model.add(Activation('relu'))
model.add(Dense(nb_actions, activation='linear'))

memory = SequentialMemory(limit=5000000, window_length=1)
policy = BoltzmannQPolicy()
log_interval = 10000

dqn = DQNAgent(
    model=model,
    nb_actions=nb_actions,
    memory=memory,
    nb_steps_warmup=1000,
    enable_dueling_network=False,
    target_model_update=1e-2,
    policy=policy
)

dqn.compile(Adam(lr=1e-5), metrics=['accuracy', 'mae'])

dqn.fit(env, nb_steps=50000, visualize=False, verbose=1,
    callbacks=[ModelIntervalCheckpoint(interval=log_interval)],
    log_interval=log_interval
)

In [8]:
board_state = np.array([
            [1, 0, 0],
            [0, 1, 0],
            [0, 0, 0]
        ])
predict(board_state, 'saved-weights/-0.2235.h5f')

6