# Using Numba to Accelerate 2048

In [1]:
import numpy as np
import torch
from numba import cuda, jit, guvectorize

In [8]:
def check_action(state, action):
    for x in range(state.shape[0]):
        for y1 in range(state.shape[1] - 1):
            for y2 in range(y1 + 1, state.shape[1]):
                rot_x1, rot_y1 = x, y1
                rot_x2, rot_y2 = x, y2
                num_rotate = action % 4
                while num_rotate > 0:
                    rot_x1, rot_y1 = rot_y1, state.shape[0] - rot_x1 - 1
                    rot_x2, rot_y2 = rot_y2, state.shape[0] - rot_x2 - 1
                    num_rotate -= 1
                if state[rot_x2, rot_y2] == 0:
                    continue
                elif state[rot_x1, rot_y1] == 0:
                    return True
                else:
                    if state[rot_x1, rot_y1] == state[rot_x2, rot_y2]:
                        return True
                    break
    return False

## Pure Python

In [9]:
state = np.arange(16, dtype=np.uint8).reshape(4, 4)
action = 3

In [10]:
%timeit check_action(state, action)

45.7 µs ± 451 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Just in Time Compilation

In [11]:
jitted_check_action = jit(nopython=True)(check_action)

In [12]:
jitted_check_action(state, action);

In [13]:
%timeit jitted_check_action(state, action)

336 ns ± 0.731 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


## Generelized Universal Functions

In [8]:
batch_size = int(1e4)
states = np.arange(16 * batch_size, dtype=np.uint8).reshape(batch_size, 4, 4)
actions = np.arange(batch_size, dtype=np.int8)

In [9]:
%%timeit
for state, action in zip(states, actions):
    jitted_check_action(state, action)

6.28 ms ± 69.4 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
@guvectorize(['void(u1[:,:], i1[:], b1[:])'], '(n,n),()->()', target='cpu', nopython=True)
def ufunc_check_action(state, action, output):
    output[0] = jitted_check_action(state, action[0])

In [11]:
ufunc_check_action(states, actions);

In [12]:
%timeit ufunc_check_action(states, actions)

1.39 ms ± 2.79 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## GPU Acceleration

In [13]:
batch_size = int(1e8)
states = np.arange(16 * batch_size, dtype=np.uint8).reshape(batch_size, 4, 4)
actions = np.arange(batch_size, dtype=np.int8)

In [14]:
%timeit ufunc_check_action(states, actions)

13.9 s ± 9.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [15]:
cuda_check_action = cuda.jit(device=True)(check_action)

In [16]:
@guvectorize(['void(u1[:,:], i1[:], u1[:])'], '(n,n),()->()', target='cuda', nopython=True)
def gpu_check_action(state, action, output):
    output[0] = jitted_check_action(state, action[0])

In [17]:
%timeit gpu_check_action(states, actions)

296 ms ± 2.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [18]:
device = torch.device("cuda")
states = torch.arange(16 * batch_size, dtype=torch.uint8, device=device).reshape(batch_size, 4, 4)
actions = torch.arange(batch_size, dtype=torch.int8, device=device)
output = torch.zeros(batch_size, dtype=torch.uint8, device=device)

In [19]:
%timeit gpu_check_action(states, actions, out=output)

48.9 ms ± 533 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
