Skip to content

Commit

Permalink
Fix for cumath.sqrt usage, reported by @julian121266
Browse files Browse the repository at this point in the history
  • Loading branch information
flukeskywalker committed Nov 17, 2015
1 parent f35dae5 commit 4f69701
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion brainstorm/handlers/pycuda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ def split_add_tt(self, x, out_a, out_b):
block=block, grid=grid)

def sqrt_t(self, a, out):
cumath.sqrt(a, out)
cumath.sqrt(a, out=out)

def subtract_mv(self, m, v, out):
cumisc.binaryop_matvec('-', m, v, None, out, None)
Expand Down
11 changes: 11 additions & 0 deletions brainstorm/tests/test_handlers_against_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,17 @@ def test_log_t(handler):
assert operation_check(handler, 'log_t', ref_args)


@pytest.mark.parametrize("handler", non_default_handlers, ids=handler_ids)
def test_sqrt_t(handler):
list_a = get_random_arrays(some_nd_shapes)

for a in list_a:
a += 10 # to remove negatives
out = np.zeros_like(a, dtype=ref_dtype)
ref_args = (a, out)
assert operation_check(handler, 'sqrt_t', ref_args)


@pytest.mark.parametrize("handler", non_default_handlers, ids=handler_ids)
def test_abs_t(handler):
list_a = get_random_arrays(some_nd_shapes)
Expand Down

0 comments on commit 4f69701

Please sign in to comment.