diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index 0668dc83df55..caf6bc4f6778 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -1226,7 +1226,7 @@ def from_source( elif inspect.isfunction(input_func): _, start_line = inspect.getsourcelines(input_func) env: Dict[str, Any] = input_func.__globals__ - namespace = [key for key in env.keys() if env[key] == tir] + namespace = [key for key in env.keys() if env[key] is tir] parser = TVMScriptParser(start_line, namespace) result = to_ast(input_func, TVMDiagnosticCtx(), parser) return result diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py new file mode 100644 index 000000000000..3ad8090893eb --- /dev/null +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import numpy + +import tvm +from tvm.script import tir as T + + +# This numpy array is used to test the comparison between the global objects and the +# `tvm.script.tir` submodule. +np_array = numpy.array([0, 1, 2, 3]) + + +@T.prim_func +def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128]) + B = T.match_buffer(b, [128, 128]) + C = T.match_buffer(c, [128, 128]) + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = T.float32(0) + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] + + +def test_multi_element_array_in_outmost_namespace(): + func = matmul + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func) + + +if __name__ == "__main__": + test_multi_element_array_in_outmost_namespace()