Skip to content

Commit

Permalink
ELL and BSR correctness test scripts (apache#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
yzh119 authored and MasterJH5574 committed Dec 22, 2021
1 parent 97e36b3 commit aa4e8f0
Showing 1 changed file with 152 additions and 13 deletions.
165 changes: 152 additions & 13 deletions tests/python/sparsetir/test_tir_sparse_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm.runtime.ndarray import device
import tvm.tir as tir
import scipy.sparse as sp
import numpy as np
Expand All @@ -36,49 +37,90 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha


@T.prim_func
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, nnz: T.int32) -> None:
def csrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_data = T.match_buffer(a, (nnz,), "float32")
A_data = T.match_buffer(a, (NNZ,), "float32")
B = T.match_buffer(b, (N * K,), "float32")
C = T.match_buffer(c, (M * K,), "float32")
A_indptr = T.match_buffer(indptr, (M + 1,), "int32")
A_indices = T.match_buffer(indices, (nnz,), "int32")
A_indices = T.match_buffer(indices, (NNZ,), "int32")
for i, k in T.grid(M, K):
with T.block("spmm_outer"):
vi, vk = T.axis.remap("SS", [i, k])
with T.init():
C[vi * K + vk] = 0.
for j in T.serial(0, A_indptr[vi + 1] - A_indptr[vi]):
with T.block("spmm_inner"):
vj = T.axis.R(N, j + A_indptr[vi])
C[vi * K + vk] = C[vi * K + vk] + A_data[vj] * B[A_indices[vj] * K + vk]
vj = T.axis.R(NNZ, j + A_indptr[vi])
C[vi * K + vk] = C[vi * K + vk] + \
A_data[vj] * B[A_indices[vj] * K + vk]


@T.prim_func
def bsrmm_tir(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, MB: T.int32, NB: T.int32, K: T.int32, BLOCK_SIZE: T.int32, NNZB: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_data = T.match_buffer(a, (NNZB * BLOCK_SIZE * BLOCK_SIZE), "float32")
B = T.match_buffer(b, (NB * BLOCK_SIZE * K,), "float32")
C = T.match_buffer(c, (MB * BLOCK_SIZE * K,), "float32")
A_indptr = T.match_buffer(indptr, (MB + 1,), "int32")
A_indices = T.match_buffer(indices, (NNZB,), "int32")
for io, ii, ji, k in T.grid(MB, BLOCK_SIZE, BLOCK_SIZE, K):
with T.block("spmm_outer"):
vio, vii, vji, vk = T.axis.remap("SSSS", [io, ii, ji, k])
with T.init():
C[(vio * BLOCK_SIZE + vii) * K + vk] = 0.
for jo in T.serial(0, A_indptr[vio + 1] - A_indptr[vio]):
with T.block("spmm_inner"):
vjo = T.axis.R(NNZB, jo + A_indptr[vio])
C[(vio * BLOCK_SIZE + vii) * K + vk] = C[(vio * BLOCK_SIZE + vii) * K + vk] + A_data[(
vjo * BLOCK_SIZE + vii) * BLOCK_SIZE + vji] * B[(A_indices[vjo] * BLOCK_SIZE + vji) * K + vk]


@T.prim_func
def ellmm_tir(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, M: T.int32, N: T.int32, K: T.int32, NNZ_COLS: T.int32) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_data = T.match_buffer(a, (M * NNZ_COLS,), "float32")
B = T.match_buffer(b, (N * K,), "float32")
C = T.match_buffer(c, (M * K,), "float32")
A_indices = T.match_buffer(indices, (M * NNZ_COLS,), "int32")
for i, j, k in T.grid(M, NNZ_COLS, K):
with T.block("spmm"):
vi, vj, vk = T.axis.remap("SRS", [i, j, k])
with T.init():
C[vi * K + vk] = 0.
C[vi * K + vk] = C[vi * K + vk] + A_data[vi * NNZ_COLS + vj] * \
B[A_indices[vi * NNZ_COLS + vj] * K + vk]


def test_csrmm():
# generate random input
A = sp.random(4096, 4096, dtype="float32", density=0.0125, format='csr')
x = np.random.rand(4096, 256).astype("float32")
m = 4096
n = 4096
k = 256
A = sp.random(m, n, dtype="float32", density=0.0125, format='csr')
nnz = A.nnz
x = np.random.rand(n, k).astype("float32")
y_ground_truth = A * x
y = np.zeros((4096, 256)).astype("float32")
y = np.zeros((m * k,)).astype("float32")

# specialize function
_, _, _, _, _, m, n, k, nnz = csrmm_tir.params
_, _, _, _, _, M, N, K, NNZ = csrmm_tir.params
sch = tir.Schedule(
csrmm_tir.specialize(
{m: 4096, n: 4096, k: 256, nnz: A.nnz}
{M: m, N: n, K: k, NNZ: nnz}
)
)
blk_outer = sch.get_block("spmm_outer")
i, k = sch.get_loops(blk_outer)
sch.bind(i, "blockIdx.x")
sch.bind(k, "threadIdx.x")

# convert numpy tensor to tvm ndarray
A_indptr = tvm.nd.array(A.indptr.astype("int32"), device=tvm.cuda(0))
A_indices = tvm.nd.array(A.indices.astype("int32"), device=tvm.cuda(0))
A_data = tvm.nd.array(A.data.astype("float32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod, target='cuda')
Expand All @@ -88,5 +130,102 @@ def test_csrmm():
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())


def test_bsrmm():
# generate random input
block_size = 1
mb = 64
nb = 64
k = 256
m = mb * block_size
n = nb * block_size
A_block = sp.random(mb, nb, dtype="float32", density=0.05, format='csr')
indptr = A_block.indptr
indices = A_block.indices
nnzb = A_block.nnz
data = np.random.rand(nnzb, block_size, block_size)
A = sp.bsr_matrix((data, indices, indptr), shape=(m, n))
x = np.random.rand(n, k).astype("float32")
y_ground_truth = A * x
y = np.zeros((m * k,)).astype("float32")

# specialize function
_, _, _, _, _, MB, NB, K, BLOCK_SIZE, NNZB = bsrmm_tir.params
sch = tir.Schedule(
bsrmm_tir.specialize(
{MB: mb, NB: nb, K: k, BLOCK_SIZE: block_size, NNZB: nnzb}
)
)
blk_outer = sch.get_block("spmm_outer")
io, ii, ji, k = sch.get_loops(blk_outer)
sch.unroll(ii)
sch.unroll(ji)
sch.bind(io, "blockIdx.x")
sch.bind(k, "threadIdx.x")

# convert numpy tensor to tvm ndarray
A_indptr = tvm.nd.array(indptr.astype("int32"), device=tvm.cuda(0))
A_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
A_data = tvm.nd.array(
data.reshape(-1).astype("float32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod, target="cuda")
f(A_data, X_nd, Y_nd, A_indptr, A_indices)

# assertion
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())


def test_ellmm():
# generate random input
nnz_cols = 64
m = 4096
n = 4096
k = 256
nnz = nnz_cols * m
indptr = np.arange(0, (m + 1) * nnz_cols, nnz_cols)
indices = np.random.randint(0, n, size=(nnz,))
data = np.random.rand(nnz)
A = sp.csr_matrix((data, indices, indptr), shape=(m, n))
x = np.random.rand(n, k).astype("float32")
y_ground_truth = A * x
y = np.zeros((m * k,)).astype("float32")
# specialize function
_, _, _, _, M, N, K, NNZ_COLS = ellmm_tir.params
sch = tir.Schedule(
ellmm_tir.specialize(
{M: m, N: n, K: k, NNZ_COLS: nnz_cols}
)
)
blk = sch.get_block("spmm")
i, j, k = sch.get_loops(blk)
sch.bind(i, "blockIdx.x")
sch.bind(k, "threadIdx.x")
sch.unroll(j)

# convert numpy tensor to tvm ndarray
A_indices = tvm.nd.array(indices.astype("int32"), device=tvm.cuda(0))
A_data = tvm.nd.array(data.astype("float32"), device=tvm.cuda(0))
X_nd = tvm.nd.array(x.reshape(-1), device=tvm.cuda(0))
Y_nd = tvm.nd.array(y, device=tvm.cuda(0))

# build function
f = tvm.build(sch.mod, target="cuda")
f(A_data, X_nd, Y_nd, A_indices)

# assertion
assert np.allclose(y_ground_truth.reshape(-1), Y_nd.numpy())


def test_bmm():
# TODO(zihao)
pass


if __name__ == "__main__":
test_csrmm()
test_csrmm()
test_bsrmm()
test_ellmm()
test_bmm()

0 comments on commit aa4e8f0

Please sign in to comment.