In [1]:
from numba import jit, cuda, guvectorize
import numpy

In [2]:
def base_rotate_index(x, y, rotation, board_size):
    num_rotate = rotation % 4
    while num_rotate > 0:
        x, y = y, board_size - x - 1
        num_rotate -= 1
    return x, y

In [3]:
def base_do_action(state, action, punishment, simulate):
    reward = 0
    update = False
    for x in range(state.shape[0]):
        for y1 in range(state.shape[1] - 1):
            for y2 in range(y1 + 1, state.shape[1]):
                rot_x1, rot_y1 = rotate_index(x, y1, action, state.shape[0])
                rot_x2, rot_y2 = rotate_index(x, y2, action, state.shape[0])
                if state[rot_x2, rot_y2] == 0:
                    continue
                elif state[rot_x1, rot_y1] == 0:
                    if not simulate:
                        state[rot_x1, rot_y1] = state[rot_x2, rot_y2]
                        state[rot_x2, rot_y2] = 0
                    update |= True
                else:
                    if state[rot_x1, rot_y1] == state[rot_x2, rot_y2]:
                        if not simulate:
                            state[rot_x1, rot_y1] += 1
                            state[rot_x2, rot_y2] = 0
                        reward += 2 ** state[rot_x1, rot_y1]
                        update |= True
                    break
    return reward if update else -1 * punishment

In [4]:
def base_do_actions(state, action, punishment, simulate, reward):
    reward[0] = do_action(state, action[0], punishment, simulate)

In [5]:
class Test(object):
    def __init__(self):
        rotate_index = jit(nopython=True)(base_rotate_index)
        do_action = jit(nopython=True)(base_do_action)
        self.do_actions = guvectorize(['void(u1[:,:], i1[:], i4, b1, i4[:])'], '(n,n),(),(),()->()', target='cpu', nopython=True)(base_do_actions)

In [6]:
rotate_index = jit(nopython=True)(base_rotate_index)
do_action = jit(nopython=True)(base_do_action)
x = Test()

array([ 0, -1], dtype=int32)

In [8]:
%timeit x.do_actions(numpy.arange(32, dtype=numpy.uint8).reshape((2, 4, 4)), 1, 1, True)

6.7 µs ± 114 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
