Skip to content

Commit

Permalink
- numpy import fix for CUDA (#64)
Browse files Browse the repository at this point in the history
- skip tagLocation for empty arrays

Signed-off-by: raver119 <raver119@gmail.com>
  • Loading branch information
raver119 authored and AlexDBlack committed Jul 20, 2019
1 parent c9e867b commit c499dc9
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 31 deletions.
Expand Up @@ -529,7 +529,7 @@ public AllocationPoint allocateMemory(DataBuffer buffer, AllocationShape require
* @param objectId
* @return
*/
protected AllocationPoint getAllocationPoint(Long objectId) {
protected AllocationPoint getAllocationPoint(@NonNull Long objectId) {
return allocationsMap.get(objectId);
}

Expand Down
Expand Up @@ -339,6 +339,10 @@ public DataBuffer replicateToDevice(Integer deviceId, DataBuffer buffer) {
*/
@Override
public void tagLocation(INDArray array, Location location) {
// we can't tag empty arrays.
if (array.isEmpty())
return;

if (location == Location.HOST)
AtomicAllocator.getInstance().getAllocationPoint(array).tickHostWrite();
else if (location == Location.DEVICE)
Expand Down
Expand Up @@ -116,6 +116,7 @@ public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) {

//cuda specific bits
this.allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false);
this.trackingPoint = allocationPoint.getObjectId();

Nd4j.getDeallocatorService().pickObject(this);

Expand All @@ -124,40 +125,19 @@ public BaseCudaDataBuffer(Pointer pointer, Indexer indexer, long length) {

val perfD = PerformanceTracker.getInstance().helperStartTransaction();

NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
if (allocationPoint.getHostPointer() != null) {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getHostPointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), allocationPoint.getHostPointer(), length * getElementSize(), CudaConstants.cudaMemcpyHostToHost, context.getSpecialStream());
} else {
NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), pointer, length * getElementSize(), CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream());
}

context.getSpecialStream().synchronize();

PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);

this.pointer = new CudaPointer(allocationPoint.getHostPointer(), length * getElementSize(), 0);

switch (dataType()) {
case INT: {
setIndexer(IntIndexer.create(((CudaPointer) this.pointer).asIntPointer()));
}
break;
case FLOAT: {
setIndexer(FloatIndexer.create(((CudaPointer) this.pointer).asFloatPointer()));
}
break;
case DOUBLE: {
setIndexer(DoubleIndexer.create(((CudaPointer) this.pointer).asDoublePointer()));
}
break;
case HALF: {
setIndexer(ShortIndexer.create(((CudaPointer) this.pointer).asShortPointer()));
}
break;
case LONG: {
setIndexer(LongIndexer.create(((CudaPointer) this.pointer).asLongPointer()));
}
break;
}
if (allocationPoint.getHostPointer() != null)
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_HOST);

this.trackingPoint = allocationPoint.getObjectId();
PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD / 2, allocationPoint.getNumberOfBytes(), MemcpyDirection.HOST_TO_DEVICE);

}

Expand Down
Expand Up @@ -310,6 +310,13 @@ public void testAbsentNumpyFile_1() throws Exception {
INDArray act1 = Nd4j.createFromNpyFile(f);
}

@Test
public void testAbsentNumpyFile_2() throws Exception {
val f = new File("c:/develop/batch-x-1.npy");
INDArray act1 = Nd4j.createFromNpyFile(f);
log.info("Array shape: {}; sum: {};", act1.shape(), act1.sumNumber().doubleValue());
}

@Override
public char ordering() {
return 'c';
Expand Down

0 comments on commit c499dc9

Please sign in to comment.