From 2af1966315d4cff086d6cb6ecc8c0b1d963231de Mon Sep 17 00:00:00 2001 From: Yang Bo Date: Thu, 29 Mar 2018 18:14:16 +0800 Subject: [PATCH] Split should return TransformedTensors --- Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala | 4 ++-- .../src/test/scala/com/thoughtworks/compute/TensorsSpec.scala | 2 +- .../src/jmh/scala/com/thoughtworks/compute/benchmarks.scala | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index a21a944f..9851a819 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -937,10 +937,10 @@ trait Tensors extends OpenCL { /** * @group delayed */ - def split(dimension: Int): IndexedSeq[Tensor] = { + def split(dimension: Int): IndexedSeq[TransformedTensor] = { // TODO: override map/reduce to produce less OpenCL C code val newShape = shape.patch(dimension, Nil, 1) - new IndexedSeq[Tensor] { + new IndexedSeq[TransformedTensor] { val length: Int = shape(dimension) diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 9ac0e89c..1e7693c5 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -307,7 +307,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { val Array(`j`, k) = matrix2.shape val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k)) - product.split(1).reduce(_ + _) + product.split(1).reduce[Tensor](_ + _) } diff --git a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala index 137af16f..4c217e79 100644 --- a/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala +++ b/benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala @@ -100,7 +100,7 @@ object benchmarks { // unroll only j val Array(`j`, k) = matrix2.shape val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k)) - product.split(1).reduce(_ + _) + product.split(1).reduce[Tensor](_ + _) } }