Skip to content

Commit

Permalink
fix conftest (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy committed Jul 26, 2023
1 parent 40af0bd commit 75de8d8
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 36 deletions.
25 changes: 8 additions & 17 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,22 @@
# global
import pytest
from typing import Dict

# local
import ivy
from ivy_tests.test_ivy import helpers
import jax


FW_STRS = ["numpy", "jax", "tensorflow", "torch"]
jax.config.update("jax_enable_x64", True)


TEST_BACKENDS: Dict[str, callable] = {
"numpy": lambda: helpers.globals._get_ivy_numpy(),
"jax": lambda: helpers.globals._get_ivy_jax(),
"tensorflow": lambda: helpers.globals._get_ivy_tensorflow(),
"torch": lambda: helpers.globals._get_ivy_torch(),
}
FW_STRS = ["numpy", "jax", "tensorflow", "torch"]


@pytest.fixture(autouse=True)
def run_around_tests(device, f, compile_graph, fw):
def run_around_tests(device, compile_graph, fw):
if "gpu" in device and fw == "numpy":
# Numpy does not support GPU
pytest.skip()
with f.use:
with ivy.utils.backend.ContextManager(fw):
with ivy.DefaultDevice(device):
yield

Expand All @@ -39,7 +32,7 @@ def pytest_generate_tests(metafunc):
# framework
raw_value = metafunc.config.getoption("--backend")
if raw_value == "all":
backend_strs = TEST_BACKENDS.keys()
backend_strs = FW_STRS
else:
backend_strs = raw_value.split(",")

Expand All @@ -57,10 +50,8 @@ def pytest_generate_tests(metafunc):
for backend_str in backend_strs:
for device in devices:
for compile_graph in compile_modes:
configs.append(
(device, TEST_BACKENDS[backend_str](), compile_graph, backend_str)
)
metafunc.parametrize("device,f,compile_graph,fw", configs)
configs.append((device, compile_graph, backend_str))
metafunc.parametrize("device,compile_graph,fw", configs)


def pytest_addoption(parser):
Expand Down
5 changes: 2 additions & 3 deletions ivy_mech_demos/interactive/polar_to_cartesian_coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, interactive, try_use_sim):
plt.show()


def main(interactive=True, try_use_sim=True, f=None, fw=None):
def main(interactive=True, try_use_sim=True, fw=None):
fw = ivy.choose_random_backend() if fw is None else fw
ivy.set_backend(fw)
sim = Simulator(interactive, try_use_sim)
Expand Down Expand Up @@ -181,5 +181,4 @@ def main(interactive=True, try_use_sim=True, f=None, fw=None):
)
parsed_args = parser.parse_args()
fw = parsed_args.backend
f = None if fw is None else ivy.with_backend(backend=fw)
main(not parsed_args.non_interactive, not parsed_args.no_sim, f, fw)
main(not parsed_args.non_interactive, not parsed_args.no_sim, fw)
5 changes: 2 additions & 3 deletions ivy_mech_demos/interactive/target_facing_rotation_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __init__(self, interactive, try_use_sim):
plt.show()


def main(interactive=True, try_use_sim=True, f=None, fw=None):
def main(interactive=True, try_use_sim=True, fw=None):
fw = ivy.choose_random_backend() if fw is None else fw
ivy.set_backend(fw)
sim = Simulator(interactive, try_use_sim)
Expand Down Expand Up @@ -157,5 +157,4 @@ def main(interactive=True, try_use_sim=True, f=None, fw=None):
)
parsed_args = parser.parse_args()
fw = parsed_args.backend()
f = None if fw is None else ivy.with_backend(backend=fw)
main(not parsed_args.non_interactive, not parsed_args.no_sim, f, fw)
main(not parsed_args.non_interactive, not parsed_args.no_sim, fw)
6 changes: 2 additions & 4 deletions ivy_mech_demos/run_through.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import ivy_mech


def main(f=None, fw=None):
def main(fw=None):
# Framework Setup #
# ----------------#

# choose random framework

fw = ivy.choose_random_backend() if fw is None else fw
ivy.set_backend(fw)
f = ivy.with_backend(backend=fw) if f is None else f

# Orientation #
# ------------#
Expand Down Expand Up @@ -96,5 +95,4 @@ def main(f=None, fw=None):
)
parsed_args = parser.parse_args()
fw = parsed_args.backend()
f = None if fw is None else ivy.with_backend(backend=fw)
main(f, fw)
main(fw)
15 changes: 6 additions & 9 deletions ivy_mech_tests/test_ivy_mech_demos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,27 @@
import pytest


def test_demo_run_through(device, f, fw):
def test_demo_run_through(device, fw):
from ivy_mech_demos.run_through import main

if fw == "tensorflow_graph":
# these particular demos are only implemented in eager mode, without compilation
pytest.skip()
main(f=f, fw=fw)
main(fw=fw)


@pytest.mark.parametrize("with_sim", [False])
def test_demo_target_facing_rotation_vector(with_sim, device, f, fw):
def test_demo_target_facing_rotation_vector(with_sim, device, fw):
from ivy_mech_demos.interactive.target_facing_rotation_matrix import main

if fw == "tensorflow_graph":
# these particular demos are only implemented in eager mode, without compilation
pytest.skip()
main(False, with_sim, f=f, fw=fw)
main(False, with_sim, fw=fw)


@pytest.mark.parametrize("with_sim", [False])
def test_demo_polar_to_cartesian_coords(with_sim, device, f, fw):
def test_demo_polar_to_cartesian_coords(with_sim, device, fw):
from ivy_mech_demos.interactive.polar_to_cartesian_coords import main

if fw == "tensorflow_graph":
# these particular demos are only implemented in eager mode, without compilation
pytest.skip()
main(False, with_sim, f=f, fw=fw)
main(False, with_sim, fw=fw)

0 comments on commit 75de8d8

Please sign in to comment.