In [13]:
import tvm
from tvm import te

N, M, L = 1024, 1024, 1024
A = te.placeholder((N, L), name='A')
B = te.placeholder((L, M), name='B')
k = te.reduce_axis((0, L), 'k')
C = te.compute((N, M), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k), name='C')

s = te.create_schedule(C.op)
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis("threadIdx.x")
block_y = te.thread_axis("blockIdx.y")
thread_y = te.thread_axis("threadIdx.y")

# 选择一个合适的块和线程块大小来分解C的计算
bx, tx = s[C].split(C.op.axis[0], factor=32)
by, ty = s[C].split(C.op.axis[1], factor=32)

#s[C].bind(bx, block_x)
#s[C].bind(tx, thread_x)
#s[C].bind(by, block_y)
#s[C].bind(ty, thread_y)

mod = tvm.build(s, [A, B, C], 'c' ,name='matmul')
print(mod.get_source())


// tvm target: c -keys=cpu 
#define TVM_EXPORTS
#include "tvm/runtime/c_runtime_api.h"
#include "tvm/runtime/c_backend_api.h"
#include <math.h>
#include <stdbool.h>
#ifdef __cplusplus
extern "C"
#endif
TVM_DLL int32_t matmul(void* args, int32_t* arg_type_ids, int32_t num_args, void* out_ret_value, int32_t* out_ret_tcode, void* resource_handle) {
  void* arg_A = (((TVMValue*)args)[0].v_handle);
  int32_t arg_A_code = arg_type_ids[0];
  void* arg_B = (((TVMValue*)args)[1].v_handle);
  int32_t arg_B_code = arg_type_ids[1];
  void* arg_C = (((TVMValue*)args)[2].v_handle);
  int32_t arg_C_code = arg_type_ids[2];
  void* A = (((DLTensor*)arg_A)[0].data);
  void* arg_A_shape = (((DLTensor*)arg_A)[0].shape);
  void* arg_A_strides = (((DLTensor*)arg_A)[0].strides);
  int32_t dev_id = (((DLTensor*)arg_A)[0].device.device_id);
  void* B = (((DLTensor*)arg_B)[0].data);
  void* arg_B_shape = (((DLTensor*)arg_B)[0].shape);
  void* arg_B_strides = (((DLTensor*)arg_B)[0].strides);
  void* C = (((DLTen