Skip to content

Commit

Permalink
Rename get_debug_values to get_test_values
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Oct 18, 2020
1 parent 0ea1058 commit da6ea2f
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
24 changes: 12 additions & 12 deletions tests/gof/test_op.py
Expand Up @@ -288,23 +288,23 @@ def test_test_value_op():


@change_flags(compute_test_value="off")
def test_get_debug_values_no_debugger():
"""Tests that `get_debug_values` returns `[]` when debugger is off."""
def test_get_test_values_no_debugger():
"""Tests that `get_test_values` returns `[]` when debugger is off."""

x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []


@change_flags(compute_test_value="ignore")
def test_get_det_debug_values_ignore():
"""Tests that `get_debug_values` returns `[]` when debugger is set to "ignore" and some values are missing."""
def test_get_test_values_ignore():
"""Tests that `get_test_values` returns `[]` when debugger is set to "ignore" and some values are missing."""

x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []


def test_get_debug_values_success():
"""Tests that `get_debug_value` returns values when available (and the debugger is on)."""
def test_get_test_values_success():
"""Tests that `get_test_values` returns values when available (and the debugger is on)."""

for mode in ["ignore", "warn", "raise"]:
with change_flags(compute_test_value=mode):
Expand All @@ -314,7 +314,7 @@ def test_get_debug_values_success():

iters = 0

for x_val, y_val in op.get_debug_values(x, y):
for x_val, y_val in op.get_test_values(x, y):

assert x_val.shape == (4,)
assert y_val.shape == (5, 5)
Expand All @@ -325,9 +325,9 @@ def test_get_debug_values_success():


@change_flags(compute_test_value="raise")
def test_get_debug_values_exc():
"""Tests that `get_debug_value` raises an exception when debugger is set to raise and a value is missing."""
def test_get_test_values_exc():
"""Tests that `get_test_values` raises an exception when debugger is set to raise and a value is missing."""

with pytest.raises(AttributeError):
x = tt.vector()
assert op.get_debug_values(x) == []
assert op.get_test_values(x) == []
2 changes: 1 addition & 1 deletion theano/gof/op.py
Expand Up @@ -1080,7 +1080,7 @@ def missing_test_message(msg):
assert action in ["ignore", "off"]


def get_debug_values(*args):
def get_test_values(*args):
"""
Intended use:
Expand Down
6 changes: 3 additions & 3 deletions theano/gradient.py
Expand Up @@ -15,7 +15,7 @@
from theano.gof import utils, Variable

from theano.gof.null_type import NullType, null_type
from theano.gof.op import get_debug_values
from theano.gof.op import get_test_values
from theano.compile import ViewOp, FAST_RUN, DebugMode, get_mode

__authors__ = "James Bergstra, Razvan Pascanu, Arnaud Bergeron, Ian Goodfellow"
Expand Down Expand Up @@ -1217,7 +1217,7 @@ def try_to_copy_if_needed(var):
continue
if isinstance(new_output_grad.type, DisconnectedType):
continue
for orig_output_v, new_output_grad_v in get_debug_values(*packed):
for orig_output_v, new_output_grad_v in get_test_values(*packed):
o_shape = orig_output_v.shape
g_shape = new_output_grad_v.shape
if o_shape != g_shape:
Expand Down Expand Up @@ -1310,7 +1310,7 @@ def try_to_copy_if_needed(var):
# has the right shape
if hasattr(term, "shape"):
orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_debug_values(orig_ipt, term):
for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
Expand Down

0 comments on commit da6ea2f

Please sign in to comment.