In [1]:
%%time
# Compiling HexState
from Numba_hex_class import HexState, _simulate
_simulate(10)

Wall time: 10.3 s


In [3]:
from random import randint
from math import sqrt, log
from numba import njit, prange
from numba import int64, deferred_type, optional, types, float64, typed
from numba.experimental import jitclass
import numpy as np


In [4]:
%%time
@njit(parallel=True)
def run_sims():
    for x0 in prange(100):
        _simulate(100)
        
run_sims()
# Multi processed
%time run_sims()

# Not multi-processed
%time _simulate(10000)

Wall time: 474 ms
Wall time: 1.02 s
Wall time: 7.13 s


In [44]:
#NODE CLASS
Node_type = deferred_type()

spec = (

    ('parent', optional(Node_type)),
    ('move', int64),
    ('children', int64[:]),

    ('N', int64),
    ('Q', int64),
    ('N_rave', int64),
    ('Q_rave', int64),

)


@jitclass(spec)
class Node:
    def __init__(self, parent, move):
        self.parent = parent
        self.move = move
        self.children = np.zeros(0, dtype=np.int64)

        self.N = 0
        self.Q = 0
        self.N_rave = 0
        self.Q_rave = 0

    def value(self):
        if self.N == 0: return float64(1000)

        rave_weight = max(0, 1 - (self.N/300))
        UCT_value = self.Q/self.N + 0.5 * sqrt(2 * log(self.parent.N/self.N))
        if self.N_rave != 0: rave_value = self.Q_rave/self.N_rave
        else: rave_value = 0
        value = (1 - rave_weight)*UCT_value + rave_weight*rave_value
        
        return float64(value)

    def set_children(self, new_children):
        self.children = np.append(self.children, new_children)

    def get_N(self):
        return self.N

    def get_children(self):
        return self.children

    def get_move(self):
        return self.move
    
    def get_parent(self):
        return self.parent

    def addon_rave(self, q, n):
        self.Q_rave += q
        self.N_rave += n

    def addon_mcts(self, q, n):
        self.Q += q
        self.N += n
    
    def get_stats(self):
        print(f'N:{self.N}, Q:{self.Q}, Q_r:{self.Q_rave}, N_r:{self.N_rave}, M:{self.move}')


Node_type.define(Node.class_type.instance_type)


In [45]:
#EXPAND
@njit
def expand(parent: Node, state, mem, mem_address):
    if state.winner() != 0:
        return (0, mem_address)

    possible_moves = state.possible_moves()
    children = np.zeros_like(possible_moves)
        
    for index, move in enumerate(possible_moves):
        child_node = Node(parent, move)
        mem_address += 1
        mem[mem_address] = child_node
        children[index] = mem_address

    parent.set_children(children)
    return (1, mem_address)


In [46]:
#lead node
@njit
def leaf_node(root_node : Node, root_state, mem, mem_addrs):
    
    node = root_node
    state = root_state
    
    while node.get_children().size != 0:
        benchmark = float64(-1000)
        for child_mem_addrs in node.get_children():
            child = mem[child_mem_addrs]
            value_of_child = child.value()
            if value_of_child > benchmark:
                selected_node = child
                benchmark = value_of_child
        node = selected_node
        move = node.get_move()
        state.step(move)


        if node.get_N() == 0:

            return (node, state, mem_addrs)

    bool_expand, mem_addrs_new = expand(node, state,mem, mem_addrs)
    if bool_expand:
        children_addrs = node.get_children()
        child_index = randint(0, children_addrs.size-1)
        child_addrs = children_addrs[child_index]
        node = mem[child_addrs]
        state.step(node.get_move())

    return (node, state, mem_addrs_new)


In [47]:
#rollout
@njit
def rollout(state):
    moves = state.possible_moves()
    while state.winner() != 0:
        move_index = randint(0, moves.size-1)
        move = moves[move_index]
        state.step(move)
        moves = np.delete(moves, move_index)

    blk_rave_pieces = np.zeros(0, dtype=np.int64)
    wht_rave_pieces = np.zeros(0, dtype=np.int64)

    board = state.get_board()
    blk_rave_pieces = np.append(blk_rave_pieces,np.where(board == 1)[0] )
    wht_rave_pieces = np.append(wht_rave_pieces,np.where(board == -1)[0] )

    return (state.winner(), blk_rave_pieces, wht_rave_pieces)


In [48]:
#BACKUP
@njit
def backup(outcome, turn, node, blk_rave_pieces, wht_rave_pieces, mem ):
    
    if outcome == turn:
        reward = -1
    else:
        reward = 1
        
    while node is not None:
        if turn == 1:
            for child_addrs in node.get_children():
                child  = mem[child_addrs]
                child_move = child.get_move()
                if child_move in blk_rave_pieces:
                    child.addon_rave(-reward, 1)
        else:
            for child_addrs in node.get_children():
                child  = mem[child_addrs]
                child_move = child.get_move()
                if child_move in wht_rave_pieces:
                    child.addon_rave(-reward, 1)
                    
        node.addon_mcts(reward, 1)

        turn *= -1
        reward *= -1
        node = node.get_parent()

In [49]:
#fetch_best_move
@njit
def fetch_best_move(state, limit):
    memory = {}
    memory_address = 0
    
    root_state = state.copy()
    root_node = Node(None, 1000)
    memory[memory_address] = root_node
    
    num_simulation = 0
    
    while num_simulation < limit:
        state_copy = root_state.copy()
        
        node, new_state, memory_address = leaf_node(root_node,state_copy,memory, memory_address)
        turn = new_state.get_to_play()
        winner, blk_rave_pieces, wht_rave_pieces = rollout(new_state)
        
        backup(winner, turn, node, blk_rave_pieces, wht_rave_pieces, memory)
        
        num_simulation +=1
        
    i  = 0
    while i<36:
        n = memory[i]
        i +=1
        memory[i].get_stats()

    

In [50]:
#SETTING UP EMPTY BOARD
parent_type = (types.int64, types.int64)
rank_type = (types.int64, types.int64)
groups_type = (types.int64, types.int64[:])
size = 6

brd = np.zeros(size**2, dtype=np.int64)
to_play = 1
EDGE_START = 1000
EDGE_FINISH = -1000
blk_parent = typed.Dict.empty(*parent_type)
blk_rank = typed.Dict.empty(*rank_type)
blk_groups = typed.Dict.empty(*groups_type)
wht_parent = typed.Dict.empty(*parent_type)
wht_rank = typed.Dict.empty(*rank_type)
wht_groups = typed.Dict.empty(*groups_type)


board = HexState(size, brd, to_play, EDGE_START, EDGE_FINISH, blk_parent, blk_rank, blk_groups, wht_parent, wht_rank, wht_groups)

In [54]:
fetch_best_move(board, 100000)

N:33715, Q:-423, Q_r:1126, N_r:50332, M:0
N:31817, Q:-435, Q_r:2648, N_r:49294, M:1
N:33708, Q:-426, Q_r:3181, N_r:53019, M:2
N:23, Q:-21, Q_r:7648, N_r:27818, M:3
N:23, Q:-21, Q_r:4875, N_r:48667, M:4
N:23, Q:-21, Q_r:969, N_r:2553, M:5
N:23, Q:-21, Q_r:1629, N_r:4913, M:6
N:23, Q:-21, Q_r:690, N_r:3834, M:7
N:23, Q:-21, Q_r:673, N_r:3829, M:8
N:23, Q:-21, Q_r:675, N_r:3827, M:9
N:23, Q:-21, Q_r:680, N_r:3832, M:10
N:23, Q:-21, Q_r:668, N_r:3830, M:11
N:23, Q:-21, Q_r:628, N_r:3774, M:12
N:23, Q:-21, Q_r:603, N_r:3747, M:13
N:23, Q:-21, Q_r:597, N_r:3747, M:14
N:23, Q:-21, Q_r:582, N_r:3742, M:15
N:23, Q:-21, Q_r:580, N_r:3720, M:16
N:23, Q:-21, Q_r:576, N_r:3716, M:17
N:23, Q:-21, Q_r:570, N_r:3714, M:18
N:23, Q:-21, Q_r:572, N_r:3714, M:19
N:24, Q:-22, Q_r:534, N_r:3672, M:20
N:23, Q:-21, Q_r:466, N_r:3608, M:21
N:23, Q:-21, Q_r:407, N_r:3551, M:22
N:23, Q:-21, Q_r:338, N_r:3480, M:23
N:23, Q:-21, Q_r:189, N_r:3333, M:24
N:23, Q:-21, Q_r:-231, N_r:2907, M:25
N:23, Q:-21, Q_r:-717, N

In [27]:
%time a = fetch_best_move(board, 100)

-0.3200734158596588
-0.3200734158596588
-0.3200734158596588
-0.3200734158596588
-0.3200734158596588
Wall time: 3.99 ms


In [26]:
%time a = fetch_best_move(board, 100)

-0.32009218931682176
-0.32009218931682176
-0.32009218931682176
-0.32009218931682176
-0.32009218931682176
Wall time: 2.99 ms
