Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ trait Expressions {

type Term <: TermApi

protected trait TypeApi extends ExpressionApi {}
protected trait TypeApi extends ExpressionApi

type Type <: TypeApi

Expand Down
4 changes: 1 addition & 3 deletions OpenCL/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ val lwjglNatives: String = {

libraryDependencies += "org.lwjgl" % "lwjgl-opencl" % "3.1.5"

libraryDependencies += "org.lwjgl" % "lwjgl" % "3.1.5"

libraryDependencies += "org.lwjgl" % "lwjgl" % "3.1.5" % Test classifier lwjglNatives
libraryDependencies += ("org.lwjgl" % "lwjgl" % "3.1.5" % Test).classifier(lwjglNatives).jar()

libraryDependencies += "com.thoughtworks.raii" %% "asynchronous" % "3.0.0-M8"

Expand Down
6 changes: 3 additions & 3 deletions OpenCL/src/main/scala/com/thoughtworks/compute/OpenCL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -588,13 +588,13 @@ object OpenCL {
* will cause memory access error.
*
*/
final def toHostBuffer(implicit witnessOwner: Witness.Aux[Owner],
memory: Memory[Element]): Do[memory.HostBuffer] = {
final def toHostBuffer(preconditionEvents: Event[Owner]*)(implicit witnessOwner: Witness.Aux[Owner],
memory: Memory[Element]): Do[memory.HostBuffer] = {
Do(TryT(ResourceT(UnitContinuation.delay {
val hostBuffer = memory.allocate(length)
Resource(value = Success(hostBuffer), release = UnitContinuation.delay { memory.free(hostBuffer) })
}))).flatMap { hostBuffer =>
enqueueReadBuffer[memory.HostBuffer](hostBuffer)(witnessOwner, memory)
enqueueReadBuffer[memory.HostBuffer](hostBuffer, preconditionEvents: _*)(witnessOwner, memory)
.flatMap { event =>
Do.garbageCollected(event.waitForComplete())
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ object OpenCLKernelBuilder {

final case class ClTypeSymbol(firstDefinition: ClTypeDefinition, typeCode: ClTypeCode)

final class GlobalContext {
final class GlobalContext extends Fastring {

private var seed = 0

Expand All @@ -56,8 +56,8 @@ object OpenCLKernelBuilder {
name
}

val globalDeclarations = mutable.Buffer.empty[Fastring]
val globalDefinitions = mutable.Buffer.empty[Fastring]
protected[OpenCLKernelBuilder] val globalDeclarations = mutable.Buffer.empty[Fastring]
protected[OpenCLKernelBuilder] val globalDefinitions = mutable.Buffer.empty[Fastring]
private val typeSymbolCache = mutable.HashMap.empty[ClTypeDefinition, ClTypeSymbol]

val floatSymbol = cachedSymbol(FloatDefinition)
Expand All @@ -73,6 +73,10 @@ object OpenCLKernelBuilder {
typeSymbol
}

def foreach[U](f: String => U): Unit = {
globalDeclarations.foreach(_.foreach(f))
globalDefinitions.foreach(_.foreach(f))
}
}

}
Expand All @@ -98,7 +102,8 @@ trait OpenCLKernelBuilder extends FloatArrays {
val (outputParameters, outputAssignments) = outputs.map { output =>
val outputTermCode = output.termCode
val outputTypeCode = output.typeCode
val outputParameter = fast"global $outputTypeCode *output_$outputTermCode"
val outputId = freshName("output")
val outputParameter = fast"global $outputTypeCode *$outputId"
def outputIndex(dimension: Int): Fastring = {
if (dimension == 0) {
fast"get_global_id(0)"
Expand All @@ -108,11 +113,11 @@ trait OpenCLKernelBuilder extends FloatArrays {
}

val index = outputIndex(numberOfDimensions - 1)
val outputAssignment = fast"output_$outputTermCode[$index] = $outputTermCode;\n"
val outputAssignment = fast"$outputId[$index] = $outputTermCode;\n"
(outputParameter, outputAssignment)
}.unzip
fastraw"""
kernel void $functionName(${parameterDeclarations.mkFastring(", ")}, ${outputParameters.mkFastring(", ")}) {
kernel void $functionName(${(parameterDeclarations.view ++ outputParameters).mkFastring(", ")}) {
${localDefinitions.mkFastring}
${outputAssignments.mkFastring}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,7 @@ class OpenCLKernelBuilderSpec extends FreeSpec with Matchers {
),
Seq(f.tree.export(openCLFunctionContext, map)))

globalContext.globalDeclarations.foreach(print)
globalContext.globalDefinitions.foreach(print)
globalContext.foreach(print)
sourceCode.foreach(print)

// TODO: Convert this example to a test case
Expand Down
19 changes: 19 additions & 0 deletions Tensors/build.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,22 @@
scalacOptions += "-Ypartial-unification"

libraryDependencies += "com.google.guava" % "guava" % "23.6-jre"

libraryDependencies += "org.scalatest" %% "scalatest" % "3.0.4" % Test

val lwjglNatives: String = {
import scala.util.Properties._
if (isMac) {
"natives-macos"
} else if (isLinux) {
"natives-linux"
} else if (isWin) {
"natives-windows"
} else {
throw new MessageOnlyException(s"lwjgl does not support $osName")
}
}

libraryDependencies += ("org.lwjgl" % "lwjgl" % "3.1.5" % Test).classifier(lwjglNatives).jar()

fork in Test := true
27 changes: 20 additions & 7 deletions Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ trait Tensors extends OpenCL {
def event: Event

def buffer: DeviceBuffer[Float]

def toHostBuffer = {
buffer.toHostBuffer(event)
}
}
object Tensor {
def fill(value: Float, shape0: Array[Int]) = {
new InlineTensor {
val padding: Float = value
val shape: shape0.type = shape0
val closure: trees.FloatTerm = float.literal(value)
}
}
}

sealed trait Tensor { thisTensor =>
Expand All @@ -75,6 +88,7 @@ trait Tensors extends OpenCL {

def closure: ValueTerm

// TODO: rename to make buffer
def enqueue: Do[PendingBuffer]

def padding: Float
Expand All @@ -96,7 +110,7 @@ trait Tensors extends OpenCL {
})
}

protected val kernelCache: Cache[ValueTerm, CompiledKernel] = kernelCacheBuilder.build()
protected[compute] val kernelCache: Cache[ValueTerm, CompiledKernel] = kernelCacheBuilder.build()

protected implicit val executionContext: ExecutionContext

Expand Down Expand Up @@ -132,22 +146,21 @@ trait Tensors extends OpenCL {
def call(): CompiledKernel = {

val alphConversionContext = new AlphaConversionContext
val convertedTerm = closure.tree.alphaConversion(alphConversionContext).asInstanceOf[ValueTerm]
val convertedTree = closure.tree.alphaConversion(alphConversionContext)

val sourceCode = {
val globalContext = new GlobalContext
val functionContext = Factory[OpenCLKernelBuilder].newInstance(globalContext)

val exportContext = new ExportContext
val kernelBody = convertedTerm.tree.export(functionContext, exportContext)
val kernelBody = convertedTree.export(functionContext, exportContext).asInstanceOf[functionContext.Term]

val kernelParameters = upvalues(closure.tree).map { upvalue: Parameter =>
exportContext.get(alphConversionContext.get(upvalue)).asInstanceOf[functionContext.Term]
}
fastraw"""
${globalContext.globalDeclarations}
${globalContext.globalDefinitions}
${functionContext.generateKernelSourceCode("kernel", shape.length, kernelParameters, Seq(kernelBody))}
$globalContext
${functionContext.generateKernelSourceCode("jit_kernel", shape.length, kernelParameters, Seq(kernelBody))}
"""
}

Expand Down Expand Up @@ -188,7 +201,7 @@ trait Tensors extends OpenCL {

}
}
kernelCache.put(convertedTerm, compiledKernel)
kernelCache.put(float.factory.newInstance(convertedTree.asInstanceOf[float.Tree]), compiledKernel)
compiledKernel
}
}
Expand Down
30 changes: 26 additions & 4 deletions Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala
Original file line number Diff line number Diff line change
@@ -1,27 +1,49 @@
package com.thoughtworks.compute

import java.nio.ByteBuffer
import java.nio.{ByteBuffer, FloatBuffer}

import com.thoughtworks.feature.Factory
import TensorsSpec._
import com.thoughtworks.future._
import com.thoughtworks.raii.asynchronous._
import org.lwjgl.opencl.CLCapabilities

import scalaz.syntax.all._
import scala.language.existentials
import org.scalatest._

/**
* @author 杨博 (Yang Bo)
*/
class TensorsSpec {
private val hyperparameters =
class TensorsSpec extends AsyncFreeSpec with Matchers {
private val tensors: Tensors =
Factory[
OpenCL.GlobalExecutionContext with OpenCL.UseAllDevices with OpenCL.UseFirstPlatform with OpenCL.CommandQueuePool with Tensors]
.newInstance(
handleOpenCLNotification = handleOpenCLNotification,
numberOfCommandQueuesForDevice = { (deviceId: Long, capabilities: CLCapabilities) =>
1
5
}
)

"create a tensor of a constant" in {
val shape = Array(2, 3, 5)
val element = 42.0f
val zeros = tensors.Tensor.fill(element, shape)

for {
pendingBuffer <- zeros.enqueue
floatBuffer <- pendingBuffer.toHostBuffer
} yield {
for (i <- 0 until floatBuffer.capacity()) {
floatBuffer.get(i) should be(element)
}
floatBuffer.position() should be(0)
floatBuffer.limit() should be(shape.product)
floatBuffer.capacity() should be(shape.product)
}
}.run.toScalaFuture

}

object TensorsSpec {
Expand Down
Loading