diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 2a1a72b3..082dea3a 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -16,7 +16,7 @@ import org.scalatest._ * @author 杨博 (Yang Bo) */ class TensorsSpec extends AsyncFreeSpec with Matchers { - private val tensors: Tensors = + private def doTensors: Do[Tensors] = Do.monadicCloseable( Factory[ OpenCL.GlobalExecutionContext with OpenCL.UseAllDevices with OpenCL.UseFirstPlatform with OpenCL.CommandQueuePool with Tensors] .newInstance( @@ -24,26 +24,27 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { numberOfCommandQueuesForDevice = { (deviceId: Long, capabilities: CLCapabilities) => 5 } - ) + )) "create a tensor of a constant" in { - val shape = Array(2, 3, 5) - val element = 42.0f - val zeros = tensors.Tensor.fill(element, shape) - - for { - pendingBuffer <- zeros.enqueue - floatBuffer <- pendingBuffer.toHostBuffer - } yield { - for (i <- 0 until floatBuffer.capacity()) { - floatBuffer.get(i) should be(element) + doTensors.flatMap { tensors => + val shape = Array(2, 3, 5) + val element = 42.0f + val zeros = tensors.Tensor.fill(element, shape) + for { + pendingBuffer <- zeros.enqueue + floatBuffer <- pendingBuffer.toHostBuffer + } yield { + for (i <- 0 until floatBuffer.capacity()) { + floatBuffer.get(i) should be(element) + } + floatBuffer.position() should be(0) + floatBuffer.limit() should be(shape.product) + floatBuffer.capacity() should be(shape.product) + tensors.kernelCache.getIfPresent(zeros.closure) should not be null + val zeros2 = tensors.Tensor.fill(element, shape) + tensors.kernelCache.getIfPresent(zeros2.closure) should not be null } - floatBuffer.position() should be(0) - floatBuffer.limit() should be(shape.product) - floatBuffer.capacity() should be(shape.product) - tensors.kernelCache.getIfPresent(zeros.closure) should not be null - val zeros2 = tensors.Tensor.fill(element, shape) - tensors.kernelCache.getIfPresent(zeros2.closure) should not be null } }.run.toScalaFuture