In [1]:
from __future__ import division, absolute_import, print_function

In [2]:
import numpy as np

In [463]:
def uct_search(game_state, policy, num_reads, verbose=0):
    count = [0,0,0]
    root = UCT_Node(game_state)
    for _ in range(num_reads):
        leaf = root.select_leaf()
        child_priors, _,  value_estimate = policy.evaluate(leaf.game_state)
        count[value_estimate+1] += 1
        if value_estimate == 0:
            leaf.expand(child_priors)
        leaf.backup(value_estimate)
    print("Counts: %s" % count)
    move, _ = max(root.children.items(), 
               key = lambda item: item[1].number_visits)
    return root, move

In [500]:
class UCT_Node:
    def __init__(self, game_state, move=None, parent=None, C=1.4):
        self.game_state = game_state
        self.move = move
        self.is_expanded = False
        self.parent = parent
        self.children = {}
        self.child_priors = np.zeros(
            [self.game_state.n_actions+1], dtype=np.float32)
        self.child_total_value =  np.zeros(
            [self.game_state.n_actions+1], dtype=np.float32)
        self.child_number_visits =  np.zeros(
            [self.game_state.n_actions+1], dtype=np.float32)
        self.C = C
        self.penalty = 0

        
    def any_argmax(self, aa):
        """
        stochastic function:
        returns any of the indices that have the maximum.
        """
        import random
        ind = np.argmax(aa)
        m = aa[ind]
        choices = np.where(aa == m)
        return random.choice(np.ndarray.tolist(choices[0]))        

    
    def child_Q(self):
        return self.child_total_value / (1 + self.child_number_visits)

    def child_U(self):
        if self.number_visits == 0: 
            return np.zeros([5], dtype=float)
        else:
            return np.sqrt(np.log(self.number_visits) / 
                           (1 + self.child_number_visits))

        # TODO: This is not reflecting the theory
        return np.sqrt(self.number_visits) * (
            self.child_priors / (1 + self.child_number_visits))

    def best_child(self):
        """
        We are looking for the child that has the *worst* value, because
        the value is always computed as from the point of view of the one
        who's moving next. And the children of this node are all adversary
        moves. In short: Maximize the opponents pain! My move is your pain.
        """
        qus = - self.child_Q() + self.C * self.child_U()
        move = self.any_argmax(qus[1:])+1 # because we're starting at index 1
        if self.children:
            return self.children[move]
        else:
            return None

    
    def select_leaf(self):
        current = self
        while current.is_expanded:
            current.number_visits += 1
            current.total_value -= self.penalty #make it less attractive until backup
            current = current.best_child()
        return current
    
    def expand(self, child_priors):
        self.is_expanded = True
        self.child_priors = child_priors
        
        for move in self.game_state.legal_actions:
            self.add_child(move)
            
    def add_child(self, move):
        self.children[move] = UCT_Node(
            self.game_state.play(move), move=move, parent=self)
    
    def backup(self, value_estimate):
        current = self
        while current.parent is not None:
            upd = value_estimate * current.game_state.to_play + current.penalty
            #print("Updating: %s: %s, %s" % (self, current.game_state.to_play, upd))
            current.total_value += upd
            current = current.parent
            
    def pathid(self):
        leaf = self
        name = [leaf.move]
        while leaf.parent:
            name = [leaf.parent.move] + name
            leaf = leaf.parent
        return name    
            
    def __repr__(self):
        return str(self.pathid())
        
    @property
    def number_visits(self):
        if self.parent:
            return self.parent.child_number_visits[self.move]
        else:
            return 0
    
    @number_visits.setter
    def number_visits(self, value):
        #print("Value: %s" % value)
        #raise ValueError()
        if self.parent:
            self.parent.child_number_visits[self.move] = value
        
    @property
    def total_value(self):
        if self.parent:
            return self.parent.child_total_value[self.move]
        else:
            return 0
    
    @total_value.setter
    def total_value(self, value):
        if self.parent:
            self.parent.child_total_value[self.move] = value

In [501]:
from copy import deepcopy
class SillyGame: 
    
    def __init__(self, board='......I....II.........III............IIII'):
        self.legal_actions = [1,2,3,4]
        self.n_actions = len(self.legal_actions)
        self.total = 0
        self.to_play = 1
        self.board=board

    def play(self, move):     
        if self.total + move >= len(self.board):            
            print(self)
            raise ValueError("Beyond boundary: Pos=%s, Move=%s" % (self.total,move))
        if not move in self.legal_actions:
            raise ValueError("Not a legal move: %s" %move)
        cp = deepcopy(self)
        cp.to_play *= -1
        cp.total += move
        return cp
    
    def __repr__(self):
        stone = 'x' if self.to_play==1 else 'o'
        board = self.board[:self.total] + stone + self.board[self.total+1:]
        return ("total: %02d, %s" % (self.total, board))
    

In [502]:
class SillyPolicy:

    def __init__(self, lookahead=False):
        self.lookahead = lookahead
    
    def _is_done(self, board, total):
        return board[total] == 'I'
            
    def evaluate(self, game, view_point=1):
        if self._is_done(game.board, game.total):
            # the opponent lost -> I won
            return [0,0,0,0,0], 0, view_point * game.to_play
        else: 
            distr = [0,0,0,0,0]
            array = [1,2,3,4]
            np.random.shuffle(array)
            if not self.lookahead:
                return [0., .25,  .25,  .25,  .25], array[0], 0
            for n in array:
                if not self._is_done(game.board, game.total+n):
                    distr[n] = 1
                    return distr, n, 0
            # Can't move -> I make my losing move
            return distr, 0, view_point * game.to_play
                    

In [503]:
p = SillyPolicy()
g = SillyGame()
print(g)
p._is_done(g.board, 6)

total: 00, x.....I....II.........III............IIII


True

In [504]:
p = SillyPolicy()
g = SillyGame()

print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

g.play(1)
print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

g.play(1)
print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)

g.play(4)
print(g.to_play)
_, choice, state = p.evaluate(g)
print(g, state)



1
total: 00, x.....I....II.........III............IIII 0
1
total: 00, x.....I....II.........III............IIII 0
1
total: 00, x.....I....II.........III............IIII 0
1
total: 00, x.....I....II.........III............IIII 0


In [505]:
p = SillyPolicy(lookahead=True)
g = SillyGame()
choice = 2

while choice != 0:
    g = g.play(choice)
    _, choice, value = p.evaluate(g)
    print(g, choice, value)
    
print(("white" if value == -1 else "black") + " wins.")

total: 02, ..o...I....II.........III............IIII 1 0
total: 03, ...x..I....II.........III............IIII 2 0
total: 05, .....oI....II.........III............IIII 2 0
total: 07, ......Ix...II.........III............IIII 2 0
total: 09, ......I..o.II.........III............IIII 1 0
total: 10, ......I...xII.........III............IIII 4 0
total: 14, ......I....II.o.......III............IIII 2 0
total: 16, ......I....II...x.....III............IIII 3 0
total: 19, ......I....II......o..III............IIII 2 0
total: 21, ......I....II........xIII............IIII 4 0
total: 25, ......I....II.........IIIo...........IIII 1 0
total: 26, ......I....II.........III.x..........IIII 2 0
total: 28, ......I....II.........III...o........IIII 4 0
total: 32, ......I....II.........III.......x....IIII 1 0
total: 33, ......I....II.........III........o...IIII 3 0
total: 36, ......I....II.........III...........xIIII 0 1
black wins.


In [506]:
p = SillyPolicy(lookahead=False)
g0 = SillyGame('...............IIII')
root, move = uct_search(g0, p, 3)

Counts: [0, 3, 0]


In [507]:
root = UCT_Node(g0)
root.expand([0., .25, .25, .25, .25])
leaf = root.select_leaf()
child_priors, _,  value_estimate = p.evaluate(leaf.game_state)
if value_estimate == 0:
    leaf.expand(child_priors)
leaf.backup(value_estimate)
print("P/V: %s / %s" %(leaf.child_priors, leaf.total_value))
print(leaf, leaf.parent.number_visits)

P/V: [0.0, 0.25, 0.25, 0.25, 0.25] / 0.0
[None, 3] 0


In [508]:
leaf = root.select_leaf()
child_priors, _,  value_estimate = p.evaluate(leaf.game_state)
if value_estimate == 0:
    leaf.expand(child_priors)
leaf.backup(value_estimate)
print("P/V: %s / %s" %(leaf.child_priors, leaf.total_value))
print(leaf.pathid(), leaf.game_state, leaf.parent.number_visits)

P/V: [0.0, 0.25, 0.25, 0.25, 0.25] / 0.0
[None, 4] total: 04, ....o..........IIII 0


In [509]:
leaf = root.select_leaf()
child_priors, _,  value_estimate = p.evaluate(leaf.game_state)
if value_estimate == 0:
    leaf.expand(child_priors)
leaf.backup(value_estimate)
print("P/V: %s / %s" %(leaf.child_priors, leaf.total_value))
print(leaf.pathid(), leaf.game_state)

P/V: [0.0, 0.25, 0.25, 0.25, 0.25] / 0.0
[None, 4, 4] total: 08, ........x......IIII


In [510]:
leaf = root.select_leaf()
child_priors, _,  value_estimate = p.evaluate(leaf.game_state)
if value_estimate == 0:
    leaf.expand(child_priors)
leaf.backup(value_estimate)
print("P/V: %s / %s" %(leaf.child_priors, leaf.total_value))
print(leaf.pathid(), leaf.game_state)

P/V: [0.0, 0.25, 0.25, 0.25, 0.25] / 0.0
[None, 4, 1] total: 05, .....x.........IIII


In [511]:
leaf = root.select_leaf()
child_priors, _,  value_estimate = p.evaluate(leaf.game_state)
if value_estimate == 0:
    leaf.expand(child_priors)
leaf.backup(value_estimate)
print("P/V: %s / %s" %(leaf.child_priors, leaf.total_value))
print(leaf.pathid(), leaf.game_state)

P/V: [0.0, 0.25, 0.25, 0.25, 0.25] / 0.0
[None, 3, 2] total: 05, .....x.........IIII


In [519]:
b_sure_win = '....-....-....-IIII'
w_sure_win = '.....-....-....-IIII'
g0 = SillyGame(b_sure_win)
print(g0)
root = UCT_Node(g0)
leaf = root

for move in [4,4,1,2,3,1]:
    leaf.expand([0]+[.25]*4)
    leaf = leaf.children[move]
print(leaf.game_state, leaf)
p.evaluate(leaf.game_state)


#white_loss = g0.play(4).play(4).play(1).play(2).play(3).play(1)
#print(white_loss)
#print(p.evaluate(white_loss))
#black_loss = g0.play(2).play(2).play(4).play(1).play(2).play(3).play(1)
#print(black_loss)
#print(p.evaluate(black_loss))

total: 00, x...-....-....-IIII
total: 15, ....-....-....-xIII [None, 4, 4, 1, 2, 3, 1]


([0, 0, 0, 0, 0], 0, 1)

In [513]:
leaf.backup(1)

In [514]:
while leaf.parent:
    print(leaf, leaf.total_value)
    leaf = leaf.parent

[None, 4, 4, 1, 2, 3, 1] 1.0
[None, 4, 4, 1, 2, 3] -1.0
[None, 4, 4, 1, 2] 1.0
[None, 4, 4, 1] -1.0
[None, 4, 4] 1.0
[None, 4] -1.0


In [515]:
root.game_state.to_play, root.child_Q()

(1, array([ 0.,  0.,  0.,  0., -1.], dtype=float32))

In [516]:
root.children[4].child_Q()

array([0., 0., 0., 0., 1.], dtype=float32)

In [499]:
root.children[4].children[4].child_Q()

array([ 0., -1.,  0.,  0.,  0.], dtype=float32)

In [544]:
g0 = SillyGame(b_sure_win)
root = UCT_Node(g0)
root, move = uct_search(g0, p, 10000)
move

Counts: [2814, 3556, 3630]


4

In [545]:
root.best_child().best_child().best_child().best_child().best_child()

[None, 4, 1, 4, 1, 4]

In [395]:
move = res[1]
print(move.game_state, move)
print()
while move.best_child():
    move = move.best_child()
    print("Q:      ", move.parent.child_Q()[1:])
    print("U:      ", move.parent.child_U()[1:])
    print("Values: ", move.parent.child_total_value[1:])
    print("Visits: ", move.parent.child_number_visits[1:], 
          move.parent.number_visits)
    print(move.game_state, move)
    print()

total: 03, ...o...........IIII [None, 3]

Q:       [ 0.03539823 -0.03703704 -0.07575758 -0.20512821]
U:       [0.22460277 0.26528448 0.29388833 0.3823156 ]
Values:  [ 4. -3. -5. -8.]
Visits:  [112.  80.  65.  38.] 299.0
total: 04, ....x..........IIII [None, 3, 1]

Q:       [-0.03571429 -0.07692308 -0.03703704  0.        ]
U:       [0.4105092  0.42600554 0.41804212 0.39014053]
Values:  [-1. -2. -1.  0.]
Visits:  [27. 25. 26. 30.] 112.0
total: 07, .......o.......IIII [None, 3, 1, 3]

Q:       [ 0.125       0.14285715  0.125      -0.6666667 ]
U:       [0.6381709  0.68223333 0.6381709  1.0421287 ]
Values:  [ 1.  1.  1. -2.]
Visits:  [7. 6. 7. 2.] 26.0
total: 09, .........x.....IIII [None, 3, 1, 3, 2]

Q:       [ 0.   0.  -0.5  0. ]
U:       [1.3385662  0.94650924 0.94650924 1.3385662 ]
Values:  [ 0.  0. -1.  0.]
Visits:  [0. 1. 1. 0.] 6.0
total: 13, .............o.IIII [None, 3, 1, 3, 2, 4]

Q:       [0. 0. 0. 0.]
U:       [0. 0. 0. 0.]
Values:  [0. 0. 0. 0.]
Visits:  [0. 0. 0. 0.] 0.0
tot

In [106]:
move.parent.child_U()

0

In [61]:
if value_estimate == 0:
    leaf.expand(child_priors)
leaf.backup(value_estimate)

In [65]:
root, leaf

(<__main__.UCT_Node at 0x24e899feef0>, <__main__.UCT_Node at 0x24e899feef0>)

In [None]:
return max(root.children.items(), 
           key = lambda item: item[1].number_visits)