diff --git a/OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala b/OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala index 7f3594b2..c691f22a 100644 --- a/OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala +++ b/OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala @@ -50,6 +50,29 @@ import scala.language.higherKinds */ object OpenCL { + final case class PlatformId[Owner <: Singleton with OpenCL](handle: Long) extends AnyVal { + + final def deviceIdsByType(deviceType: Int): Seq[DeviceId[Owner]] = { + val Array(numberOfDevices) = { + val a = Array(0) + checkErrorCode(clGetDeviceIDs(handle, deviceType, null, a)) + a + } + val stack = stackPush() + try { + val deviceIdBuffer = stack.mallocPointer(numberOfDevices) + checkErrorCode(clGetDeviceIDs(handle, deviceType, deviceIdBuffer, null: IntBuffer)) + for (i <- 0 until numberOfDevices) yield { + val deviceId = deviceIdBuffer.get(i) + new DeviceId[Owner](deviceId) + } + } finally { + stack.close() + } + } + + } + /** Returns a [[String]] for the C string `address`. * * @note We don't know the exact charset of the C string. Use [[memASCII]] because lwjgl treats them as ASCII. @@ -96,33 +119,33 @@ object OpenCL { } }) object Exceptions { - final class MisalignedSubBufferOffset extends IllegalArgumentException + final class MisalignedSubBufferOffset(message: String = null) extends IllegalArgumentException(message) - final class ExecStatusErrorForEventsInWaitList extends IllegalArgumentException + final class ExecStatusErrorForEventsInWaitList(message: String = null) extends IllegalArgumentException(message) - final class InvalidProperty extends IllegalArgumentException + final class InvalidProperty(message: String = null) extends IllegalArgumentException(message) - final class PlatformNotFoundKhr extends IllegalStateException + final class PlatformNotFoundKhr(message: String = null) extends NoSuchElementException(message) - final class DeviceNotFound extends IllegalArgumentException + final class DeviceNotFound(message: String = null) extends NoSuchElementException(message) - final class DeviceNotAvailable extends IllegalStateException + final class DeviceNotAvailable(message: String = null) extends IllegalStateException(message) - final class CompilerNotAvailable extends IllegalStateException + final class CompilerNotAvailable(message: String = null) extends IllegalStateException(message) - final class MemObjectAllocationFailure extends IllegalStateException + final class MemObjectAllocationFailure(message: String = null) extends IllegalStateException(message) - final class OutOfResources extends IllegalStateException + final class OutOfResources(message: String = null) extends IllegalStateException(message) - final class OutOfHostMemory extends IllegalStateException + final class OutOfHostMemory(message: String = null) extends IllegalStateException(message) - final class ProfilingInfoNotAvailable extends IllegalStateException + final class ProfilingInfoNotAvailable(message: String = null) extends IllegalStateException(message) - final class MemCopyOverlap extends IllegalStateException + final class MemCopyOverlap(message: String = null) extends IllegalStateException(message) - final class ImageFormatMismatch extends IllegalStateException + final class ImageFormatMismatch(message: String = null) extends IllegalStateException(message) - final class ImageFormatNotSupported extends IllegalStateException + final class ImageFormatNotSupported(message: String = null) extends IllegalStateException(message) final class BuildProgramFailure(buildLogs: Map[Long /* device id */, String] = Map.empty) extends IllegalStateException({ @@ -134,71 +157,71 @@ object OpenCL { .mkString("\n") }) - final class MapFailure extends IllegalStateException + final class MapFailure(message: String = null) extends IllegalStateException(message) - final class InvalidValue extends IllegalArgumentException + final class InvalidValue(message: String = null) extends IllegalArgumentException(message) - final class InvalidDeviceType extends IllegalArgumentException + final class InvalidDeviceType(message: String = null) extends IllegalArgumentException(message) - final class InvalidPlatform extends IllegalArgumentException + final class InvalidPlatform(message: String = null) extends IllegalArgumentException(message) - final class InvalidDevice extends IllegalArgumentException + final class InvalidDevice(message: String = null) extends IllegalArgumentException(message) - final class InvalidContext extends IllegalArgumentException + final class InvalidContext(message: String = null) extends IllegalArgumentException(message) - final class InvalidQueueProperties extends IllegalArgumentException + final class InvalidQueueProperties(message: String = null) extends IllegalArgumentException(message) - final class InvalidCommandQueue extends IllegalArgumentException + final class InvalidCommandQueue(message: String = null) extends IllegalArgumentException(message) - final class InvalidHostPtr extends IllegalArgumentException + final class InvalidHostPtr(message: String = null) extends IllegalArgumentException(message) - final class InvalidMemObject extends IllegalArgumentException + final class InvalidMemObject(message: String = null) extends IllegalArgumentException(message) - final class InvalidImageFormatDescriptor extends IllegalArgumentException + final class InvalidImageFormatDescriptor(message: String = null) extends IllegalArgumentException(message) - final class InvalidImageSize extends IllegalArgumentException + final class InvalidImageSize(message: String = null) extends IllegalArgumentException(message) - final class InvalidSampler extends IllegalArgumentException + final class InvalidSampler(message: String = null) extends IllegalArgumentException(message) - final class InvalidBinary extends IllegalArgumentException + final class InvalidBinary(message: String = null) extends IllegalArgumentException(message) - final class InvalidBuildOptions extends IllegalArgumentException + final class InvalidBuildOptions(message: String = null) extends IllegalArgumentException(message) - final class InvalidProgram extends IllegalArgumentException + final class InvalidProgram(message: String = null) extends IllegalArgumentException(message) - final class InvalidProgramExecutable extends IllegalArgumentException + final class InvalidProgramExecutable(message: String = null) extends IllegalArgumentException(message) - final class InvalidKernelName extends IllegalArgumentException + final class InvalidKernelName(message: String = null) extends IllegalArgumentException(message) - final class InvalidKernelDefinition extends IllegalArgumentException + final class InvalidKernelDefinition(message: String = null) extends IllegalArgumentException(message) - final class InvalidKernel extends IllegalArgumentException + final class InvalidKernel(message: String = null) extends IllegalArgumentException(message) - final class InvalidArgIndex extends IllegalArgumentException + final class InvalidArgIndex(message: String = null) extends IllegalArgumentException(message) - final class InvalidArgValue extends IllegalArgumentException + final class InvalidArgValue(message: String = null) extends IllegalArgumentException(message) - final class InvalidArgSize extends IllegalArgumentException + final class InvalidArgSize(message: String = null) extends IllegalArgumentException(message) - final class InvalidKernelArgs extends IllegalArgumentException + final class InvalidKernelArgs(message: String = null) extends IllegalArgumentException(message) - final class InvalidWorkDimension extends IllegalArgumentException + final class InvalidWorkDimension(message: String = null) extends IllegalArgumentException(message) - final class InvalidWorkGroupSize extends IllegalArgumentException + final class InvalidWorkGroupSize(message: String = null) extends IllegalArgumentException(message) - final class InvalidWorkItemSize extends IllegalArgumentException + final class InvalidWorkItemSize(message: String = null) extends IllegalArgumentException(message) - final class InvalidGlobalOffset extends IllegalArgumentException + final class InvalidGlobalOffset(message: String = null) extends IllegalArgumentException(message) - final class InvalidEventWaitList extends IllegalArgumentException + final class InvalidEventWaitList(message: String = null) extends IllegalArgumentException(message) - final class InvalidEvent extends IllegalArgumentException + final class InvalidEvent(message: String = null) extends IllegalArgumentException(message) - final class InvalidOperation extends IllegalArgumentException + final class InvalidOperation(message: String = null) extends IllegalArgumentException(message) - final class InvalidBufferSize extends IllegalArgumentException + final class InvalidBufferSize(message: String = null) extends IllegalArgumentException(message) - final class InvalidGlobalWorkSize extends IllegalArgumentException + final class InvalidGlobalWorkSize(message: String = null) extends IllegalArgumentException(message) final class UnknownErrorCode(errorCode: Int) extends IllegalStateException(s"Unknown error code: $errorCode") @@ -265,17 +288,10 @@ object OpenCL { } } - trait UseFirstPlatform { + trait UseFirstPlatform extends OpenCL { @transient - protected lazy val platformId: Long = { - val stack = stackPush() - try { - val platformIdBuffer = stack.mallocPointer(1) - checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer)) - platformIdBuffer.get(0) - } finally { - stack.close() - } + protected lazy val platformId: PlatformId = { + platformIds.head } } @@ -283,7 +299,7 @@ object OpenCL { @transient protected lazy val deviceIds: Seq[DeviceId] = { - deviceIdsByType(CL_DEVICE_TYPE_ALL) + platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL) } } @@ -292,7 +308,7 @@ object OpenCL { @transient protected lazy val deviceIds: Seq[DeviceId] = { - val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_ALL) + val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL) Seq(allDeviceIds.head) } @@ -302,7 +318,7 @@ object OpenCL { @transient protected lazy val deviceIds: Seq[DeviceId] = { - deviceIdsByType(CL_DEVICE_TYPE_GPU) + platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU) } } @@ -310,7 +326,7 @@ object OpenCL { @transient protected lazy val deviceIds: Seq[DeviceId] = { - val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_GPU) + val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU) Seq(allDeviceIds.head) } } @@ -318,7 +334,7 @@ object OpenCL { @transient protected lazy val deviceIds: Seq[DeviceId] = { - val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_CPU) + val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU) Seq(allDeviceIds.head) } } @@ -327,7 +343,7 @@ object OpenCL { @transient protected lazy val deviceIds: Seq[DeviceId] = { - deviceIdsByType(CL_DEVICE_TYPE_CPU) + platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU) } } @@ -1010,20 +1026,18 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton type Event = OpenCL.Event[this.type] type CommandQueue = OpenCL.CommandQueue[this.type] type DeviceId = OpenCL.DeviceId[this.type] + type PlatformId = OpenCL.PlatformId[this.type] - protected final def deviceIdsByType(deviceType: Int): Seq[DeviceId] = { - val Array(numberOfDevices) = { - val a = Array(0) - checkErrorCode(clGetDeviceIDs(platformId, deviceType, null, a)) - a - } + def platformIds: Seq[PlatformId] = { val stack = stackPush() try { - val deviceIdBuffer = stack.mallocPointer(numberOfDevices) - checkErrorCode(clGetDeviceIDs(platformId, deviceType, deviceIdBuffer, null: IntBuffer)) - for (i <- 0 until deviceIdBuffer.capacity()) yield { - val deviceId = deviceIdBuffer.get(i) - new DeviceId(deviceId) + val numberOfPlatformsBuffer = stack.mallocInt(1) + checkErrorCode(clGetPlatformIDs(null, numberOfPlatformsBuffer)) + val numberOfPlatforms = numberOfPlatformsBuffer.get(0) + val platformIdBuffer = stack.mallocPointer(numberOfPlatforms) + checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer)) + (0 until numberOfPlatforms).map { i => + new PlatformId(platformIdBuffer.get(i)) } } finally { stack.close() @@ -1174,12 +1188,12 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton import OpenCL._ - protected val platformId: Long + protected val platformId: PlatformId protected val deviceIds: Seq[DeviceId] @transient protected lazy val platformCapabilities: CLCapabilities = { - CL.createPlatformCapabilities(platformId) + CL.createPlatformCapabilities(platformId.handle) } protected def createCommandQueue(deviceId: DeviceId, properties: Map[Int, Long]): CommandQueue = new CommandQueue( @@ -1187,7 +1201,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton val cl20Properties = (properties.view.flatMap { case (key, value) => Seq(key, value) } ++ Seq(0L)).toArray val a = Array(0) val commandQueue = - clCreateCommandQueueWithProperties(platformId, deviceId.handle, cl20Properties, a) + clCreateCommandQueueWithProperties(platformId.handle, deviceId.handle, cl20Properties, a) checkErrorCode(a(0)) commandQueue } else { @@ -1211,7 +1225,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton val stack = stackPush() try { val errorCodeBuffer = stack.ints(CL_SUCCESS) - val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId, 0) + val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId.handle, 0) val deviceIdBuffer = stack.pointers(deviceIds.view.map(_.handle): _*) val context = clCreateContext(contextProperties, diff --git a/benchmarks/build.sbt b/benchmarks/build.sbt index 74c83fa9..4385c009 100644 --- a/benchmarks/build.sbt +++ b/benchmarks/build.sbt @@ -2,9 +2,11 @@ enablePlugins(JmhPlugin) libraryDependencies += "org.nd4j" % "nd4j-api" % "0.8.0" -libraryDependencies += "org.nd4j" % "nd4j-cuda-8.0-platform" % "0.8.0" +val nd4jRuntime = settingKey[String]("\"cuda-8.0\" to run benchmark on GPU, \"native\" to run benchmark on CPU.") -libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.8.0" +nd4jRuntime in Global := "native" + +libraryDependencies += "org.nd4j" % s"nd4j-${nd4jRuntime.value}-platform" % "0.8.0" libraryDependencies += ("org.lwjgl" % "lwjgl" % "3.1.6").jar().classifier { import scala.util.Properties._ diff --git a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala index 23d79945..e2ce6694 100644 --- a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala +++ b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala @@ -1,5 +1,6 @@ package com.thoughtworks.compute +import com.thoughtworks.compute.OpenCL.Exceptions.DeviceNotFound import com.thoughtworks.compute.benchmarks.RandomNormalState import com.thoughtworks.feature.Factory import com.thoughtworks.future._ @@ -8,7 +9,7 @@ import com.thoughtworks.raii.asynchronous._ import com.thoughtworks.raii.covariant._ import com.thoughtworks.tryt.covariant._ import com.typesafe.scalalogging.StrictLogging -import org.lwjgl.opencl.CLCapabilities +import org.lwjgl.opencl.{CL10, CLCapabilities} import org.lwjgl.system.Configuration import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.convolution.Convolution @@ -23,6 +24,54 @@ import scala.util.Try object benchmarks { + trait TensorState { + @Param(Array("CPU", "GPU")) + protected var tensorDeviceType: String = _ + + trait BenchmarkTensors + extends StrictLogging + with Tensors.UnsafeMathOptimizations + with Tensors.SuppressWarnings + with OpenCL.LogContextNotification + with OpenCL.GlobalExecutionContext + with OpenCL.CommandQueuePool + with OpenCL.DontReleaseEventTooEarly + with Tensors.WangHashingRandomNumberGenerator { + @transient + protected lazy val (platformId: PlatformId, deviceIds: Seq[DeviceId]) = { + val deviceType = classOf[CL10].getField(s"CL_DEVICE_TYPE_$tensorDeviceType").get(null).asInstanceOf[Int] + + object MatchDeviceType { + def unapply(platformId: PlatformId): Option[Seq[DeviceId]] = { + (try { + platformId.deviceIdsByType(deviceType) + } catch { + case e: DeviceNotFound => + return None + }) match { + case devices if devices.nonEmpty => + Some(devices) + case _ => + None + } + + } + } + + platformIds.collectFirst { + case platformId @ MatchDeviceType(deviceIds) => + (platformId, deviceIds) + } match { + case None => + throw new DeviceNotFound(s"$tensorDeviceType device is not found") + case Some(pair) => + pair + } + } + + } + } + @Threads(value = Threads.MAX) @State(Scope.Benchmark) class Nd4jTanh extends TanhState { @@ -47,18 +96,8 @@ object benchmarks { @Threads(value = Threads.MAX) @State(Scope.Benchmark) - class TensorTanh extends TanhState { - trait Benchmarks - extends StrictLogging - with Tensors.UnsafeMathOptimizations - with Tensors.SuppressWarnings - with OpenCL.LogContextNotification - with OpenCL.GlobalExecutionContext - with OpenCL.UseAllCpuDevices - with OpenCL.UseFirstPlatform - with OpenCL.CommandQueuePool - with OpenCL.DontReleaseEventTooEarly - with Tensors.WangHashingRandomNumberGenerator { + class TensorTanh extends TanhState with TensorState { + trait Benchmarks extends BenchmarkTensors { protected val numberOfCommandQueuesPerDevice: Int = 2 @@ -130,17 +169,8 @@ object benchmarks { @Threads(value = Threads.MAX) @State(Scope.Benchmark) - class TensorSum extends SumState { - trait Benchmarks - extends StrictLogging - with Tensors.UnsafeMathOptimizations - with OpenCL.LogContextNotification - with OpenCL.GlobalExecutionContext - with OpenCL.UseAllCpuDevices - with OpenCL.UseFirstPlatform - with OpenCL.CommandQueuePool - with OpenCL.DontReleaseEventTooEarly - with Tensors.WangHashingRandomNumberGenerator { + class TensorSum extends SumState with TensorState { + trait Benchmarks extends BenchmarkTensors { protected val numberOfCommandQueuesPerDevice: Int = 2 @@ -200,17 +230,8 @@ object benchmarks { @Threads(value = Threads.MAX) @State(Scope.Benchmark) - class TensorRandomNormal extends RandomNormalState { - trait Benchmarks - extends StrictLogging - with Tensors.UnsafeMathOptimizations - with OpenCL.LogContextNotification - with OpenCL.GlobalExecutionContext - with OpenCL.UseAllCpuDevices - with OpenCL.UseFirstPlatform - with OpenCL.CommandQueuePool - with OpenCL.DontReleaseEventTooEarly - with Tensors.WangHashingRandomNumberGenerator { + class TensorRandomNormal extends RandomNormalState with TensorState { + trait Benchmarks extends BenchmarkTensors { protected val numberOfCommandQueuesPerDevice: Int = 2 @@ -295,19 +316,9 @@ object benchmarks { @Threads(value = Threads.MAX) @State(Scope.Benchmark) - class TensorConvolution extends ConvolutionState { + class TensorConvolution extends ConvolutionState with TensorState { - trait Benchmarks - extends StrictLogging - with Tensors.UnsafeMathOptimizations - with OpenCL.LogContextNotification - with OpenCL.GlobalExecutionContext - with OpenCL.UseAllCpuDevices - with OpenCL.UseFirstPlatform - with OpenCL.CommandQueuePool - with OpenCL.DontReleaseEventTooEarly - with Tensors.WangHashingRandomNumberGenerator - with ConvolutionTensors { + trait Benchmarks extends BenchmarkTensors with ConvolutionTensors { protected val numberOfCommandQueuesPerDevice = 2