From a2ee33301ecb0a5e7cd6a0c3fd6bed81912ffa7f Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 19 Jul 2018 19:17:22 -0700 Subject: [PATCH 1/8] support augmented assign --- python/tvm/hybrid/parser.py | 8 ++++++ python/tvm/hybrid/var_decl.py | 9 +++++++ tests/python/unittest/test_hybrid_script.py | 30 ++++++++++++++++----- 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index 1e532367a321..ba4dd05fe95c 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -162,6 +162,14 @@ def visit_Name(self, node): def visit_Num(self, node): return _api.const(node.n) + def visit_AugAssign(self, node): + lhs = node.target + rhs = self.visit(node.value) + rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs) + if not isinstance(lhs, _expr.Call): + raise ValueError("The LHS of an AugAssign is supposed to be a call!") + buf = self._get_buffer_from_id(lhs.name) + return _make.Provide(buf.op, 0, rhs, lhs.args) def visit_Assign(self, node): if len(node.targets) != 1: diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index df38bac1acba..d6ca3b911eb5 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -14,6 +14,7 @@ def __init__(self, args): self.scope_level = [] self._args = {} self.args = args + self.aug_assign_ = False def visit_FunctionDef(self, node): @@ -48,6 +49,12 @@ def visit_Call(self, node): self.visit(elem) + def visit_AugAssign(self, node): + self.aug_assign_ = True + self.generic_visit(node) + self.aug_assign_ = False + + def visit_Name(self, node): # If it is from the argument list or loop variable, we do not worry about it! if node.id in self._args.keys(): @@ -62,6 +69,8 @@ def visit_Name(self, node): if node.id not in self.status.keys(): if not isinstance(node.ctx, ast.Store): raise ValueError('In Python, "first store" indicates "declaration"') + if self.aug_assign_: + raise ValueError('"First store" cannot be an AugAssign') self.status[node.id] = (node, self.scope_level[-1], set()) else: decl, loop, usage = self.status[node.id] diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 0f500d7c704f..a52ff3c7b2a1 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -165,20 +165,29 @@ def fanout(n, a, b): run_and_check(fanout, [n, a, b], [b], {n: 10}) -@script -def failure(): - for i in range(1, 100): - i = 0 - def test_failure(): try: + @script + def failure(): + for i in range(1, 100): + i = 0 tvm.hybrid.parse(failure, []) except IOError as err: assert sys.version_info[0] == 2 print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err)) - except Exception as err: + except ValueError as err: assert str(err) == 'You CAN NEVER overwrite a loop variable!' + try: + @tvm.hybrid.script + def augdefine(): + for i in range(10): + s += 0 + tvm.hybrid.parse(augdefine, []) + except ValueError as err: + assert str(err) == '"First store" cannot be an AugAssign' + + def test_looptype(): @script @@ -338,6 +347,15 @@ def share_vec_add(a, b, c): print('[Warning] No GPU found! Skip shared mem test!') +def test_augassign(): + @tvm.hybrid.script + def augassign(a): + for i in range(a.shape[0]): + a[i] += 1.0 + a = tvm.placeholder((16, ), dtype='float32', name='a') + run_and_check(augassign, [a], [a]) + + if __name__ == "__main__": test_outer_product() test_fanout() From 5d4ba1e9e3f7e46b8503ad9117a31f0f3f3d3218 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 19 Jul 2018 19:18:56 -0700 Subject: [PATCH 2/8] share memory gpu test add! --- tests/python/unittest/test_hybrid_script.py | 31 +++++++++++---------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index a52ff3c7b2a1..cd572135d61d 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -324,25 +324,26 @@ def blur2d(a, b): a = tvm.placeholder((32, 32), 'float32', 'a') b = tvm.placeholder((30, 30), 'float32', 'b') - run_and_check(blur2d, [a, b], [b]) if tvm.gpu().exist: @tvm.hybrid.script - def share_vec_add(a, b, c): - shared = allocate((256, ), 'float32', 'shared') - for i in bind("threadIdx.x", 256): - shared[i] = a[i] - local = allocate((256, ), 'float32', 'local') - for i in bind("threadIdx.x", 256): - local[i] = b[i] - for i in bind("threadIdx.x", 256): - c[i] = shared[i] + local[i] - - a = tvm.placeholder((256, ), dtype='float32', name='a') - b = tvm.placeholder((256, ), dtype='float32', name='b') - c = tvm.placeholder((256, ), dtype='float32', name='c') - run_and_check(share_vec_add, [a, b, c], [c], target='cuda') + def shared_gemm(a, b, c): + for io in bind('blockIdx.x', 8): + for ii in bind('blockIdx.y', 8): + shared_b = allocate((64, 64), 'float32', 'shared') + for k in range(64): + shared_b[io * 8 + ii, k] = b[io * 8 + ii, k] + for jo in bind('threadIdx.y', 8): + for ji in bind('threadIdx.x', 8): + for k in range(64): + c[io*8+ii, jo*8+ji] = c[io*8+ii, jo*8+ji] + a[io*8+ii, k] * shared_b[k, jo*8+ji] + + a = tvm.placeholder((64, 64), dtype='float32', name='a') + b = tvm.placeholder((64, 64), dtype='float32', name='b') + c = tvm.placeholder((64, 64), dtype='float32', name='c') + module = run_and_check(shared_gemm, [a, b, c], [c], target='cuda') + assert "__syncthreads()" in module.imported_modules[0].get_source() else: print('[Warning] No GPU found! Skip shared mem test!') From 66ef1b10eb45a0d16a8557d03514af0156b6b17c Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 19 Jul 2018 19:25:12 -0700 Subject: [PATCH 3/8] test case fixed --- python/tvm/hybrid/parser.py | 4 ++-- tests/python/unittest/test_hybrid_script.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index ba4dd05fe95c..e5a77c8c3425 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -15,7 +15,7 @@ def list_to_block(visit, lst): """Convert a list of Python IR nodes to HalideIR Block""" - lst = list(map(visit, lst)) + lst = [visit(i) for i in lst] lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())] if not lst: return make_nop() @@ -163,7 +163,7 @@ def visit_Num(self, node): return _api.const(node.n) def visit_AugAssign(self, node): - lhs = node.target + lhs = self.visit(node.target) rhs = self.visit(node.value) rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs) if not isinstance(lhs, _expr.Call): diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index cd572135d61d..bc7c4a994100 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -289,7 +289,7 @@ def blur(a, b): s = 0.0 for di in range(3): for dj in range(3): - s = s + a[i-di, j-dj] + s += a[i-di, j-dj] b[i-2, j-2] = s / 9.0 try: a = tvm.placeholder((32, 32), 'float32', 'a') @@ -337,7 +337,7 @@ def shared_gemm(a, b, c): for jo in bind('threadIdx.y', 8): for ji in bind('threadIdx.x', 8): for k in range(64): - c[io*8+ii, jo*8+ji] = c[io*8+ii, jo*8+ji] + a[io*8+ii, k] * shared_b[k, jo*8+ji] + c[io*8+ii, jo*8+ji] += a[io*8+ii, k] * shared_b[k, jo*8+ji] a = tvm.placeholder((64, 64), dtype='float32', name='a') b = tvm.placeholder((64, 64), dtype='float32', name='b') @@ -367,4 +367,5 @@ def augassign(a): test_math_intrin() test_non_zero() test_allocate() + test_augassign() From 9f3754e04f17cac1ebb2e4cb04450e02ed6fb0c2 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 19 Jul 2018 19:36:30 -0700 Subject: [PATCH 4/8] fix test --- python/tvm/hybrid/parser.py | 3 +-- tests/python/unittest/test_hybrid_script.py | 6 ++++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/tvm/hybrid/parser.py b/python/tvm/hybrid/parser.py index e5a77c8c3425..afb1a1a917fb 100644 --- a/python/tvm/hybrid/parser.py +++ b/python/tvm/hybrid/parser.py @@ -168,8 +168,7 @@ def visit_AugAssign(self, node): rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs) if not isinstance(lhs, _expr.Call): raise ValueError("The LHS of an AugAssign is supposed to be a call!") - buf = self._get_buffer_from_id(lhs.name) - return _make.Provide(buf.op, 0, rhs, lhs.args) + return _make.Provide(lhs.func, 0, rhs, lhs.args) def visit_Assign(self, node): if len(node.targets) != 1: diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index bc7c4a994100..2beb4e9f7eb1 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -40,6 +40,8 @@ def tvm_val_2_py_val(val): for nd, np in to_check: numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) + return module + @script def outer_product(n, m, a, b, c): @@ -182,7 +184,7 @@ def failure(): @tvm.hybrid.script def augdefine(): for i in range(10): - s += 0 + es += 0 tvm.hybrid.parse(augdefine, []) except ValueError as err: assert str(err) == '"First store" cannot be an AugAssign' @@ -289,7 +291,7 @@ def blur(a, b): s = 0.0 for di in range(3): for dj in range(3): - s += a[i-di, j-dj] + s = s + a[i-di, j-dj] b[i-2, j-2] = s / 9.0 try: a = tvm.placeholder((32, 32), 'float32', 'a') From f6f8a37ee2785b8e73390378def855fa763bd3d3 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 19 Jul 2018 19:38:05 -0700 Subject: [PATCH 5/8] fixlint --- python/tvm/hybrid/var_decl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/hybrid/var_decl.py b/python/tvm/hybrid/var_decl.py index d6ca3b911eb5..8982b9d39ae9 100644 --- a/python/tvm/hybrid/var_decl.py +++ b/python/tvm/hybrid/var_decl.py @@ -53,7 +53,7 @@ def visit_AugAssign(self, node): self.aug_assign_ = True self.generic_visit(node) self.aug_assign_ = False - + def visit_Name(self, node): # If it is from the argument list or loop variable, we do not worry about it! From 47441a6cb8b9a7153b2c22b095fd862363919ca4 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Thu, 19 Jul 2018 19:43:45 -0700 Subject: [PATCH 6/8] test case updated! --- tests/python/unittest/test_hybrid_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index 2beb4e9f7eb1..c21ea782dc44 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -291,7 +291,7 @@ def blur(a, b): s = 0.0 for di in range(3): for dj in range(3): - s = s + a[i-di, j-dj] + s += a[i-di, j-dj] b[i-2, j-2] = s / 9.0 try: a = tvm.placeholder((32, 32), 'float32', 'a') From 4a6a438b4d1d3f44bd96a45d91b831d3faa490d2 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Fri, 20 Jul 2018 11:01:15 -0700 Subject: [PATCH 7/8] one fixed another in wip --- tests/python/unittest/test_hybrid_script.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index c21ea782dc44..caa18db02c53 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -85,7 +85,7 @@ def test_outer_product(): func = tvm.lower(ir, [n, m, a, b, c]) func = tvm.build(func) - run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001}) + run_and_check(outer_product, [n, m, a, b, c], [c], {n: 99, m: 101}) for key, _ in HYBRID_GLOBALS.items(): assert key not in globals().keys() @@ -176,7 +176,7 @@ def failure(): tvm.hybrid.parse(failure, []) except IOError as err: assert sys.version_info[0] == 2 - print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err)) + print('[Warning] Case test_failure.0 is skipped by Python2 because "%s"' % str(err)) except ValueError as err: assert str(err) == 'You CAN NEVER overwrite a loop variable!' @@ -186,6 +186,9 @@ def augdefine(): for i in range(10): es += 0 tvm.hybrid.parse(augdefine, []) + except IOError as err: + assert sys.version_info[0] == 2 + print('[Warning] Case test_failure.1 is skipped by Python2 because "%s"' % str(err)) except ValueError as err: assert str(err) == '"First store" cannot be an AugAssign' From 072126e42e8c97e7d01170f8277ed3eb278209a7 Mon Sep 17 00:00:00 2001 From: Jian Weng Date: Fri, 20 Jul 2018 14:17:00 -0700 Subject: [PATCH 8/8] relase the error check bound --- tests/python/unittest/test_hybrid_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_hybrid_script.py b/tests/python/unittest/test_hybrid_script.py index caa18db02c53..f725d29e8b57 100644 --- a/tests/python/unittest/test_hybrid_script.py +++ b/tests/python/unittest/test_hybrid_script.py @@ -38,7 +38,7 @@ def tvm_val_2_py_val(val): module(*nd_args) for nd, np in to_check: - numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5) + numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-3, atol=1e-3) return module