In [13]:
%%time
# Compiling HexState
from Numba_hex_class import HexState, _simulate
_simulate(10)

Wall time: 1.99 ms


In [14]:
from settings import RAVE_constants
from random import randint
from math import sqrt, log
from numba import njit, prange
from numba import int64, deferred_type, optional, types, float64, typed
from numba.experimental import jitclass
import numpy as np



In [15]:
%%time
@njit(parallel=True)
def run_sims():
    for x0 in prange(100):
        _simulate(100)
        
run_sims()

Wall time: 3.8 s


In [16]:
# Multi processed
%time run_sims()

Wall time: 464 ms


In [17]:
# Not multi-processed
%time _simulate(10000)

Wall time: 920 ms


In [18]:

Node_type = deferred_type()

spec = (
    
    ('parent', optional(Node_type[:])),
    ('move', int64),
    ('children', optional(Node_type[:])),

    ('N', int64),
    ('Q', int64),
    ('N_rave', int64),
    ('Q_rave', int64),
    
)


@jitclass(spec)
class Node:
    def __init__(self, move, parent):
        self.parent = parent
        self.move = move
        self.children = np.zeros(0, dtype=Node_type)

        self.N = 0
        self.Q = 0
        self.N_rave = 0
        self.Q_rave = 0

    def value(self):
        if self.N == 0:
            return float64(1000)
        
        rave_weight = max(0, 1 - (self.N/RAVE_constants.rave_const))
        UCT_value = self.Q/self.N + RAVE_constants.explore * \
            sqrt(2 * log(self.parent.N/self.N))
        
        if self.N_rave != 0:
            rave_value = self.Q_rave/self.N_rave
        else: 
            rave_value = 0

        value = (1 - rave_weight)*UCT_value + rave_weight*rave_value
        return float64(value)

    def get_N(self):
        return self.N
    
    def get_move(self):
        return self.move
    
    def get_children(self):
        return self.children
    
    def get_parent(self):
        return self.parent
    
    def get_move(self):
        return self.move

    def add_children(self, new_children):
        self.children = np.append(self.children, new_children)
        
    def addon_rave(self, q,n):
        self.Q_rave += q
        self.N_rave += n
        
    def addon_mcts(self, q, n):
        self.Q += q
        self.N += n
        


In [19]:

@njit
def expand(parent: Node, state):
    if state.winner() != 0:
        return 0

    new_children = np.zeros(0, dtype=Node_type)
    for move in state.possible_moves():
        new_children.append(Node(move, np.array([parent])))

    parent.add_children(new_children)
    return 1


In [20]:

@njit
def fetch_leaf_node(root_node, root_state):
    node = root_node
    state = root_state.copy()

    while node.get_children().size != 0:
        benchmark = float64(-1000)
        for child in node.children:
            value_of_child = child.value()
            if value_of_child > benchmark:
                selected_node = child
                benchmark = value_of_child
        node = selected_node
        move = node.get_move()
        state.step(move)

        if node.get_N() == 0:
            return node, state

    if expand(node, state) == 1:
        children = node.get_children()
        child_index = randint(0, children.size-1)
        node = children[child_index]
        state.step(node.get_move())

    return node, state



In [21]:

@njit
def rollout(state):
    moves = state.possible_moves()
    while state.winner() != 0:
        move_index = randint(0, moves.size-1)
        move = moves[move_index]
        state.step(move)
        moves = np.delete(moves, move_index)

    blk_rave_pieces = np.zeros(0, dtype=np.int64)
    wht_rave_pieces = np.zeros(0, dtype=np.int64)

    board = state.get_board()
    blk_rave_pieces.append(np.where(board == 1)[0])
    wht_rave_pieces.append(np.where(board == -1)[0])

    return state.winner(), blk_rave_pieces, wht_rave_pieces


In [22]:
@njit
def backup(outcome, turn, node, blk_rave_pieces, wht_rave_pieces ):
    
    if outcome == turn:
        reward = -1
    else:
        reward = 1
        
    while node.get_parent().size != 0:
        if turn == 1:
            for child in node.get_children():
                child_move = child.get_move()
                if child_move in blk_rave_pieces:
                    child.addon_rave(-reward, 1)
        else:
            for child in node.get_children():
                child_move = child.get_move()
                if child_move in wht_rave_pieces:
                    child.addon_rave(-reward, 1)
                    
        node.addon_mcts(reward, 1)
        turn *= -1
        reward *= -1
        node = node.get_parent()[0]

In [23]:

@njit
def fetch_best_move(state):
    root_state = state.copy()
    root_node = Node(100, np.zeros(0, dtype= Node_type))
    
    num_simulation = 0
    while num_simulation >10:
        node, state = fetch_leaf_node(root_node,root_state)
        turn = state.get_to_play()
        
        winner, blk_rave_pieces, wht_rave_pieces = rollout(state)
        backup(winner, turn, node, blk_rave_pieces, wht_rave_pieces)
        num_simulation +=1

In [24]:

parent_type = (types.int64, types.int64)
rank_type = (types.int64, types.int64)
groups_type = (types.int64, types.int64[:])
size = 6

brd = np.zeros(size**2, dtype=np.int64)
to_play = 1
EDGE_START = 1000
EDGE_FINISH = -1000
blk_parent = typed.Dict.empty(*parent_type)
blk_rank = typed.Dict.empty(*rank_type)
blk_groups = typed.Dict.empty(*groups_type)
wht_parent = typed.Dict.empty(*parent_type)
wht_rank = typed.Dict.empty(*rank_type)
wht_groups = typed.Dict.empty(*groups_type)


board = HexState(size, brd, to_play, EDGE_START, EDGE_FINISH, blk_parent, blk_rank, blk_groups, wht_parent, wht_rank, wht_groups)

In [25]:
fetch_best_move(board)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
[1mUntyped global name 'backup':[0m [1m[1mCannot determine Numba type of <class 'function'>[0m
[1m
File "..\..\..\..\HARSHK~1\AppData\Local\Temp\ipykernel_8676\2628448253.py", line 12:[0m
[1m<source missing, REPL/exec in use?>[0m
[0m