Skip to content

Commit

Permalink
fix conftest
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Jul 26, 2023
1 parent ff78e63 commit d8b6b9d
Showing 1 changed file with 9 additions and 35 deletions.
44 changes: 9 additions & 35 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,25 @@
# global
import pytest
from typing import Dict

# local
import ivy
from ivy_tests.test_ivy import helpers


FW_STRS = ["numpy", "jax", "tensorflow", "torch"]


TEST_BACKENDS: Dict[str, callable] = {
"numpy": lambda: helpers.globals._get_ivy_numpy(),
"jax": lambda: helpers.globals._get_ivy_jax(),
"tensorflow": lambda: helpers.globals._get_ivy_tensorflow(),
"torch": lambda: helpers.globals._get_ivy_torch(),
}


@pytest.fixture(autouse=True)
def run_around_tests(dev_str, f, compile_graph, implicit, fw):
if "gpu" in dev_str and fw == "numpy":
def run_around_tests(device, f, compile_graph, fw):
if "gpu" in device and fw == "numpy":
# Numpy does not support GPU
pytest.skip()
ivy.unset_backend()
with f.use:
with ivy.DefaultDevice(dev_str):
with ivy.DefaultDevice(device):
yield


def pytest_generate_tests(metafunc):
# device
# dev_str
raw_value = metafunc.config.getoption("--device")
if raw_value == "all":
devices = ["cpu", "gpu:0", "tpu:0"]
Expand All @@ -40,7 +29,7 @@ def pytest_generate_tests(metafunc):
# framework
raw_value = metafunc.config.getoption("--backend")
if raw_value == "all":
backend_strs = TEST_BACKENDS.keys()
backend_strs = FW_STRS
else:
backend_strs = raw_value.split(",")

Expand All @@ -53,33 +42,18 @@ def pytest_generate_tests(metafunc):
else:
compile_modes = [False]

# with_implicit
raw_value = metafunc.config.getoption("--with_implicit")
if raw_value == "true":
implicit_modes = [True, False]
else:
implicit_modes = [False]

# create test configs
configs = list()
for backend_str in backend_strs:
for device in devices:
for compile_graph in compile_modes:
for implicit in implicit_modes:
configs.append(
(
device,
TEST_BACKENDS[backend_str](),
compile_graph,
implicit,
backend_str,
)
)
metafunc.parametrize("dev_str,f,compile_graph,implicit,fw", configs)
configs.append(
(device, ivy.with_backend(backend_str), compile_graph, backend_str)
)
metafunc.parametrize("device,f,compile_graph,fw", configs)


def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--backend", action="store", default="numpy,jax,tensorflow,torch")
parser.addoption("--compile_graph", action="store", default="true")
parser.addoption("--with_implicit", action="store", default="false")

0 comments on commit d8b6b9d

Please sign in to comment.