From 2b6f87bb6af69d1da90de5705e647ee78e2e8a30 Mon Sep 17 00:00:00 2001 From: Yang Bo Date: Mon, 26 Mar 2018 17:10:55 +0800 Subject: [PATCH 1/2] Fix broadcast method --- .../com/thoughtworks/compute/OpenCLKernelBuilder.scala | 2 +- .../main/scala/com/thoughtworks/compute/Tensors.scala | 2 +- .../scala/com/thoughtworks/compute/TensorsSpec.scala | 10 ++++++++++ 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala b/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala index 260db270..5785979e 100644 --- a/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala +++ b/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala @@ -298,7 +298,7 @@ trait OpenCLKernelBuilder extends AllExpressions { def extract: Element = { val numberOfRows = originalShape.length val numberOfColumns = matrix.length / numberOfRows - if (matrix.length % numberOfRows != 0) { + if (matrix.length != numberOfRows * numberOfColumns) { throw new IllegalStateException() } diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index fb6e10f0..31947af3 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -711,7 +711,7 @@ trait Tensors extends OpenCL { if (i < length) { shape(i) match { case di if di == newShape(i) => - matrix1(i * (length + 1) + i) = 1.0 + matrix1(i * (newLength + 1) + i) = 1.0 case 1 => case _ => throw new IllegalArgumentException( diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 2e3b0ecd..8d33221e 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -293,4 +293,14 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { .run .toScalaFuture + "broadcast" in doTensors + .map { tensors => + import tensors._ + + val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))) + matrix1.broadcast(Array(2, 3, 4)).toString should be( + "[[[1.0,1.0,1.0,1.0],[2.0,2.0,2.0,2.0],[3.0,3.0,3.0,3.0]],[[4.0,4.0,4.0,4.0],[5.0,5.0,5.0,5.0],[6.0,6.0,6.0,6.0]]]") + } + .run + .toScalaFuture } From 2e773688e6142908812e3d7e8305d11981f9b50b Mon Sep 17 00:00:00 2001 From: Yang Bo Date: Mon, 26 Mar 2018 17:12:50 +0800 Subject: [PATCH 2/2] Add tests for matrix multiplication --- .../thoughtworks/compute/TensorsSpec.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 8d33221e..12ee9d44 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -293,6 +293,29 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { .run .toScalaFuture + "matrix multiplication" in doTensors + .map { tensors => + import tensors._ + + def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = { + val Array(i, j) = matrix1.shape + val Array(`j`, k) = matrix2.shape + val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k)) + + product.unzip(1).reduce(_ + _) + + } + + val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))) + val matrix2 = Tensor( + Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f))) + + matrixMultiply(matrix1, matrix2).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]") + + } + .run + .toScalaFuture + "broadcast" in doTensors .map { tensors => import tensors._ @@ -303,4 +326,33 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { } .run .toScalaFuture + + "unrolled matrix multiplication" in doTensors + .map { tensors => + import tensors._ + + def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = { + + val columns1 = matrix1.unzip(1) + + Tensor.zip(matrix2.unzip(1).map { column2: Tensor => + (columns1 zip column2.unzip(0)) + .map { + case (l: Tensor, r: Tensor) => + l * r.broadcast(l.shape) + } + .reduce[Tensor](_ + _) + }) + } + + matrixMultiply( + Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))), + Tensor( + Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f))) + ).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]") + + } + .run + .toScalaFuture + }