# Using Numba to Accelerate 2048

In [1]:
import numpy as np
import torch
from numba import cuda, jit, guvectorize

In [2]:
def check_action(state, action):
    for x in range(state.shape[0]):
        for y1 in range(state.shape[1] - 1):
            for y2 in range(y1 + 1, state.shape[1]):
                rot_x1, rot_y1 = x, y1
                rot_x2, rot_y2 = x, y2
                num_rotate = action % 4
                while num_rotate > 0:
                    rot_x1, rot_y1 = rot_y1, state.shape[0] - rot_x1 - 1
                    rot_x2, rot_y2 = rot_y2, state.shape[0] - rot_x2 - 1
                    num_rotate -= 1
                if state[rot_x2, rot_y2] == 0:
                    continue
                elif state[rot_x1, rot_y1] == 0:
                    return True
                else:
                    if state[rot_x1, rot_y1] == state[rot_x2, rot_y2]:
                        return True
                    break
    return False

## Pure Python

In [3]:
state = np.arange(16, dtype=np.uint8).reshape(4, 4)
action = 3

In [4]:
%timeit check_action(state, action)

43.4 µs ± 416 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


## Just in Time Compilation

In [5]:
jitted_check_action = jit(nopython=True)(check_action)

In [6]:
%time jitted_check_action(state, action)

CPU times: user 245 ms, sys: 12.4 ms, total: 258 ms
Wall time: 260 ms


False

In [7]:
%timeit jitted_check_action(state, action)

334 ns ± 8.45 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


## Generelized Universal Functions

In [8]:
batch_size = int(1e4)
states = np.arange(16 * batch_size, dtype=np.uint8).reshape(batch_size, 4, 4)
actions = np.arange(batch_size, dtype=np.int8)

In [9]:
%%timeit
for state, action in zip(states, actions):
    jitted_check_action(state, action)

4.85 ms ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
@guvectorize(['void(u1[:,:], i1[:], b1[:])'], '(n,n),()->()', target='cpu', nopython=True)
def ufunc_check_action(state, action, output):
    output[0] = jitted_check_action(state, action[0])

In [11]:
%time ufunc_check_action(states, actions)

CPU times: user 1.39 ms, sys: 34 µs, total: 1.42 ms
Wall time: 803 µs


array([ True, False, False, ..., False, False, False])

In [12]:
%timeit ufunc_check_action(states, actions)

717 µs ± 2.89 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## GPU Acceleration

In [3]:
batch_size = int(1e8)
states = np.arange(16 * batch_size, dtype=np.uint8).reshape(batch_size, 4, 4)
actions = np.arange(batch_size, dtype=np.int8)

In [14]:
%timeit ufunc_check_action(states, actions)

7.15 s ± 28.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
cuda_check_action = cuda.jit(device=True)(check_action)

In [5]:
@guvectorize(['void(u1[:,:], i1[:], u1[:])'], '(n,n),()->()', target='cuda', nopython=True)
def gpu_check_action(state, action, output):
    output[0] = cuda_check_action(state, action[0])

In [6]:
%time gpu_check_action(states, actions)

CPU times: user 269 ms, sys: 19 ms, total: 288 ms
Wall time: 290 ms


array([1, 0, 0, ..., 0, 0, 0], dtype=uint8)

In [7]:
%timeit gpu_check_action(states, actions)

299 ms ± 1.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
device = torch.device("cuda")
states = torch.arange(16 * batch_size, dtype=torch.uint8, device=device).reshape(batch_size, 4, 4)
actions = torch.arange(batch_size, dtype=torch.int8, device=device)
output = torch.zeros(batch_size, dtype=torch.uint8, device=device)

In [10]:
%timeit gpu_check_action(states, actions, out=output)

29 ms ± 140 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
