In [1]:
import ctypes
import numpy as np
import numba as nb

In [2]:
intrinsic_gemm_so = ctypes.CDLL('./libintrinsic_gemm.so')

In [3]:
@nb.jit(nopython=True)
def cout_diff(matrix_a, matrix_b, length, loss_rate=1e-4, debug=False):
    diff_num = 0
    zero_offset = 1e-10
    for idx in range(length):
        cur_loss_rate = abs(float(matrix_b[idx] - matrix_a[idx])) / (abs(float(matrix_a[idx]))+zero_offset)
        if cur_loss_rate > loss_rate:
            diff_num += 1
        if debug == True:
            print(idx, matrix_a[idx], matrix_b[idx])
    return float(diff_num) / float(length)

In [4]:
@nb.jit(nopython=True)
def gemm_reference(matrix_a, matrix_b, matrix_c, m, n, k):
    for m_idx in range(m):
        for n_idx in range(n):
            for k_idx in range(k):
                matrix_c[m_idx*n+n_idx] += matrix_a[m_idx*k+k_idx] * matrix_a[n_idx*k+k_idx];
    return None

In [5]:
def intrinsic_gemm_impl(matrix_a, matrix_b, matrix_c, m, n, k):
    get_instance_func = intrinsic_gemm_so.get_instance
    get_instance_func.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int)
    get_instance_func.restype = ctypes.c_void_p

    instance_init_func = intrinsic_gemm_so.instance_init
    instance_init_func.argtypes = (ctypes.c_void_p, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_int, ctypes.c_int)
    instance_init_func.restype = ctypes.c_int

    instance_dispatch_func = intrinsic_gemm_so.instance_dispatch
    instance_dispatch_func.argtypes =  (ctypes.c_void_p, ctypes.c_float, ctypes.c_float, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p)
    instance_dispatch_func.restype = ctypes.c_int

    intrinnsic_gemm_engine = get_instance_func(3, 3, 5)
    status = instance_init_func(intrinnsic_gemm_engine, False, False, m, n, k)
    status = instance_dispatch_func(intrinnsic_gemm_engine, 1.0, 1.0,
                                    matrix_a.ctypes.data_as(ctypes.c_void_p),
                                    matrix_b.ctypes.data_as(ctypes.c_void_p),
                                    matrix_c.ctypes.data_as(ctypes.c_void_p))
    return None

In [6]:
def test(m, n, k, debug=False):
    if debug == True:
        matrix_a = np.ones(m*k, dtype=np.int8)
        matrix_b = np.ones(n*k, dtype=np.int8)
    else:
        matrix_a = np.arange(0, m*k, 1, dtype=np.int8)
        matrix_b = np.arange(0, n*k, 1, dtype=np.int8)
    matrix_c= np.zeros(shape=(m*n), dtype=np.int32)
    matrix_c_ref = np.zeros(shape=(m*n), dtype=np.float)

    intrinsic_gemm_impl(matrix_a, matrix_b, matrix_c, m, n, k)
    gemm_reference(matrix_a.copy().astype(np.float), matrix_b.copy().astype(np.float), matrix_c_ref, m, n, k)

    error_rate = cout_diff(matrix_c_ref, matrix_c, m*n, 1e-4, debug)
    print('error_rate:',error_rate)
    return None

In [7]:
m, n, k = 95, 95, 32768
test(m, n, k, False)

error_rate: 0.0
