In [1]:
import numba as nb
import numpy as np

# 普通的 for
def add1(x, c):
    rs = [0.] * len(x)
    for i, xx in enumerate(x):
        rs[i] = xx + c
    return rs

# list comprehension
def add2(x, c):
    return [xx + c for xx in x]

# 使用 jit 加速后的 for
@nb.jit(nopython=True)
def add_with_jit(x, c):
    rs = [0.] * len(x)
    for i, xx in enumerate(x):
        rs[i] = xx + c
    return rs

y = np.random.random(10**5).astype(np.float32)
x = y.tolist()

assert np.allclose(add1(x, 1), add2(x, 1), add_with_jit(x, 1))
%timeit add1(x, 1)
%timeit add2(x, 1)
%timeit add_with_jit(x, 1)
print(np.allclose(wrong_add(x, 1), 1))

In [2]:
assert np.allclose(y + 1, add_with_jit(x, 1))
%timeit add_with_jit(x, 1)
%timeit y + 1

212 ms ± 7.98 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
38.7 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [3]:
@nb.vectorize(nopython=True)
def add_with_vec(yy, c):
    return yy + c

assert np.allclose(y + 1, add_with_vec(y, 1), add_with_vec(y, 1.))
%timeit add_with_vec(y, 1)
%timeit add_with_vec(y, 1.)
%timeit y + 1
%timeit y + 1.

105 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
117 µs ± 2.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
39 µs ± 861 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
38.9 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [4]:
@nb.vectorize("float32(float32, float32)", target="parallel", nopython=True)
def add_with_vec(y, c):
    return y + c

assert np.allclose(y+1, add_with_vec(y,1.))
%timeit add_with_vec(y, 1.)
%timeit y + 1

115 µs ± 1.53 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
39.1 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [5]:
import numba as nb
import numpy as np

# 普通的卷积
def conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                window = x[i, ..., j:j+filter_height, p:p+filter_width]
                for q in range(n_filters):
                    rs[i, q, j, p] += np.sum(w[q] * window)
    return rs

# 简单地加了个 jit 后的卷积
@nb.jit(nopython=True)
def jit_conv_kernel(x, w, rs, n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w):
    for i in range(n):
        for j in range(out_h):
            for p in range(out_w):
                window = x[i, ..., j:j+filter_height, p:p+filter_width]
                for q in range(n_filters):
                    rs[i, q, j, p] += np.sum(w[q] * window)

# 卷积操作的封装
def conv(x, w, kernel, args):
    n, n_filters = args[0], args[4]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, w, rs, *args)
    return rs

# 64 个 3 x 28 x 28 的图像输入（模拟 mnist）
x = np.random.randn(64, 3, 28, 28).astype(np.float32)
# 16 个 5 x 5 的 kernel
w = np.random.randn(16, x.shape[1], 5, 5).astype(np.float32)

n, n_channels, height, width = x.shape
n_filters, _, filter_height, filter_width = w.shape
out_h = height - filter_height + 1
out_w = width - filter_width + 1
args = (n, n_channels, height, width, n_filters, filter_height, filter_width, out_h, out_w)

print(np.linalg.norm((conv(x, w, conv_kernel, args) - conv(x, w, jit_conv_kernel, args)).ravel()))
%timeit conv(x, w, conv_kernel, args)
%timeit conv(x, w, jit_conv_kernel, args)

0.0011400561
7.78 s ± 317 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
336 ms ± 4.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


池化操作

In [6]:
# 普通的 MaxPool
def max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

# 简单地加了个 jit 后的 MaxPool
@nb.jit(nopython=True)
def jit_max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

# 不惧用 for 的、“更像 C”的 MaxPool
@nb.jit(nopython=True)
def jit_max_pool_kernel2(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    _max = x[i, j, p, q]
                    for r in range(pool_height):
                        for s in range(pool_width):
                            _tmp = x[i, j, p+r, q+s]
                            if _tmp > _max:
                                _max = _tmp
                    rs[i, j, p, q] += _max

# MaxPool 的封装
def max_pool(x, kernel, args):
    n, n_channels = args[:2]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, rs, *args)
    return rs

pool_height, pool_width = 2, 2
n, n_channels, height, width = x.shape
out_h = height - pool_height + 1
out_w = width - pool_width + 1
args = (n, n_channels, pool_height, pool_width, out_h, out_w)

assert np.allclose(max_pool(x, max_pool_kernel, args), max_pool(x, jit_max_pool_kernel, args))
assert np.allclose(max_pool(x, jit_max_pool_kernel, args), max_pool(x, jit_max_pool_kernel2, args))
%timeit max_pool(x, max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel2, args)

1.43 s ± 113 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
12.9 ms ± 394 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.73 ms ± 70.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
