Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pylint] fix pylint issues from test_random to test_tedd #16065

Merged
merged 19 commits into from
Nov 7, 2023
Merged
8 changes: 8 additions & 0 deletions tests/lint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ python3 -m pylint tests/python/contrib/test_cblas.py --rcfile="$(dirname "$0")"/
python3 -m pylint tests/python/contrib/test_tflite_runtime.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_thrust.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_util.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_sort.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_sparse.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_tedd.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rpc_tracker.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rpc_server_device.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rpc_proxy.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_rocblas.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_random.py --rcfile="$(dirname "$0")"/pylintrc

# tests/python/contrib/test_hexagon tests
python3 -m pylint tests/python/contrib/test_hexagon/*.py --rcfile="$(dirname "$0")"/pylintrc
Expand Down
60 changes: 32 additions & 28 deletions tests/python/contrib/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
import threading
import numpy as np
import tvm
from tvm import te
import numpy as np
from tvm.contrib import random
from tvm import rpc
import tvm.testing
import threading


def test_randint():
"""Tests randint function"""
m = 10240
n = 10240
A = random.randint(-127, 128, size=(m, n), dtype="int32")
s = te.create_schedule(A.op)
input_a = random.randint(-127, 128, size=(m, n), dtype="int32")
s = te.create_schedule(input_a.op)

def verify(target="llvm"):
if not tvm.testing.device_enabled(target):
Expand All @@ -37,22 +39,23 @@ def verify(target="llvm"):
print("skip because extern function is not available")
return
dev = tvm.cpu(0)
f = tvm.build(s, [A], target)
a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), dev)
f = tvm.build(s, [input_a], target)
a = tvm.nd.array(np.zeros((m, n), dtype=input_a.dtype), dev)
f(a)
na = a.numpy()
assert abs(np.mean(na)) < 0.3
assert np.min(na) == -127
assert np.max(na) == 127
_na = a.numpy()
assert abs(np.mean(_na)) < 0.3
assert np.min(_na) == -127
assert np.max(_na) == 127

verify()


def test_uniform():
"""Tests uniform function"""
m = 10240
n = 10240
A = random.uniform(0, 1, size=(m, n))
s = te.create_schedule(A.op)
input_a = random.uniform(0, 1, size=(m, n))
s = te.create_schedule(input_a.op)

def verify(target="llvm"):
if not tvm.testing.device_enabled(target):
Expand All @@ -62,22 +65,23 @@ def verify(target="llvm"):
print("skip because extern function is not available")
return
dev = tvm.cpu(0)
f = tvm.build(s, [A], target)
a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), dev)
f = tvm.build(s, [input_a], target)
a = tvm.nd.array(np.zeros((m, n), dtype=input_a.dtype), dev)
f(a)
na = a.numpy()
assert abs(np.mean(na) - 0.5) < 1e-1
assert abs(np.min(na) - 0.0) < 1e-3
assert abs(np.max(na) - 1.0) < 1e-3
op_na = a.numpy()
assert abs(np.mean(op_na) - 0.5) < 1e-1
assert abs(np.min(op_na) - 0.0) < 1e-3
assert abs(np.max(op_na) - 1.0) < 1e-3

verify()


def test_normal():
"""Tests normal function"""
m = 10240
n = 10240
A = random.normal(3, 4, size=(m, n))
s = te.create_schedule(A.op)
input_a = random.normal(3, 4, size=(m, n))
s = te.create_schedule(input_a.op)

def verify(target="llvm"):
if not tvm.testing.device_enabled(target):
Expand All @@ -87,18 +91,20 @@ def verify(target="llvm"):
print("skip because extern function is not available")
return
dev = tvm.cpu(0)
f = tvm.build(s, [A], target)
a = tvm.nd.array(np.zeros((m, n), dtype=A.dtype), dev)
f = tvm.build(s, [input_a], target)
a = tvm.nd.array(np.zeros((m, n), dtype=input_a.dtype), dev)
f(a)
na = a.numpy()
assert abs(np.mean(na) - 3) < 1e-1
assert abs(np.std(na) - 4) < 1e-2
_na = a.numpy()
assert abs(np.mean(_na) - 3) < 1e-1
assert abs(np.std(_na) - 4) < 1e-2

verify()


@tvm.testing.uses_gpu
def test_random_fill():
"""Tests random_fill function"""

def test_local(dev, dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
Expand All @@ -120,8 +126,6 @@ def test_rpc(dtype):
if not tvm.testing.device_enabled("rpc") or not tvm.runtime.enabled("llvm"):
return

np_ones = np.ones((512, 512), dtype=dtype)

def check_remote(server):
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty((512, 512), dtype, remote.cpu())
Expand Down Expand Up @@ -171,7 +175,7 @@ def test_body():
test_input = tvm.runtime.ndarray.empty((10, 10))
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill_for_measure")
random_fill(test_input)
except:
except: # pylint: disable=bare-except
nonlocal no_exception_happened
no_exception_happened = False

Expand Down
41 changes: 22 additions & 19 deletions tests/python/contrib/test_rocblas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,47 +14,49 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
import numpy as np
import tvm
import tvm.testing
from tvm import te
import numpy as np
import tvm.topi.testing
import tvm.testing
from tvm.contrib import rocblas


@tvm.testing.requires_rocm
def test_matmul():
"""Tests matmul operation using roc"""
n = 1024
l = 128
op_l = 128
m = 235
A = te.placeholder((n, l), name="A")
B = te.placeholder((l, m), name="B")
C = rocblas.matmul(A, B)
s = te.create_schedule(C.op)
input_a = te.placeholder((n, op_l), name="input_a")
input_b = te.placeholder((op_l, m), name="input_b")
result_c = rocblas.matmul(input_a, input_b)
s = te.create_schedule(result_c.op)

def verify(target="rocm"):
if not tvm.get_global_func("tvm.contrib.rocblas.matmul", True):
print("skip because extern function is not available")
return
dev = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(n, l)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(l, m)).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((n, m), dtype=C.dtype), dev)
f = tvm.build(s, [input_a, input_b, result_c], target)
a = tvm.nd.array(np.random.uniform(size=(n, op_l)).astype(input_a.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(op_l, m)).astype(input_b.dtype), dev)
c = tvm.nd.array(np.zeros((n, m), dtype=result_c.dtype), dev)
f(a, b, c)
tvm.testing.assert_allclose(c.numpy(), np.dot(a.numpy(), b.numpy()), rtol=1e-5)

verify()


def verify_batch_matmul(batch, m, k, n, lib, transa=False, transb=False, dtype="float32"):
"""Tests matmul operation in batch using roc"""
ashape = (batch, k, m) if transa else (batch, m, k)
bshape = (batch, n, k) if transb else (batch, k, n)
A = te.placeholder(ashape, name="A", dtype=dtype)
B = te.placeholder(bshape, name="B", dtype=dtype)
C = lib.batch_matmul(A, B, transa, transb)
s = te.create_schedule(C.op)
input_a = te.placeholder(ashape, name="input_a", dtype=dtype)
input_b = te.placeholder(bshape, name="input_b", dtype=dtype)
result_c = lib.batch_matmul(input_a, input_b, transa, transb)
s = te.create_schedule(result_c.op)

def get_numpy(a, b, transa, transb):
if transa:
Expand All @@ -71,10 +73,10 @@ def verify(target="rocm"):
print("skip because extern function is not available")
return
dev = tvm.rocm(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((batch, m, n), dtype=C.dtype), dev)
f = tvm.build(s, [input_a, input_b, result_c], target)
a = tvm.nd.array(np.random.uniform(size=ashape).astype(input_a.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=bshape).astype(input_b.dtype), dev)
c = tvm.nd.array(np.zeros((batch, m, n), dtype=result_c.dtype), dev)
f(a, b, c)
tvm.testing.assert_allclose(
c.numpy(), get_numpy(a.numpy(), b.numpy(), transa, transb), rtol=1e-5
Expand All @@ -85,6 +87,7 @@ def verify(target="rocm"):

@tvm.testing.requires_rocm
def test_batch_matmul():
"""Tests of matmul operation in batch using roc"""
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=False)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=False, transb=True)
verify_batch_matmul(128, 64, 512, 512, rocblas, transa=True, transb=False)
Expand Down
12 changes: 6 additions & 6 deletions tests/python/contrib/test_rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
"""Configure pytest"""
import logging
import numpy as np
import time
import multiprocessing
import tvm
from tvm import rpc


Expand All @@ -35,6 +34,7 @@ def rpc_proxy_check():
"""

try:
# pylint: disable=import-outside-toplevel
from tvm.rpc import proxy

web_port = 8888
Expand All @@ -52,9 +52,9 @@ def check():
server.deamon = True
server.start()
client = rpc.connect(prox.host, prox.port, key="x1")
f1 = client.get_function("testing.echo")
assert f1(10) == 10
assert f1("xyz") == "xyz"
test_f1 = client.get_function("testing.echo")
assert test_f1(10) == 10
assert test_f1("xyz") == "xyz"

check()
except ImportError:
Expand Down
1 change: 0 additions & 1 deletion tests/python/contrib/test_rpc_server_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
# under the License.
"""iOS RPC Server tests."""
# pylint: disable=invalid-name, no-value-for-parameter, missing-function-docstring, import-error
import sys
import multiprocessing
import pytest
import numpy as np
Expand Down
30 changes: 16 additions & 14 deletions tests/python/contrib/test_rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
"""Configure pytest"""
import logging
import numpy as np
import time
import multiprocessing
import tvm
from tvm import rpc


def check_server_drop():
"""test when server drops"""
try:
# pylint: disable=import-outside-toplevel
from tvm.rpc import tracker, proxy, base

# pylint: disable=import-outside-toplevel
from tvm.rpc.base import TrackerCode

@tvm.register_func("rpc.test2.addone")
Expand Down Expand Up @@ -63,8 +64,8 @@ def _put(tclient, value):
def check_timeout(timeout, sleeptime):
def myfunc(remote):
time.sleep(sleeptime)
f1 = remote.get_function("rpc.test2.addone")
assert f1(10) == 11
test_f1 = remote.get_function("rpc.test2.addone")
assert test_f1(10) == 11

try:
tclient.request_and_run("xyz", myfunc, session_timeout=timeout)
Expand All @@ -75,18 +76,19 @@ def myfunc(remote):
remote = tclient.request("xyz", priority=0, session_timeout=timeout)
remote2 = tclient.request("xyz", session_timeout=timeout)
time.sleep(sleeptime)
f1 = remote.get_function("rpc.test2.addone")
assert f1(10) == 11
f1 = remote2.get_function("rpc.test2.addone")
assert f1(10) == 11
test_f1 = remote.get_function("rpc.test2.addone")
assert test_f1(10) == 11
test_f1 = remote2.get_function("rpc.test2.addone")
assert test_f1(10) == 11

except tvm.error.TVMError as e:
except tvm.error.TVMError:
pass
remote3 = tclient.request("abc")
f1 = remote3.get_function("rpc.test2.addone")
test_f1 = remote3.get_function("rpc.test2.addone")
assert test_f1(10) == 11
remote3 = tclient.request("xyz1")
f1 = remote3.get_function("rpc.test2.addone")
assert f1(10) == 11
test_f1 = remote3.get_function("rpc.test2.addone")
assert test_f1(10) == 11

check_timeout(0.01, 0.1)
check_timeout(2, 0)
Expand Down