Skip to content

Commit

Permalink
#8057 Nd4j.create overload dtype fix
Browse files Browse the repository at this point in the history
Signed-off-by: AlexDBlack <blacka101@gmail.com>
  • Loading branch information
AlexDBlack committed Jul 29, 2019
1 parent 52d5db0 commit e51fc06
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
Expand Up @@ -538,7 +538,7 @@ public static long[] shape(INDArray arr) {
public static INDArray create(int[] sliceShape, float[]... arrays) {
//TODO: Remove duplicate code.
int slices = arrays.length;
INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
INDArray ret = Nd4j.createUninitialized(DataType.FLOAT, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
return ret;
Expand Down Expand Up @@ -572,7 +572,7 @@ public static INDArray create(LongShapeDescriptor descriptor, boolean initialize
*/
public static INDArray create(int[] sliceShape, double[]... arrays) {
int slices = arrays.length;
INDArray ret = Nd4j.create(ArrayUtil.combine(new int[] {slices}, sliceShape));
INDArray ret = Nd4j.createUninitialized(DataType.DOUBLE, ArrayUtil.toLongArray(ArrayUtil.combine(new int[] {slices}, sliceShape)));
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, Nd4j.create(arrays[i]).reshape(ArrayUtil.toLongArray(sliceShape)));
return ret;
Expand Down
Expand Up @@ -7877,6 +7877,24 @@ public void mmulToScalar() {
final INDArray arr2 = arr1.reshape(3,1);
assertEquals("Incorrect type!", DataType.FLOAT, arr1.mmul(arr2).dataType());
}


@Test
public void testCreateDtypes() {
int[] sliceShape = new int[] {9};
float[] arrays = new float[] {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
double [] arrays_double = new double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0};

INDArray x = Nd4j.create( sliceShape, arrays, arrays );
assertEquals(DataType.FLOAT, x.dataType());

INDArray xd = Nd4j.create( sliceShape, arrays_double, arrays_double );
assertEquals(DataType.DOUBLE, xd.dataType());
}




///////////////////////////////////////////////////////
protected static void fillJvmArray3D(float[][][] arr) {
int cnt = 1;
Expand Down

0 comments on commit e51fc06

Please sign in to comment.