Skip to content

Commit

Permalink
[Pylint] fix pylint issues from test_random to test_tedd (#16065)
Browse files Browse the repository at this point in the history
* Update pylint.sh

* Update test_sparse.py

* Update test_sort.py

* Update test_rpc_tracker.py

* Update test_rpc_server_device.py

* Update test_rpc_proxy.py

* Update test_rocblas.py

* Update test_random.py

* Update test_tedd.py

* Update test_random.py

* Update test_tedd.py

* Update test_sparse.py

* Update test_sort.py

* Update test_rpc_tracker.py

* Update test_rpc_proxy.py

* Update test_rocblas.py

* Update test_random.py

* Update test_tedd.py

* Update test_tedd.py
  • Loading branch information
tlopex committed Nov 7, 2023
1 parent 641225a commit 2f20264
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 33 deletions.
8 changes: 8 additions & 0 deletions tests/lint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ python3 -m pylint tests/python/contrib/test_cblas.py --rcfile="$(dirname "$0")"/
python3 -m pylint tests/python/contrib/test_tflite_runtime.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_thrust.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_util.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_sort.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_sparse.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_tedd.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rpc_tracker.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rpc_server_device.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rpc_proxy.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rocblas.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_random.py --rcfile="$(dirname "$0")"/pylintrc

# tests/python/contrib/test_hexagon tests
python3 -m pylint tests/python/contrib/test_hexagon/*.py --rcfile="$(dirname "$0")"/pylintrc
Expand Down
15 changes: 10 additions & 5 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
# pylint: disable=invalid-name
import threading
import numpy as np
import tvm
from tvm import te
import numpy as np
from tvm.contrib import random
from tvm import rpc
import tvm.testing
import threading


def test_randint():
"""Tests randint function"""
m = 10240
n = 10240
A = random.randint(-127, 128, size=(m, n), dtype="int32")
Expand All @@ -49,6 +52,7 @@ def verify(target="llvm"):


def test_uniform():
"""Tests uniform function"""
m = 10240
n = 10240
A = random.uniform(0, 1, size=(m, n))
Expand All @@ -74,6 +78,7 @@ def verify(target="llvm"):


def test_normal():
"""Tests normal function"""
m = 10240
n = 10240
A = random.normal(3, 4, size=(m, n))
Expand All @@ -99,6 +104,8 @@ def verify(target="llvm"):

@tvm.testing.uses_gpu
def test_random_fill():
"""Tests random_fill function"""

def test_local(dev, dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
Expand All @@ -120,8 +127,6 @@ def test_rpc(dtype):
if not tvm.testing.device_enabled("rpc") or not tvm.runtime.enabled("llvm"):
return

np_ones = np.ones((512, 512), dtype=dtype)

def check_remote(server):
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty((512, 512), dtype, remote.cpu())
Expand Down Expand Up @@ -171,7 +176,7 @@ def test_body():
test_input = tvm.runtime.ndarray.empty((10, 10))
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill_for_measure")
random_fill(test_input)
except:
except: # pylint: disable=bare-except
nonlocal no_exception_happened
no_exception_happened = False

Expand Down
8 changes: 6 additions & 2 deletions tests/python/contrib/test_rocblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
# pylint: disable=invalid-name
import numpy as np
import tvm
import tvm.testing
from tvm import te
import numpy as np
import tvm.topi.testing
import tvm.testing
from tvm.contrib import rocblas


@tvm.testing.requires_rocm
def test_matmul():
"""Tests matmul operation using roc"""
n = 1024
l = 128
m = 235
Expand All @@ -49,6 +51,7 @@ def verify(target="rocm"):


def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False, dtype="float32"):
"""Tests matmul operation in batch using roc"""
ashape = (batch, k, m) if transa else (batch, m, k)
bshape = (batch, n, k) if transb else (batch, k, n)
A = te.placeholder(ashape, name="A", dtype=dtype)
Expand Down Expand Up @@ -85,6 +88,7 @@ def verify(target="rocm"):

@tvm.testing.requires_rocm
def test_batch_matmul():
"""Tests of matmul operation in batch using roc"""
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False)
Expand Down
7 changes: 4 additions & 3 deletions tests/python/contrib/test_rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
"""Configure pytest"""
# pylint: disable=invalid-name
import logging
import numpy as np
import time
import multiprocessing
import tvm
from tvm import rpc


Expand All @@ -35,6 +35,7 @@ def rpc_proxy_check():
"""

try:
# pylint: disable=import-outside-toplevel
from tvm.rpc import proxy

web_port = 8888
Expand Down
1 change: 0 additions & 1 deletion tests/python/contrib/test_rpc_server_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
"""iOS RPC Server tests."""
# pylint: disable=invalid-name, no-value-for-parameter, missing-function-docstring, import-error
import sys
import multiprocessing
import pytest
import numpy as np
Expand Down
13 changes: 8 additions & 5 deletions tests/python/contrib/test_rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
"""Configure pytest"""
# pylint: disable=invalid-name
import logging
import numpy as np
import time
import multiprocessing
import tvm
from tvm import rpc


def check_server_drop():
"""test when server drops"""
try:
# pylint: disable=import-outside-toplevel
from tvm.rpc import tracker, proxy, base

# pylint: disable=import-outside-toplevel
from tvm.rpc.base import TrackerCode

@tvm.register_func("rpc.test2.addone")
Expand Down Expand Up @@ -80,10 +82,11 @@ def myfunc(remote):
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11

except tvm.error.TVMError as e:
except tvm.error.TVMError:
pass
remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone")
assert f1(10) == 11
remote3 = tclient.request("xyz1")
f1 = remote3.get_function("rpc.test2.addone")
assert f1(10) == 11
Expand Down
11 changes: 8 additions & 3 deletions tests/python/contrib/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
# pylint: disable=invalid-name
import numpy as np
import tvm
import tvm.testing
from tvm import te
from tvm.topi.cuda import sort_by_key
import numpy as np


def test_sort():
"""Tests sort function"""
n = 2
l = 5
m = 3
Expand All @@ -38,7 +41,7 @@ def test_sort():
dtype="int32",
name="sort_tensor",
)
input = [
input_data = [
[[1, 2, 3], [2, 4.5, 3.5], [1.1, 0.5, 1], [3.2, -5, 0.5], [1.5, 0, 0]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],
]
Expand All @@ -52,14 +55,15 @@ def test_sort():
target = "llvm"
s = te.create_schedule(out.op)
f = tvm.build(s, [data, sort_num, out], target)
a = tvm.nd.array(np.array(input).astype(data.dtype), dev)
a = tvm.nd.array(np.array(input_data).astype(data.dtype), dev)
b = tvm.nd.array(np.array(sort_num_input).astype(sort_num.dtype), dev)
c = tvm.nd.array(np.zeros(a.shape, dtype=out.dtype), dev)
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), np.array(sorted_index).astype(out.dtype), rtol=1e-5)


def test_sort_np():
"""Tests sort function using numpy"""
dshape = (1, 2, 3, 4, 5, 6)
axis = 4
reduced_shape = (1, 2, 3, 4, 6)
Expand Down Expand Up @@ -92,6 +96,7 @@ def test_sort_np():


def test_sort_by_key_gpu():
"""Tests sort function using gpu"""
size = 6
keys = te.placeholder((size,), name="keys", dtype="int32")
values = te.placeholder((size,), name="values", dtype="int32")
Expand Down
12 changes: 7 additions & 5 deletions tests/python/contrib/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
# pylint: disable=invalid-name
from collections import namedtuple
import numpy as np
import tvm
import tvm.testing
from tvm import te
import tvm.contrib.sparse as tvmsp
import tvm.runtime.ndarray as _nd
import numpy as np
from collections import namedtuple


def test_static_tensor():
"""Tests static tensor"""
dtype = "float32"
stype = "csr"
target = "llvm"
dev = tvm.device(target, 0)
m = te.size_var("m")
Expand All @@ -50,8 +52,8 @@ def test_static_tensor():


def test_dynamic_tensor():
"""Tests dynamic tensor"""
dtype = "float32"
stype = "csr"
target = "llvm"
dev = tvm.device(target, 0)
nr, nc, n = te.size_var("nr"), te.size_var("nc"), te.size_var("n")
Expand All @@ -77,8 +79,8 @@ def test_dynamic_tensor():


def test_sparse_array_tuple():
"""Tests array when it is sparse"""
dtype, itype = "float32", "int32"
stype = "csr"
target = "llvm"
dev = tvm.device(target, 0)
nr, nc, n = te.size_var("nr"), te.size_var("nc"), te.size_var("n")
Expand Down
28 changes: 19 additions & 9 deletions tests/python/contrib/test_tedd.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest of Tensor Expression Debug Display"""
# pylint: disable=invalid-name
import re

import tvm
from tvm import te
from tvm import topi
Expand All @@ -24,27 +25,31 @@
from tvm.relay.backend import Runtime, Executor


def findany(pattern, str):
matches = re.findall(pattern, str)
assert len(matches) > 0, "Pattern not found.\nPattern: " + pattern + "\nString: " + str
def findany(pattern, _str):
matches = re.findall(pattern, _str)
assert len(matches) > 0, "Pattern not found.\nPattern: " + pattern + "\nString: " + _str


def checkdependency():
# pylint: disable=import-outside-toplevel
import pkg_resources

# pylint: disable=E1133
return not {"graphviz", "ipython"} - {pkg.key for pkg in pkg_resources.working_set}


def test_dfg():
"""Tests dataflow graph"""
A = te.placeholder((1024, 4096), dtype="float32", name="A")
B = topi.nn.softmax(A)
# confirm lower works
s = te.create_schedule([B.op])

def verify():
# pylint: disable=import-outside-toplevel
from tvm.contrib import tedd

str = tedd.viz_dataflow_graph(s, False, "", True)
_str = tedd.viz_dataflow_graph(s, False, "", True)
# Check all edges are available
findany(r"digraph \"Dataflow Graph\"", str)
findany(r"Stage_0:O_0 -> Tensor_0_0", str)
Expand All @@ -64,6 +69,7 @@ def verify():


def test_itervar_relationship_graph():
"""Tests itervars relationship graph"""
n = te.var("n")
m = te.var("m")
A = te.placeholder((n, m), name="A")
Expand All @@ -74,9 +80,10 @@ def test_itervar_relationship_graph():
s[B].split(B.op.reduce_axis[0], factor=16)

def verify():
# pylint: disable=import-outside-toplevel
from tvm.contrib import tedd

str = tedd.viz_itervar_relationship_graph(s, False, "", True)
_str = tedd.viz_itervar_relationship_graph(s, False, "", True)
findany(r"digraph \"IterVar Relationship Graph\"", str)
findany(r"subgraph cluster_legend", str)
# Check subgraphs for stages
Expand All @@ -97,6 +104,7 @@ def verify():


def test_schedule_tree():
"""Tests schedule tree"""
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis("threadIdx.x")
n = te.var("n")
Expand Down Expand Up @@ -124,9 +132,10 @@ def test_schedule_tree():
s[C].bind(s[C].op.axis[1], thread_x)

def verify():
# pylint: disable=import-outside-toplevel
from tvm.contrib import tedd

str = tedd.viz_schedule_tree(s, False, "", True)
_str = tedd.viz_schedule_tree(s, False, "", True)
findany(r"digraph \"Schedule Tree\"", str)
findany(r"subgraph cluster_legend", str)
# Check the A_shared stage, including memory scope, itervars,
Expand All @@ -153,6 +162,7 @@ def test_tedd_with_schedule_record():
"""Test to build a nn model and check if all schedules could be generated"""

def check_schedule(executor):
# pylint: disable=import-outside-toplevel
from tvm.contrib import tedd

error = {}
Expand All @@ -167,12 +177,12 @@ def check_schedule(executor):
tedd.viz_dataflow_graph(sch, False, "", True)
tedd.viz_itervar_relationship_graph(sch, False, "", True)
tedd.viz_schedule_tree(sch, False, "", True)
except:
except: # pylint: disable=W0702
if func_name not in error:
error[func_name] = []
error[func_name].append(index)

assert error == {}, str(error)
assert not error, str(error)

if checkdependency():
relay_mod, params = testing.mobilenet.get_workload(batch_size=1, dtype="float32")
Expand Down

0 comments on commit 2f20264

Please sign in to comment.