Skip to content

Commit

Permalink
Add the out argument to dot. Fixes #24.
Browse files Browse the repository at this point in the history
  • Loading branch information
FilipeMaia committed Oct 23, 2016
1 parent 279c1cf commit e328912
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
12 changes: 6 additions & 6 deletions afnumpy/linalg/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from afnumpy import asarray, sqrt, abs
from afnumpy.lib import asfarray
from .. import private_utils as pu
from ..decorators import *

def isComplexType(t):
return issubclass(t, complexfloating)
Expand All @@ -13,14 +14,14 @@ def vdot(a, b):
s = arrayfire.dot(arrayfire.conjg(a.flat.d_array), b.flat.d_array)
return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()]

# TODO: Implement multidimensional dot
@outufunc
def dot(a, b):
# Arrayfire requires that the types match for dot and matmul
res_dtype = numpy.result_type(a,b)
a = a.astype(res_dtype, copy=False)
b = b.astype(res_dtype, copy=False)
if a.ndim == 1 and b.ndim == 1:
s = arrayfire.dot((a.flat.d_array), b.flat.d_array)
s = arrayfire.dot(a.d_array, b.d_array)
return afnumpy.ndarray((), dtype=a.dtype, af_array=s)[()]

a_shape = a.shape
Expand All @@ -33,7 +34,9 @@ def dot(a, b):
if a.ndim == 2 and b.ndim == 2:
# Notice the order of the arguments to matmul. It's not a bug!
s = arrayfire.matmul(b.d_array, a.d_array)
return afnumpy.ndarray(pu.af_shape(s), dtype=pu.typemap(s.dtype()), af_array=s)
return afnumpy.ndarray(pu.af_shape(s), dtype=pu.typemap(s.dtype()),
af_array=s)

# Multidimensional dot is done with loops

# Calculate the shape of the result array
Expand All @@ -43,9 +46,6 @@ def dot(a, b):
b_shape.pop(-2)
res_shape = a_shape + b_shape

# Initialize the output array
res = afnumpy.empty(res_shape, dtype=res_dtype)

# Make sure the arrays are at least 3D
if a.ndim < 3:
a = a.reshape((1,)+a.shape)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ def test_dot_2D():
a = afnumpy.array(b)
fassert(afnumpy.dot(a,a), numpy.dot(b,b))

a = numpy.random.random((3,3))+numpy.random.random((3,3))*1.0j
b = numpy.random.random((3,3))
fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b)), numpy.dot(a,b))
out = afnumpy.array(a)
fassert(afnumpy.dot(afnumpy.array(a),afnumpy.array(b),out=out), numpy.dot(a,b))

def test_dot_3D():
b = numpy.random.random((3,3,3))+numpy.random.random((3,3,3))*1.0j
a = afnumpy.array(b)
Expand Down

0 comments on commit e328912

Please sign in to comment.