Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pylint] fix pylint issues for thrust&tflite_runtime&util #16023

Merged
merged 5 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/lint/pylint.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ python3 -m pylint tests/python/ci --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/integration/ --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/conftest.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_cblas.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_tflite_runtime.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_thrust.py --rcfile="$(dirname "$0")"/pylintrc
python3 -m pylint tests/python/contrib/test_util.py --rcfile="$(dirname "$0")"/pylintrc

# tests/python/contrib/test_hexagon tests
python3 -m pylint tests/python/contrib/test_hexagon/*.py --rcfile="$(dirname "$0")"/pylintrc
Expand Down
18 changes: 11 additions & 7 deletions tests/python/contrib/test_tflite_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
import pytest

import tvm
from tvm import te
import numpy as np
import tvm
from tvm import rpc
from tvm.contrib import utils, tflite_runtime


def _create_tflite_model():
"""Functions of creating a tflite model"""
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
return None
if not tvm.get_global_func("tvm.tflite_runtime.create", True):
print("skip because tflite runtime is not enabled...")
return
return None

try:
# pylint: disable=import-outside-toplevel
import tensorflow as tf
except ImportError:
print("skip because tensorflow not installed...")
return
return None

root = tf.Module()
root.const = tf.constant([1.0, 2.0], tf.float32)
Expand All @@ -55,6 +56,7 @@ def _create_tflite_model():

@pytest.mark.skip("skip because accessing output tensor is flakey")
def test_local():
"""Local tests of tflite model"""
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
Expand All @@ -63,6 +65,7 @@ def test_local():
return

try:
# pylint: disable=import-outside-toplevel
import tensorflow as tf
except ImportError:
print("skip because tensorflow not installed...")
Expand Down Expand Up @@ -96,6 +99,7 @@ def test_local():


def test_remote():
"""Remote tests of tflite model"""
if not tvm.runtime.enabled("tflite"):
print("skip because tflite runtime is not enabled...")
return
Expand All @@ -104,6 +108,7 @@ def test_remote():
return

try:
# pylint: disable=import-outside-toplevel
import tensorflow as tf
except ImportError:
print("skip because tensorflow not installed...")
Expand All @@ -130,7 +135,6 @@ def test_remote():
# inference via remote tvm tflite runtime
def check_remote(server):
remote = rpc.connect(server.host, server.port)
a = remote.upload(tflite_model_path)

with open(tflite_model_path, "rb") as model_fin:
runtime = tflite_runtime.create(model_fin.read(), remote.cpu(0))
Expand Down
6 changes: 5 additions & 1 deletion tests/python/contrib/test_thrust.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Configure pytest"""
import numpy as np
import tvm
import tvm.testing
from tvm import te
from tvm.topi.cuda import stable_sort_by_key_thrust
from tvm.topi.cuda.scan import exclusive_scan, scan_thrust, schedule_scan
from tvm.contrib.thrust import can_use_thrust, can_use_rocthrust
import numpy as np


thrust_check_func = {"cuda": can_use_thrust, "rocm": can_use_rocthrust}


def test_stable_sort_by_key():
"""Tests function test_stable_sort_by_key"""
size = 6
keys = te.placeholder((size,), name="keys", dtype="int32")
values = te.placeholder((size,), name="values", dtype="int32")
Expand Down Expand Up @@ -64,6 +66,7 @@ def test_stable_sort_by_key():


def test_exclusive_scan():
"""Tests function test_exclusive_scan"""
for target in ["cuda", "rocm"]:
if not tvm.testing.device_enabled(target):
print("Skip because %s is not enabled" % target)
Expand Down Expand Up @@ -105,6 +108,7 @@ def test_exclusive_scan():


def test_inclusive_scan():
"""Tests function test_inclusive_scan"""
out_dtype = "int64"

for target in ["cuda", "rocm"]:
Expand Down
4 changes: 3 additions & 1 deletion tests/python/contrib/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@


def validate_debug_dir_path(temp_dir, expected_basename):
"""Validate the dir path of debugging"""
dirname, basename = os.path.split(temp_dir.temp_dir)
assert basename == expected_basename, "unexpected basename: %s" % (basename,)

Expand All @@ -32,7 +33,8 @@ def validate_debug_dir_path(temp_dir, expected_basename):


def test_tempdir():
assert utils.TempDirectory._KEEP_FOR_DEBUG == False, "don't submit with KEEP_FOR_DEBUG == True"
"""Tests for temporary dir"""
assert utils.TempDirectory._KEEP_FOR_DEBUG is False, "don't submit with KEEP_FOR_DEBUG == True"

temp_dir = utils.tempdir()
assert os.path.exists(temp_dir.temp_dir)
Expand Down