Skip to content

Commit

Permalink
Update conftest.py
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Aug 16, 2023
1 parent 57d8331 commit fedc3d0
Showing 1 changed file with 10 additions and 35 deletions.
45 changes: 10 additions & 35 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,28 @@
# global
import pytest
from typing import Dict

# local
import ivy
from ivy_tests.test_ivy import helpers
import jax


FW_STRS = ["numpy", "jax", "tensorflow", "torch"]
jax.config.update("jax_enable_x64", True)


TEST_BACKENDS: Dict[str, callable] = {
"numpy": lambda: helpers.globals._get_ivy_numpy(),
"jax": lambda: helpers.globals._get_ivy_jax(),
"tensorflow": lambda: helpers.globals._get_ivy_tensorflow(),
"torch": lambda: helpers.globals._get_ivy_torch(),
}
FW_STRS = ["numpy", "jax", "tensorflow", "torch"]


@pytest.fixture(autouse=True)
def run_around_tests(dev_str, f, compile_graph, implicit, fw):
if "gpu" in dev_str and fw == "numpy":
def run_around_tests(dev_str, compile_graph, fw):
if "gpu" in device and fw == "numpy":
# Numpy does not support GPU
pytest.skip()
ivy.unset_backend()
with f.use:
with ivy.utils.backend.ContextManager(fw):
with ivy.DefaultDevice(dev_str):
yield


def pytest_generate_tests(metafunc):
# device
# dev_str
raw_value = metafunc.config.getoption("--device")
if raw_value == "all":
devices = ["cpu", "gpu:0", "tpu:0"]
Expand All @@ -40,7 +32,7 @@ def pytest_generate_tests(metafunc):
# framework
raw_value = metafunc.config.getoption("--backend")
if raw_value == "all":
backend_strs = TEST_BACKENDS.keys()
backend_strs = FW_STRS
else:
backend_strs = raw_value.split(",")

Expand All @@ -53,33 +45,16 @@ def pytest_generate_tests(metafunc):
else:
compile_modes = [False]

# with_implicit
raw_value = metafunc.config.getoption("--with_implicit")
if raw_value == "true":
implicit_modes = [True, False]
else:
implicit_modes = [False]

# create test configs
configs = list()
for backend_str in backend_strs:
for device in devices:
for compile_graph in compile_modes:
for implicit in implicit_modes:
configs.append(
(
device,
TEST_BACKENDS[backend_str](),
compile_graph,
implicit,
backend_str,
)
)
metafunc.parametrize("dev_str,f,compile_graph,implicit,fw", configs)
configs.append((device, compile_graph, backend_str))
metafunc.parametrize("dev_str,compile_graph,fw", configs)


def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cpu")
parser.addoption("--backend", action="store", default="numpy,jax,tensorflow,torch")
parser.addoption("--compile_graph", action="store", default="true")
parser.addoption("--with_implicit", action="store", default="false")

0 comments on commit fedc3d0

Please sign in to comment.