In [2]:
from __future__ import division, absolute_import, print_function

In [1]:
import numpy as np
from wgomoku import UCT_Node, PolicyAdapter, GomokuEnvironment, GomokuState

# MCTS with UCB and policy support

In [2]:
def uct_search(game_state, policy, num_reads, verbose=0, rollout_delay=0):
    """
    """
    def _is_terminal(value):
        return value == 1 or value == -1
    
    def _choose_from(distr):
        return np.random.choice(range(5), 1, p=distr)[0]
    
    count = [0,0,0]
    root = UCT_Node(game_state)
    for read_count in range(num_reads):

        # UCB-driven selection
        leaf = root.select_leaf()
        
        # policy-advised rollout until terminal state
        leaf_priors, _,  value_estimate = policy.evaluate(leaf.game_state)

        priors = leaf_priors
        game = leaf.game_state
        
        if read_count > rollout_delay:
            while value_estimate == 0:
                move = _choose_from(priors)
                game = game.play(move)
                priors, _,  value_estimate = policy.evaluate(game, view_point = leaf.game_state.to_play )
        
        count[value_estimate+1] += 1

        if _is_terminal(value_estimate):
            if verbose > 1:
                print(leaf, leaf.game_state)
            leaf.backup(value_estimate)
        else:
            # Only expand non-terminal states
            leaf.expand(leaf_priors)
            
    if verbose > 0:
        print("Counts: %s" % count)
    move, _ = max(root.children.items(), 
               key = lambda item: item[1].number_visits)
    return root, move

In [3]:
from copy import deepcopy
class SillyGame: 
    
    def __init__(self, board):
        self.legal_actions = [1,2,3,4]
        self.n_actions = len(self.legal_actions)
        self.total = 0
        self.to_play = 1
        self.board=board

    def play(self, move):     
        if self.total + move >= len(self.board):            
            print(self)
            raise ValueError("Beyond boundary: Pos=%s, Move=%s" % (self.total,move))
        if not move in self.legal_actions:
            raise ValueError("Not a legal move: %s" %move)
        cp = deepcopy(self)
        cp.to_play *= -1
        cp.total += move
        return cp
    
    def __repr__(self):
        stone = 'x' if self.to_play==1 else 'o'
        board = self.board[:self.total] + stone + self.board[self.total+1:]
        return ("total: %02d, %s" % (self.total, board))
    

In [4]:
class SillyPolicy:

    def __init__(self, lookahead=False):
        self.lookahead = lookahead
    
    def _is_done(self, board, total):
        return board[total] == 'I'
            
    def evaluate(self, game, view_point=1):
        """
        The game is over, once any player has stepped on a mine ('I'). The other
        (obviously) has won, then. If view_point==1 (black), evaluate returns +1,
        if the next move (to_play) would be black: i.e. white stepped on the mine.
        """
        if self._is_done(game.board, game.total):
            # the opponent (previous player) lost -> I won
            return [0,0,0,0,0], 0, view_point * game.to_play
        else: 
            distr = [0,0,0,0,0]
            array = [1,2,3,4]
            np.random.shuffle(array)
            if not self.lookahead:
                return [0., .25,  .25,  .25,  .25], array[0], 0
            for n in array:
                
                #
                #  TODO: This logic is ill. It'll default to n=1 unless that's fatal
                #  then 2 unless that's fatal, etc.
                #
                if not self._is_done(game.board, game.total+n):
                    distr[n] = 1
                    return distr, n, 0
            # Can't move -> from the viewpoint of the one to play: Bad!
            return distr, 0, - view_point * game.to_play
                    

In [5]:
p = SillyPolicy()
g = SillyGame('..-...I-.II.-....-III....-....-IIII')
print(g)
p._is_done(g.board, 6), p._is_done(g.board, 7)

total: 00, x.-...I-.II.-....-III....-....-IIII


(True, False)

In [6]:
p = SillyPolicy()
g = SillyGame('..-...I-.II.-....-III....-....-IIII')

print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

g = g.play(1)
print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

g = g.play(1)
print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

g = g.play(4)
print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

1
total: 00, x.-...I-.II.-....-III....-....-IIII 0
-1
total: 01, .o-...I-.II.-....-III....-....-IIII 0
1
total: 02, ..x...I-.II.-....-III....-....-IIII 0
-1
total: 06, ..-...o-.II.-....-III....-....-IIII -1


#### A Game with a lookahead policy
The policy's ```evaluate()``` method returns the value of the board from black's point of view,
if not specified otherwise. I.e.: Here, a value of $-1$ means "white wins".

In [7]:
p = SillyPolicy(lookahead=True)
g = SillyGame('..-...I-.II.-....-III....-....-IIII')
choice = 2

while choice != 0:
    g = g.play(choice)
    _, choice, value = p.evaluate(g)
    print(g, choice, value)
    
print(("white" if value == -1 else "black") + " wins.")

total: 02, ..o...I-.II.-....-III....-....-IIII 2 0
total: 04, ..-.x.I-.II.-....-III....-....-IIII 1 0
total: 05, ..-..oI-.II.-....-III....-....-IIII 2 0
total: 07, ..-...Ix.II.-....-III....-....-IIII 1 0
total: 08, ..-...I-oII.-....-III....-....-IIII 4 0
total: 12, ..-...I-.II.x....-III....-....-IIII 4 0
total: 16, ..-...I-.II.-...o-III....-....-IIII 1 0
total: 17, ..-...I-.II.-....xIII....-....-IIII 4 0
total: 21, ..-...I-.II.-....-IIIo...-....-IIII 2 0
total: 23, ..-...I-.II.-....-III..x.-....-IIII 2 0
total: 25, ..-...I-.II.-....-III....o....-IIII 1 0
total: 26, ..-...I-.II.-....-III....-x...-IIII 1 0
total: 27, ..-...I-.II.-....-III....-.o..-IIII 3 0
total: 30, ..-...I-.II.-....-III....-....xIIII 0 -1
white wins.


### Monte Carlo Tree Search with UCB

In [9]:
b_sure_win = '....-....-....-IIII'
w_sure_win = '.....-....-....-IIII'
g0 = SillyGame(b_sure_win)
print(g0)
root = UCT_Node(g0, C=0.5)
leaf = root

for move in [4,4,1,2,3,1]:
    leaf.expand([0]+[.25]*4)
    leaf = leaf.children[move]
print(leaf.game_state, leaf)
p.evaluate(leaf.game_state)

total: 00, x...-....-....-IIII


AttributeError: 'list' object has no attribute 'items'

In [11]:
leaf.backup(1)

Please observe that the back propagation comes with alternating signs. A win sequence for one player is bad for the other.

In [12]:
while leaf.parent:
    print(leaf, leaf.total_value)
    leaf = leaf.parent

[None, 4, 4, 1, 2, 3, 1] 1.0
[None, 4, 4, 1, 2, 3] -1.0
[None, 4, 4, 1, 2] 1.0
[None, 4, 4, 1] -1.0
[None, 4, 4] 1.0
[None, 4] -1.0


"[None, 4] -1.0" means: After the first move of 4 fields, the value of the board for the next (to_play == white) player is $-1$.

---

In [13]:
root.game_state.to_play, root.child_Q()

(1, array([ 0.,  0.,  0.,  0., -1.], dtype=float32))

Read this like: Black is next to move and a move of $4$ will leave its opponent a $-1$ valued board. I guess that's a good reason to do just that.

---

In [14]:
root.children[4].child_Q()

array([0., 0., 0., 0., 1.], dtype=float32)

Read this like: If white plays a $4$ that'd be good for black. We know that that holds for any other move, too.

---

### The Rule of 5
The rule of five says: Once on one of the safe positions, always move such that your move adds up to 5 with the opponents move. That brings you to the next safe position. 
Now, with 10000 runs, the tree search eventually *understands* the rule of 5.

In [15]:
g0 = SillyGame(b_sure_win)
root, move = uct_search(g0, p, 10000, verbose=1, rollout_delay=5000)
endpos = root.best_child().best_child().best_child().best_child().best_child()
print(endpos, endpos.game_state)
print(root.child_Q())

Counts: [1125, 1474, 7401]
[None, 4, 2, 3, 4, 1] total: 14, ....-....-....oIIII
[ 0.          0.16852011  0.16846848  0.16853933 -0.95321965]


What (almost always) works with 10000 runs, sometimes doesn't work with less than that. Try it a couple of times with 2000 to verify. The reason for that is that the tree is expanded in some mixed manner that explores breadth early in the search. Thus, it takes quite a number of selections until eventually more and more terminal states are explored. And only those provide useful feedback for the evaluation of the moves.

With a single field added to the board, whatever black does, white can play the winning strategy, because black starting from a safe position must leave it with the first move so that white can occupy the next safe position and so on to win the game.

In [16]:
g0 = SillyGame(w_sure_win)
root, move = uct_search(g0, p, 10000, verbose=1, rollout_delay=5000)
endpos = root.best_child().best_child().best_child().best_child().best_child().best_child()
print(endpos, endpos.game_state)
print(root.child_Q())

Counts: [6859, 2316, 825]
[None, 1, 4, 3, 2, 2, 3] total: 15, .....-....-....xIIII
[0.         0.60344005 0.6034602  0.60348165 0.6034985 ]


You can actually see from the Q-Values of black's first move that whatever black does will leave white with an almost equal positive board value. Observe that the moves add up to 5, pairwise.

### More complex games
Now look at the below complex sure-win for black that even has a deviation for the rule of five. It takes much longer for the tree search to find a winning sequence with confidence. Be warned that 60000 runs will take some time - may be up to a minute - depending on your machine.

In [17]:
complex_b_win='..-...I-.II.-....-III....-....-IIII'

In [19]:
g0 = SillyGame(complex_b_win)
root, move = uct_search(g0, p, 40000, verbose=1, rollout_delay=20000)
endpos = (root.best_child().best_child().best_child().best_child().best_child()
          .best_child().best_child().best_child().best_child().best_child().best_child())
print(endpos, endpos.game_state)
print(root.child_Q())

Counts: [2518, 6079, 31403]
[None, 2, 2, 3, 1, 4, 4, 1, 4, 4, 2, 3] total: 30, ..-...I-.II.-....-III....-....oIIII
[ 0.          0.1        -0.72265166  0.1         0.25      ]


One of the reasons for the search to take so long to converge is the intermediate traps that could be avoided easily with a little foresight. But since our current policy is oblivious of the board situation, it needs a long time before the tree learns to get around the intermediate traps. So what if we're supported by a policy that applies at least a one-move foresight?

In [26]:
forsight_policy = SillyPolicy(lookahead=True)
g0 = SillyGame(complex_b_win)
root, move = uct_search(g0, forsight_policy, 20000, verbose=1, rollout_delay=20000)
endpos = (root.best_child().best_child().best_child().best_child().best_child()
          .best_child().best_child().best_child().best_child().best_child().best_child())
print(endpos, endpos.game_state)
print(root.child_Q())

Counts: [2941, 7040, 10019]
[None, 2, 3, 2, 1, 4, 2, 3, 4, 4, 3, 2] total: 30, ..-...I-.II.-....-III....-....oIIII
[ 0.          0.2        -0.35436893  0.2         0.14285715]


That policy expectedly helped the search converge a bit faster. And that's exactly the clue here: MCTS with UCB can be significantly improved with the help of a reasonable policy instead of just random (Monte Carlo) move selection.