Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 92 additions & 78 deletions OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,29 @@ import scala.language.higherKinds
*/
object OpenCL {

final case class PlatformId[Owner <: Singleton with OpenCL](handle: Long) extends AnyVal {

final def deviceIdsByType(deviceType: Int): Seq[DeviceId[Owner]] = {
val Array(numberOfDevices) = {
val a = Array(0)
checkErrorCode(clGetDeviceIDs(handle, deviceType, null, a))
a
}
val stack = stackPush()
try {
val deviceIdBuffer = stack.mallocPointer(numberOfDevices)
checkErrorCode(clGetDeviceIDs(handle, deviceType, deviceIdBuffer, null: IntBuffer))
for (i <- 0 until numberOfDevices) yield {
val deviceId = deviceIdBuffer.get(i)
new DeviceId[Owner](deviceId)
}
} finally {
stack.close()
}
}

}

/** Returns a [[String]] for the C string `address`.
*
* @note We don't know the exact charset of the C string. Use [[memASCII]] because lwjgl treats them as ASCII.
Expand Down Expand Up @@ -96,33 +119,33 @@ object OpenCL {
}
})
object Exceptions {
final class MisalignedSubBufferOffset extends IllegalArgumentException
final class MisalignedSubBufferOffset(message: String = null) extends IllegalArgumentException(message)

final class ExecStatusErrorForEventsInWaitList extends IllegalArgumentException
final class ExecStatusErrorForEventsInWaitList(message: String = null) extends IllegalArgumentException(message)

final class InvalidProperty extends IllegalArgumentException
final class InvalidProperty(message: String = null) extends IllegalArgumentException(message)

final class PlatformNotFoundKhr extends IllegalStateException
final class PlatformNotFoundKhr(message: String = null) extends NoSuchElementException(message)

final class DeviceNotFound extends IllegalArgumentException
final class DeviceNotFound(message: String = null) extends NoSuchElementException(message)

final class DeviceNotAvailable extends IllegalStateException
final class DeviceNotAvailable(message: String = null) extends IllegalStateException(message)

final class CompilerNotAvailable extends IllegalStateException
final class CompilerNotAvailable(message: String = null) extends IllegalStateException(message)

final class MemObjectAllocationFailure extends IllegalStateException
final class MemObjectAllocationFailure(message: String = null) extends IllegalStateException(message)

final class OutOfResources extends IllegalStateException
final class OutOfResources(message: String = null) extends IllegalStateException(message)

final class OutOfHostMemory extends IllegalStateException
final class OutOfHostMemory(message: String = null) extends IllegalStateException(message)

final class ProfilingInfoNotAvailable extends IllegalStateException
final class ProfilingInfoNotAvailable(message: String = null) extends IllegalStateException(message)

final class MemCopyOverlap extends IllegalStateException
final class MemCopyOverlap(message: String = null) extends IllegalStateException(message)

final class ImageFormatMismatch extends IllegalStateException
final class ImageFormatMismatch(message: String = null) extends IllegalStateException(message)

final class ImageFormatNotSupported extends IllegalStateException
final class ImageFormatNotSupported(message: String = null) extends IllegalStateException(message)

final class BuildProgramFailure(buildLogs: Map[Long /* device id */, String] = Map.empty)
extends IllegalStateException({
Expand All @@ -134,71 +157,71 @@ object OpenCL {
.mkString("\n")
})

final class MapFailure extends IllegalStateException
final class MapFailure(message: String = null) extends IllegalStateException(message)

final class InvalidValue extends IllegalArgumentException
final class InvalidValue(message: String = null) extends IllegalArgumentException(message)

final class InvalidDeviceType extends IllegalArgumentException
final class InvalidDeviceType(message: String = null) extends IllegalArgumentException(message)

final class InvalidPlatform extends IllegalArgumentException
final class InvalidPlatform(message: String = null) extends IllegalArgumentException(message)

final class InvalidDevice extends IllegalArgumentException
final class InvalidDevice(message: String = null) extends IllegalArgumentException(message)

final class InvalidContext extends IllegalArgumentException
final class InvalidContext(message: String = null) extends IllegalArgumentException(message)

final class InvalidQueueProperties extends IllegalArgumentException
final class InvalidQueueProperties(message: String = null) extends IllegalArgumentException(message)

final class InvalidCommandQueue extends IllegalArgumentException
final class InvalidCommandQueue(message: String = null) extends IllegalArgumentException(message)

final class InvalidHostPtr extends IllegalArgumentException
final class InvalidHostPtr(message: String = null) extends IllegalArgumentException(message)

final class InvalidMemObject extends IllegalArgumentException
final class InvalidMemObject(message: String = null) extends IllegalArgumentException(message)

final class InvalidImageFormatDescriptor extends IllegalArgumentException
final class InvalidImageFormatDescriptor(message: String = null) extends IllegalArgumentException(message)

final class InvalidImageSize extends IllegalArgumentException
final class InvalidImageSize(message: String = null) extends IllegalArgumentException(message)

final class InvalidSampler extends IllegalArgumentException
final class InvalidSampler(message: String = null) extends IllegalArgumentException(message)

final class InvalidBinary extends IllegalArgumentException
final class InvalidBinary(message: String = null) extends IllegalArgumentException(message)

final class InvalidBuildOptions extends IllegalArgumentException
final class InvalidBuildOptions(message: String = null) extends IllegalArgumentException(message)

final class InvalidProgram extends IllegalArgumentException
final class InvalidProgram(message: String = null) extends IllegalArgumentException(message)

final class InvalidProgramExecutable extends IllegalArgumentException
final class InvalidProgramExecutable(message: String = null) extends IllegalArgumentException(message)

final class InvalidKernelName extends IllegalArgumentException
final class InvalidKernelName(message: String = null) extends IllegalArgumentException(message)

final class InvalidKernelDefinition extends IllegalArgumentException
final class InvalidKernelDefinition(message: String = null) extends IllegalArgumentException(message)

final class InvalidKernel extends IllegalArgumentException
final class InvalidKernel(message: String = null) extends IllegalArgumentException(message)

final class InvalidArgIndex extends IllegalArgumentException
final class InvalidArgIndex(message: String = null) extends IllegalArgumentException(message)

final class InvalidArgValue extends IllegalArgumentException
final class InvalidArgValue(message: String = null) extends IllegalArgumentException(message)

final class InvalidArgSize extends IllegalArgumentException
final class InvalidArgSize(message: String = null) extends IllegalArgumentException(message)

final class InvalidKernelArgs extends IllegalArgumentException
final class InvalidKernelArgs(message: String = null) extends IllegalArgumentException(message)

final class InvalidWorkDimension extends IllegalArgumentException
final class InvalidWorkDimension(message: String = null) extends IllegalArgumentException(message)

final class InvalidWorkGroupSize extends IllegalArgumentException
final class InvalidWorkGroupSize(message: String = null) extends IllegalArgumentException(message)

final class InvalidWorkItemSize extends IllegalArgumentException
final class InvalidWorkItemSize(message: String = null) extends IllegalArgumentException(message)

final class InvalidGlobalOffset extends IllegalArgumentException
final class InvalidGlobalOffset(message: String = null) extends IllegalArgumentException(message)

final class InvalidEventWaitList extends IllegalArgumentException
final class InvalidEventWaitList(message: String = null) extends IllegalArgumentException(message)

final class InvalidEvent extends IllegalArgumentException
final class InvalidEvent(message: String = null) extends IllegalArgumentException(message)

final class InvalidOperation extends IllegalArgumentException
final class InvalidOperation(message: String = null) extends IllegalArgumentException(message)

final class InvalidBufferSize extends IllegalArgumentException
final class InvalidBufferSize(message: String = null) extends IllegalArgumentException(message)

final class InvalidGlobalWorkSize extends IllegalArgumentException
final class InvalidGlobalWorkSize(message: String = null) extends IllegalArgumentException(message)

final class UnknownErrorCode(errorCode: Int) extends IllegalStateException(s"Unknown error code: $errorCode")

Expand Down Expand Up @@ -265,25 +288,18 @@ object OpenCL {
}
}

trait UseFirstPlatform {
trait UseFirstPlatform extends OpenCL {
@transient
protected lazy val platformId: Long = {
val stack = stackPush()
try {
val platformIdBuffer = stack.mallocPointer(1)
checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer))
platformIdBuffer.get(0)
} finally {
stack.close()
}
protected lazy val platformId: PlatformId = {
platformIds.head
}
}

trait UseAllDevices extends OpenCL {

@transient
protected lazy val deviceIds: Seq[DeviceId] = {
deviceIdsByType(CL_DEVICE_TYPE_ALL)
platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL)
}

}
Expand All @@ -292,7 +308,7 @@ object OpenCL {

@transient
protected lazy val deviceIds: Seq[DeviceId] = {
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_ALL)
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_ALL)
Seq(allDeviceIds.head)
}

Expand All @@ -302,23 +318,23 @@ object OpenCL {

@transient
protected lazy val deviceIds: Seq[DeviceId] = {
deviceIdsByType(CL_DEVICE_TYPE_GPU)
platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU)
}
}

trait UseFirstGpuDevice extends OpenCL {

@transient
protected lazy val deviceIds: Seq[DeviceId] = {
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_GPU)
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_GPU)
Seq(allDeviceIds.head)
}
}
trait UseFirstCpuDevice extends OpenCL {

@transient
protected lazy val deviceIds: Seq[DeviceId] = {
val allDeviceIds = deviceIdsByType(CL_DEVICE_TYPE_CPU)
val allDeviceIds = platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU)
Seq(allDeviceIds.head)
}
}
Expand All @@ -327,7 +343,7 @@ object OpenCL {

@transient
protected lazy val deviceIds: Seq[DeviceId] = {
deviceIdsByType(CL_DEVICE_TYPE_CPU)
platformId.deviceIdsByType(CL_DEVICE_TYPE_CPU)
}
}

Expand Down Expand Up @@ -1010,20 +1026,18 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
type Event = OpenCL.Event[this.type]
type CommandQueue = OpenCL.CommandQueue[this.type]
type DeviceId = OpenCL.DeviceId[this.type]
type PlatformId = OpenCL.PlatformId[this.type]

protected final def deviceIdsByType(deviceType: Int): Seq[DeviceId] = {
val Array(numberOfDevices) = {
val a = Array(0)
checkErrorCode(clGetDeviceIDs(platformId, deviceType, null, a))
a
}
def platformIds: Seq[PlatformId] = {
val stack = stackPush()
try {
val deviceIdBuffer = stack.mallocPointer(numberOfDevices)
checkErrorCode(clGetDeviceIDs(platformId, deviceType, deviceIdBuffer, null: IntBuffer))
for (i <- 0 until deviceIdBuffer.capacity()) yield {
val deviceId = deviceIdBuffer.get(i)
new DeviceId(deviceId)
val numberOfPlatformsBuffer = stack.mallocInt(1)
checkErrorCode(clGetPlatformIDs(null, numberOfPlatformsBuffer))
val numberOfPlatforms = numberOfPlatformsBuffer.get(0)
val platformIdBuffer = stack.mallocPointer(numberOfPlatforms)
checkErrorCode(clGetPlatformIDs(platformIdBuffer, null: IntBuffer))
(0 until numberOfPlatforms).map { i =>
new PlatformId(platformIdBuffer.get(i))
}
} finally {
stack.close()
Expand Down Expand Up @@ -1174,20 +1188,20 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton

import OpenCL._

protected val platformId: Long
protected val platformId: PlatformId
protected val deviceIds: Seq[DeviceId]

@transient
protected lazy val platformCapabilities: CLCapabilities = {
CL.createPlatformCapabilities(platformId)
CL.createPlatformCapabilities(platformId.handle)
}

protected def createCommandQueue(deviceId: DeviceId, properties: Map[Int, Long]): CommandQueue = new CommandQueue(
if (deviceCapabilities(deviceId).OpenCL20) {
val cl20Properties = (properties.view.flatMap { case (key, value) => Seq(key, value) } ++ Seq(0L)).toArray
val a = Array(0)
val commandQueue =
clCreateCommandQueueWithProperties(platformId, deviceId.handle, cl20Properties, a)
clCreateCommandQueueWithProperties(platformId.handle, deviceId.handle, cl20Properties, a)
checkErrorCode(a(0))
commandQueue
} else {
Expand All @@ -1211,7 +1225,7 @@ trait OpenCL extends MonadicCloseable[UnitContinuation] with ImplicitsSingleton
val stack = stackPush()
try {
val errorCodeBuffer = stack.ints(CL_SUCCESS)
val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId, 0)
val contextProperties = stack.pointers(CL_CONTEXT_PLATFORM, platformId.handle, 0)
val deviceIdBuffer = stack.pointers(deviceIds.view.map(_.handle): _*)
val context =
clCreateContext(contextProperties,
Expand Down
6 changes: 4 additions & 2 deletions benchmarks/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ enablePlugins(JmhPlugin)

libraryDependencies += "org.nd4j" % "nd4j-api" % "0.8.0"

libraryDependencies += "org.nd4j" % "nd4j-cuda-8.0-platform" % "0.8.0"
val nd4jRuntime = settingKey[String]("\"cuda-8.0\" to run benchmark on GPU, \"native\" to run benchmark on CPU.")

libraryDependencies += "org.nd4j" % "nd4j-native-platform" % "0.8.0"
nd4jRuntime in Global := "native"

libraryDependencies += "org.nd4j" % s"nd4j-${nd4jRuntime.value}-platform" % "0.8.0"

libraryDependencies += ("org.lwjgl" % "lwjgl" % "3.1.6").jar().classifier {
import scala.util.Properties._
Expand Down
Loading