diff --git a/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala b/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala index 260db270..5785979e 100644 --- a/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala +++ b/OpenCLKernelBuilder/src/main/scala/com/thoughtworks/compute/OpenCLKernelBuilder.scala @@ -298,7 +298,7 @@ trait OpenCLKernelBuilder extends AllExpressions { def extract: Element = { val numberOfRows = originalShape.length val numberOfColumns = matrix.length / numberOfRows - if (matrix.length % numberOfRows != 0) { + if (matrix.length != numberOfRows * numberOfColumns) { throw new IllegalStateException() } diff --git a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala index fb6e10f0..31947af3 100644 --- a/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala +++ b/Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala @@ -711,7 +711,7 @@ trait Tensors extends OpenCL { if (i < length) { shape(i) match { case di if di == newShape(i) => - matrix1(i * (length + 1) + i) = 1.0 + matrix1(i * (newLength + 1) + i) = 1.0 case 1 => case _ => throw new IllegalArgumentException( diff --git a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala index 2e3b0ecd..12ee9d44 100644 --- a/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala +++ b/Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala @@ -293,4 +293,66 @@ class TensorsSpec extends AsyncFreeSpec with Matchers { .run .toScalaFuture + "matrix multiplication" in doTensors + .map { tensors => + import tensors._ + + def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = { + val Array(i, j) = matrix1.shape + val Array(`j`, k) = matrix2.shape + val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k)) + + product.unzip(1).reduce(_ + _) + + } + + val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))) + val matrix2 = Tensor( + Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f))) + + matrixMultiply(matrix1, matrix2).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]") + + } + .run + .toScalaFuture + + "broadcast" in doTensors + .map { tensors => + import tensors._ + + val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))) + matrix1.broadcast(Array(2, 3, 4)).toString should be( + "[[[1.0,1.0,1.0,1.0],[2.0,2.0,2.0,2.0],[3.0,3.0,3.0,3.0]],[[4.0,4.0,4.0,4.0],[5.0,5.0,5.0,5.0],[6.0,6.0,6.0,6.0]]]") + } + .run + .toScalaFuture + + "unrolled matrix multiplication" in doTensors + .map { tensors => + import tensors._ + + def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = { + + val columns1 = matrix1.unzip(1) + + Tensor.zip(matrix2.unzip(1).map { column2: Tensor => + (columns1 zip column2.unzip(0)) + .map { + case (l: Tensor, r: Tensor) => + l * r.broadcast(l.shape) + } + .reduce[Tensor](_ + _) + }) + } + + matrixMultiply( + Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))), + Tensor( + Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f))) + ).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]") + + } + .run + .toScalaFuture + }