From 7e9286b49becc61b5baa0a5756d18fd4366fa6d5 Mon Sep 17 00:00:00 2001 From: Yang Bo Date: Tue, 6 Feb 2018 09:48:02 +0800 Subject: [PATCH 1/2] Add split implementation --- .../com/thoughtworks/compute/Tensors.scala | 39 ++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index 586bbc28..8dd0a167 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -163,7 +163,44 @@ trait Tensors extends OpenCL { } } - def split(dimension: Int): Seq[Tensor] = ??? + def split(dimension: Int): IndexedSeq[Tensor] = { + // TODO: override map/reduce to produce less OpenCL C code + new IndexedSeq[Tensor] { + + private val newShape = shape.patch(dimension, Nil, 1) + + val length: Int = shape(dimension) + + def apply(index: Int): TransformedTensor = { + val length = shape.length + val matrix = Array.ofDim[Double](length * length) + + @tailrec + def loopBefore(i: Int): Unit = { + if (i < dimension) { + matrix(i * length + i) = 1.0 + loopBefore(i + 1) + } + } + loopBefore(0) + + matrix(dimension * length + length - 1) = index.toDouble + + @tailrec + def loopAfter(i: Int): Unit = { + if (i < length) { + matrix(i * length + i - 1) = 1.0 + loopAfter(i + 1) + } + } + loopAfter(dimension + 1) + + // TODO: Cache the transformed tensor + transform(newShape, matrix) + } + + } + } // def debuggingInformation: Implicitly[DebuggingInformation] From bc2c539769288b852230298f59ff6aab97e6093a Mon Sep 17 00:00:00 2001 From: Yang Bo Date: Tue, 6 Feb 2018 10:06:07 +0800 Subject: [PATCH 2/2] Add permute implementation --- .../com/thoughtworks/compute/Tensors.scala | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index 8dd0a167..c986d464 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -163,6 +163,25 @@ trait Tensors extends OpenCL { } } + def permute(dimensions: Array[Int]): TransformedTensor = { + val length = shape.length + if (dimensions.length != length) { + throw new IllegalArgumentException + } + val newShape = Array.ofDim[Int](length) + val matrix = Array.ofDim[Double](length * (length + 1)) + @tailrec def loop(newDimensionIndex: Int): Unit = { + if (newDimensionIndex < length) { + val oldDimensionIndex = dimensions(newDimensionIndex) + newShape(newDimensionIndex) = shape(oldDimensionIndex) + matrix(oldDimensionIndex * (length + 1) + newDimensionIndex) = 1.0 + loop(newDimensionIndex + 1) + } + } + loop(0) + transform(newShape, matrix) + } + def split(dimension: Int): IndexedSeq[Tensor] = { // TODO: override map/reduce to produce less OpenCL C code new IndexedSeq[Tensor] {