diff --git a/Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala b/Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala index b44a17a8..9d4bbde9 100644 --- a/Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala +++ b/Expressions/src/main/scala/com/thoughtworks/compute/Expressions.scala @@ -120,6 +120,7 @@ object Expressions { def exp(operand: FloatTerm): FloatTerm def abs(operand: FloatTerm): FloatTerm def sqrt(operand: FloatTerm): FloatTerm + def tanh(operand: FloatTerm): FloatTerm } /** @template */ diff --git a/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala b/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala index edeca4c0..260db270 100644 --- a/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala +++ b/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala @@ -265,6 +265,14 @@ trait OpenCLKernelBuilder extends AllExpressions { """ float.termFactory.newInstance(valueTermName) } + + def tanh(operand0: FloatTerm): FloatTerm = { + val valueTermName = freshName("") + localDefinitions += fastraw""" + const ${operand0.typeCode} $valueTermName = tanh(${operand0.termCode}); + """ + float.termFactory.newInstance(valueTermName) + } } type FloatType <: (ValueType with Any) with ClFloatType diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index e82be69c..1e198608 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -502,6 +502,10 @@ trait Tensors extends OpenCL { leftHandSide.derivedTensor(trees.float.sqrt(leftHandSide.closure.asInstanceOf[FloatTerm])) } + def tanh(leftHandSide: Tensor): Tensor = { + leftHandSide.derivedTensor(trees.float.tanh(leftHandSide.closure.asInstanceOf[FloatTerm])) + } + def exp(leftHandSide: Tensor): Tensor = { leftHandSide.derivedTensor(trees.float.exp(leftHandSide.closure.asInstanceOf[FloatTerm])) } diff --git a/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala b/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala index 9fdbff2e..e029f98d 100644 --- a/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala +++ b/Trees/src/main/scala/com/thoughtworks/compute/Trees.scala @@ -421,6 +421,21 @@ object Trees { } } + /** @group AST */ + @(silent @companionObject) + final case class Tanh(operand0: FloatTree) extends FloatOperator { + + protected def erasedExport(foreignCategory: Category, context: ExportContext) = { + context.asScala.getOrElseUpdate(this, foreignCategory.float.tanh(operand0.export(foreignCategory, context))) + } + + protected def erasedAlphaConversion(context: AlphaConversionContext): Tree = { + def converted = copy(operand0 = operand0.alphaConversion(context)) + context.asScala.getOrElseUpdate(this, converted) + } + } + + /** @group AST */ @(silent @companionObject) final case class Sqrt(operand0: FloatTree) extends FloatOperator { @@ -619,6 +634,11 @@ object Trees { term(Sqrt(operand.tree)) } + @inline + def tanh(operand: FloatTerm): FloatTerm = { + term(Tanh(operand.tree)) + } + } type FloatType <: (ValueType with Any) with FloatTreeType diff --git a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala index 49e5c2d2..23d79945 100644 --- a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala +++ b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala @@ -9,6 +9,7 @@ import com.thoughtworks.raii.covariant._ import com.thoughtworks.tryt.covariant._ import com.typesafe.scalalogging.StrictLogging import org.lwjgl.opencl.CLCapabilities +import org.lwjgl.system.Configuration import org.nd4j.linalg.api.ndarray.INDArray import org.nd4j.linalg.convolution.Convolution import org.nd4j.linalg.factory.Nd4j @@ -24,24 +25,29 @@ object benchmarks { @Threads(value = Threads.MAX) @State(Scope.Benchmark) - class Nd4jSigmoid extends SigmoidState { + class Nd4jTanh extends TanhState { @transient private lazy val input = Nd4j.randn(Array.fill(numberOfDimensions)(size)) - private def sigmoid(x: INDArray): INDArray = { + private def tanh(x: INDArray): INDArray = { val expX = Transforms.exp(x) expX.div(expX.add(1.0)) } @Benchmark - final def nd4jSigmoidBenchmark(): Array[Float] = { - sigmoid(input).data().asFloat() + final def nd4jTanhBenchmark(): Array[Float] = { + (0 until numberOfIterations) + .foldLeft(input) { (input, _) => + Transforms.tanh(input) + } + .data() + .asFloat() } } @Threads(value = Threads.MAX) @State(Scope.Benchmark) - class TensorSigmoid extends SigmoidState { + class TensorTanh extends TanhState { trait Benchmarks extends StrictLogging with Tensors.UnsafeMathOptimizations @@ -56,18 +62,18 @@ object benchmarks { protected val numberOfCommandQueuesPerDevice: Int = 2 - private def sigmoid(x: Tensor): Tensor = { - val expX = Tensor.exp(x) - expX / (expX + Tensor.fill(1.0f, expX.shape)) - } - def doBenchmark(): Do[() => Array[Float]] = { val input = Tensor.randomNormal(Array.fill(numberOfDimensions)(size)) input.doBuffer.map { _ => { () => - sigmoid(input).flatArray.run.blockingAwait - + (0 until numberOfIterations) + .foldLeft(input) { (input, _) => + Tensor.tanh(input) + } + .flatArray + .run + .blockingAwait } } } @@ -77,6 +83,7 @@ object benchmarks { @Setup final def setup(): Unit = { +// Configuration.OPENCL_LIBRARY_NAME.set("/opt/pocl-1.1/lib/libOpenCL.dylib") assert(benchmarkResouce == null) val Do(TryT(ResourceT(resourceContinuation))) = Do.monadicCloseable(Factory[Benchmarks].newInstance()).flatMap(_.doBenchmark()) @@ -91,14 +98,17 @@ object benchmarks { } @Benchmark - final def tensorSigmoidBenchmark(): Array[Float] = { + final def tensorTanhBenchmark(): Array[Float] = { benchmarkResouce.value.get.apply() } } - trait SigmoidState { - @Param(Array("3", "2", "1")) + trait TanhState { + @Param(Array("100", "10", "1")) + protected var numberOfIterations: Int = _ + + @Param(Array("2", "3", "1")) protected var numberOfDimensions: Int = _ @Param(Array("128", "64", "32", "16"))