Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ object Expressions {
def exp(operand: FloatTerm): FloatTerm
def abs(operand: FloatTerm): FloatTerm
def sqrt(operand: FloatTerm): FloatTerm
def tanh(operand: FloatTerm): FloatTerm
}

/** @template */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,14 @@ trait OpenCLKernelBuilder extends AllExpressions {
"""
float.termFactory.newInstance(valueTermName)
}

def tanh(operand0: FloatTerm): FloatTerm = {
val valueTermName = freshName("")
localDefinitions += fastraw"""
const ${operand0.typeCode} $valueTermName = tanh(${operand0.termCode});
"""
float.termFactory.newInstance(valueTermName)
}
}

type FloatType <: (ValueType with Any) with ClFloatType
Expand Down
4 changes: 4 additions & 0 deletions Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,10 @@ trait Tensors extends OpenCL {
leftHandSide.derivedTensor(trees.float.sqrt(leftHandSide.closure.asInstanceOf[FloatTerm]))
}

def tanh(leftHandSide: Tensor): Tensor = {
leftHandSide.derivedTensor(trees.float.tanh(leftHandSide.closure.asInstanceOf[FloatTerm]))
}

def exp(leftHandSide: Tensor): Tensor = {
leftHandSide.derivedTensor(trees.float.exp(leftHandSide.closure.asInstanceOf[FloatTerm]))
}
Expand Down
20 changes: 20 additions & 0 deletions Trees/src/main/scala/com/thoughtworks/compute/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,21 @@ object Trees {
}
}

/** @group AST */
@(silent @companionObject)
final case class Tanh(operand0: FloatTree) extends FloatOperator {

protected def erasedExport(foreignCategory: Category, context: ExportContext) = {
context.asScala.getOrElseUpdate(this, foreignCategory.float.tanh(operand0.export(foreignCategory, context)))
}

protected def erasedAlphaConversion(context: AlphaConversionContext): Tree = {
def converted = copy(operand0 = operand0.alphaConversion(context))
context.asScala.getOrElseUpdate(this, converted)
}
}


/** @group AST */
@(silent @companionObject)
final case class Sqrt(operand0: FloatTree) extends FloatOperator {
Expand Down Expand Up @@ -619,6 +634,11 @@ object Trees {
term(Sqrt(operand.tree))
}

@inline
def tanh(operand: FloatTerm): FloatTerm = {
term(Tanh(operand.tree))
}

}

type FloatType <: (ValueType with Any) with FloatTreeType
Expand Down
40 changes: 25 additions & 15 deletions benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import com.thoughtworks.raii.covariant._
import com.thoughtworks.tryt.covariant._
import com.typesafe.scalalogging.StrictLogging
import org.lwjgl.opencl.CLCapabilities
import org.lwjgl.system.Configuration
import org.nd4j.linalg.api.ndarray.INDArray
import org.nd4j.linalg.convolution.Convolution
import org.nd4j.linalg.factory.Nd4j
Expand All @@ -24,24 +25,29 @@ object benchmarks {

@Threads(value = Threads.MAX)
@State(Scope.Benchmark)
class Nd4jSigmoid extends SigmoidState {
class Nd4jTanh extends TanhState {

@transient
private lazy val input = Nd4j.randn(Array.fill(numberOfDimensions)(size))
private def sigmoid(x: INDArray): INDArray = {
private def tanh(x: INDArray): INDArray = {
val expX = Transforms.exp(x)
expX.div(expX.add(1.0))
}
@Benchmark
final def nd4jSigmoidBenchmark(): Array[Float] = {
sigmoid(input).data().asFloat()
final def nd4jTanhBenchmark(): Array[Float] = {
(0 until numberOfIterations)
.foldLeft(input) { (input, _) =>
Transforms.tanh(input)
}
.data()
.asFloat()
}

}

@Threads(value = Threads.MAX)
@State(Scope.Benchmark)
class TensorSigmoid extends SigmoidState {
class TensorTanh extends TanhState {
trait Benchmarks
extends StrictLogging
with Tensors.UnsafeMathOptimizations
Expand All @@ -56,18 +62,18 @@ object benchmarks {

protected val numberOfCommandQueuesPerDevice: Int = 2

private def sigmoid(x: Tensor): Tensor = {
val expX = Tensor.exp(x)
expX / (expX + Tensor.fill(1.0f, expX.shape))
}

def doBenchmark(): Do[() => Array[Float]] = {
val input = Tensor.randomNormal(Array.fill(numberOfDimensions)(size))

input.doBuffer.map { _ =>
{ () =>
sigmoid(input).flatArray.run.blockingAwait

(0 until numberOfIterations)
.foldLeft(input) { (input, _) =>
Tensor.tanh(input)
}
.flatArray
.run
.blockingAwait
}
}
}
Expand All @@ -77,6 +83,7 @@ object benchmarks {

@Setup
final def setup(): Unit = {
// Configuration.OPENCL_LIBRARY_NAME.set("/opt/pocl-1.1/lib/libOpenCL.dylib")
assert(benchmarkResouce == null)
val Do(TryT(ResourceT(resourceContinuation))) =
Do.monadicCloseable(Factory[Benchmarks].newInstance()).flatMap(_.doBenchmark())
Expand All @@ -91,14 +98,17 @@ object benchmarks {
}

@Benchmark
final def tensorSigmoidBenchmark(): Array[Float] = {
final def tensorTanhBenchmark(): Array[Float] = {
benchmarkResouce.value.get.apply()
}

}

trait SigmoidState {
@Param(Array("3", "2", "1"))
trait TanhState {
@Param(Array("100", "10", "1"))
protected var numberOfIterations: Int = _

@Param(Array("2", "3", "1"))
protected var numberOfDimensions: Int = _

@Param(Array("128", "64", "32", "16"))
Expand Down