In [1]:
import ctypes
import numpy as np

In [2]:
def gemm_reference(matrix_a, matrix_b, matrix_c, m, n, k):
    for m_idx in range(m):
        for n_idx in range(n):
            for k_idx in range(k):
                matrix_c[m_idx*m+n_idx] += float(matrix_a[m_idx*k+k_idx]) * float(matrix_a[n_idx*k+k_idx]);
    return None

In [3]:
def cout_diff(matrix_a, matrix_b, length, loss_rate=1e-4):
    diff_num = 0
    for idx in range(length):
        cur_loss_rate = float(matrix_b[idx] - matrix_a[idx]) / float(length)
        if cur_loss_rate > loss_rate:
            diff_num += 1
    return float(diff_num) / float(length)

In [4]:
intrinsic_gemm_so = ctypes.CDLL('./libintrinsic_gemm.so')

In [5]:
m, n, k = 100, 100, 100

In [6]:
get_instance_func = intrinsic_gemm_so.get_instance
get_instance_func.argtypes = (ctypes.c_int, ctypes.c_int, ctypes.c_int)
get_instance_func.restype = ctypes.c_void_p

instance_init_func = intrinsic_gemm_so.instance_init
instance_init_func.argtypes = (ctypes.c_void_p, ctypes.c_bool, ctypes.c_bool, ctypes.c_int, ctypes.c_int, ctypes.c_int)
instance_init_func.restype = ctypes.c_int

instance_dispatch_func = intrinsic_gemm_so.instance_dispatch
instance_dispatch_func.argtypes =  (ctypes.c_void_p, ctypes.c_float, ctypes.c_float, ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p)
instance_dispatch_func.restype = ctypes.c_int

In [7]:
matrix_a = np.arange(0, m*k, 1, dtype=np.int8)
matrix_b = np.arange(0, n*k, 1, dtype=np.int8)
matrix_c = np.zeros(shape=(m*n), dtype=np.int32)

intrinnsic_gemm_engine = get_instance_func(3, 3, 5)
status = instance_init_func(intrinnsic_gemm_engine, False, False, m, n, k)
status = instance_dispatch_func(intrinnsic_gemm_engine, 1.0, 1.0,
                                matrix_a.ctypes.data_as(ctypes.c_void_p),
                                matrix_b.ctypes.data_as(ctypes.c_void_p),
                                matrix_c.ctypes.data_as(ctypes.c_void_p))

In [8]:
matrix_a = np.arange(0, m*k, 1, dtype=np.float)
matrix_b = np.arange(0, n*k, 1, dtype=np.float)
matrix_c_ref = np.zeros(shape=(m*n), dtype=np.float)
gemm_reference(matrix_a, matrix_b, matrix_c_ref, m, n, k)

In [9]:
error_rate = cout_diff(matrix_c_ref, matrix_c, m*n)
print('error_rate:',error_rate)
# for idx in range(m*n):
#     print(idx, matrix_c_ref[idx], matrix_c[idx])

('error_rate:', 0.0)
