Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] [HYBRID] Augmented assign operator supported! #1459

Merged
merged 8 commits into from
Jul 20, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion python/tvm/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

def list_to_block(visit, lst):
"""Convert a list of Python IR nodes to HalideIR Block"""
lst = list(map(visit, lst))
lst = [visit(i) for i in lst]
lst = [stmt for stmt in lst if not _ir_pass.Equal(stmt, make_nop())]
if not lst:
return make_nop()
Expand Down Expand Up @@ -162,6 +162,13 @@ def visit_Name(self, node):
def visit_Num(self, node):
return _api.const(node.n)

def visit_AugAssign(self, node):
lhs = self.visit(node.target)
rhs = self.visit(node.value)
rhs = HybridParser._binop_maker[type(node.op)](lhs, rhs)
if not isinstance(lhs, _expr.Call):
raise ValueError("The LHS of an AugAssign is supposed to be a call!")
return _make.Provide(lhs.func, 0, rhs, lhs.args)

def visit_Assign(self, node):
if len(node.targets) != 1:
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/hybrid/var_decl.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(self, args):
self.scope_level = []
self._args = {}
self.args = args
self.aug_assign_ = False


def visit_FunctionDef(self, node):
Expand Down Expand Up @@ -48,6 +49,12 @@ def visit_Call(self, node):
self.visit(elem)


def visit_AugAssign(self, node):
self.aug_assign_ = True
self.generic_visit(node)
self.aug_assign_ = False


def visit_Name(self, node):
# If it is from the argument list or loop variable, we do not worry about it!
if node.id in self._args.keys():
Expand All @@ -62,6 +69,8 @@ def visit_Name(self, node):
if node.id not in self.status.keys():
if not isinstance(node.ctx, ast.Store):
raise ValueError('In Python, "first store" indicates "declaration"')
if self.aug_assign_:
raise ValueError('"First store" cannot be an AugAssign')
self.status[node.id] = (node, self.scope_level[-1], set())
else:
decl, loop, usage = self.status[node.id]
Expand Down
75 changes: 50 additions & 25 deletions tests/python/unittest/test_hybrid_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def tvm_val_2_py_val(val):
module(*nd_args)

for nd, np in to_check:
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
numpy.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-3, atol=1e-3)

return module


@script
Expand Down Expand Up @@ -83,7 +85,7 @@ def test_outer_product():
func = tvm.lower(ir, [n, m, a, b, c])
func = tvm.build(func)

run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001})
run_and_check(outer_product, [n, m, a, b, c], [c], {n: 99, m: 101})

for key, _ in HYBRID_GLOBALS.items():
assert key not in globals().keys()
Expand Down Expand Up @@ -165,20 +167,32 @@ def fanout(n, a, b):
run_and_check(fanout, [n, a, b], [b], {n: 10})


@script
def failure():
for i in range(1, 100):
i = 0

def test_failure():
try:
@script
def failure():
for i in range(1, 100):
i = 0
tvm.hybrid.parse(failure, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
except Exception as err:
print('[Warning] Case test_failure.0 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == 'You CAN NEVER overwrite a loop variable!'

try:
@tvm.hybrid.script
def augdefine():
for i in range(10):
es += 0
tvm.hybrid.parse(augdefine, [])
except IOError as err:
assert sys.version_info[0] == 2
print('[Warning] Case test_failure.1 is skipped by Python2 because "%s"' % str(err))
except ValueError as err:
assert str(err) == '"First store" cannot be an AugAssign'



def test_looptype():
@script
Expand Down Expand Up @@ -280,7 +294,7 @@ def blur(a, b):
s = 0.0
for di in range(3):
for dj in range(3):
s = s + a[i-di, j-dj]
s += a[i-di, j-dj]
b[i-2, j-2] = s / 9.0
try:
a = tvm.placeholder((32, 32), 'float32', 'a')
Expand Down Expand Up @@ -315,29 +329,39 @@ def blur2d(a, b):

a = tvm.placeholder((32, 32), 'float32', 'a')
b = tvm.placeholder((30, 30), 'float32', 'b')

run_and_check(blur2d, [a, b], [b])

if tvm.gpu().exist:
@tvm.hybrid.script
def share_vec_add(a, b, c):
shared = allocate((256, ), 'float32', 'shared')
for i in bind("threadIdx.x", 256):
shared[i] = a[i]
local = allocate((256, ), 'float32', 'local')
for i in bind("threadIdx.x", 256):
local[i] = b[i]
for i in bind("threadIdx.x", 256):
c[i] = shared[i] + local[i]

a = tvm.placeholder((256, ), dtype='float32', name='a')
b = tvm.placeholder((256, ), dtype='float32', name='b')
c = tvm.placeholder((256, ), dtype='float32', name='c')
run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
def shared_gemm(a, b, c):
for io in bind('blockIdx.x', 8):
for ii in bind('blockIdx.y', 8):
shared_b = allocate((64, 64), 'float32', 'shared')
for k in range(64):
shared_b[io * 8 + ii, k] = b[io * 8 + ii, k]
for jo in bind('threadIdx.y', 8):
for ji in bind('threadIdx.x', 8):
for k in range(64):
c[io*8+ii, jo*8+ji] += a[io*8+ii, k] * shared_b[k, jo*8+ji]

a = tvm.placeholder((64, 64), dtype='float32', name='a')
b = tvm.placeholder((64, 64), dtype='float32', name='b')
c = tvm.placeholder((64, 64), dtype='float32', name='c')
module = run_and_check(shared_gemm, [a, b, c], [c], target='cuda')
assert "__syncthreads()" in module.imported_modules[0].get_source()
else:
print('[Warning] No GPU found! Skip shared mem test!')


def test_augassign():
@tvm.hybrid.script
def augassign(a):
for i in range(a.shape[0]):
a[i] += 1.0
a = tvm.placeholder((16, ), dtype='float32', name='a')
run_and_check(augassign, [a], [a])


if __name__ == "__main__":
test_outer_product()
test_fanout()
Expand All @@ -348,4 +372,5 @@ def share_vec_add(a, b, c):
test_math_intrin()
test_non_zero()
test_allocate()
test_augassign()