diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index 1d8f6e1b..bc653ceb 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -937,10 +937,10 @@ trait Tensors extends OpenCL { /** * @group delayed */ - def split(dimension: Int): IndexedSeq[Tensor] = { + def split(dimension: Int): IndexedSeq[TransformedTensor] = { // TODO: override map/reduce to produce less OpenCL C code val newShape = shape.patch(dimension, Nil, 1) - final class TensorSeq extends IndexedSeq[Tensor] { + final class TensorSeq extends IndexedSeq[TransformedTensor] { override def stringPrefix = "TensorSeq" diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 9ac0e89c..1e7693c5 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -307,7 +307,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { val Array(`j`, k) = matrix2.shape val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k)) - product.split(1).reduce(_ + _) + product.split(1).reduce[Tensor](_ + _) } diff --git a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala index 137af16f..4c217e79 100644 --- a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala +++ b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala @@ -100,7 +100,7 @@ object benchmarks { // unroll only j val Array(`j`, k) = matrix2.shape val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k)) - product.split(1).reduce(_ + _) + product.split(1).reduce[Tensor](_ + _) } }