Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
def extract: Element = {
val numberOfRows = originalShape.length
val numberOfColumns = matrix.length / numberOfRows
if (matrix.length % numberOfRows != 0) {
if (matrix.length != numberOfRows * numberOfColumns) {
throw new IllegalStateException()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ trait Tensors extends OpenCL {
if (i < length) {
shape(i) match {
case di if di == newShape(i) =>
matrix1(i * (length + 1) + i) = 1.0
matrix1(i * (newLength + 1) + i) = 1.0
case 1 =>
case _ =>
throw new IllegalArgumentException(
Expand Down
62 changes: 62 additions & 0 deletions Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,66 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
.run
.toScalaFuture

"matrix multiplication" in doTensors
.map { tensors =>
import tensors._

def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {
val Array(i, j) = matrix1.shape
val Array(`j`, k) = matrix2.shape
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))

product.unzip(1).reduce(_ + _)

}

val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
val matrix2 = Tensor(
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))

matrixMultiply(matrix1, matrix2).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")

}
.run
.toScalaFuture

"broadcast" in doTensors
.map { tensors =>
import tensors._

val matrix1 = Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f)))
matrix1.broadcast(Array(2, 3, 4)).toString should be(
"[[[1.0,1.0,1.0,1.0],[2.0,2.0,2.0,2.0],[3.0,3.0,3.0,3.0]],[[4.0,4.0,4.0,4.0],[5.0,5.0,5.0,5.0],[6.0,6.0,6.0,6.0]]]")
}
.run
.toScalaFuture

"unrolled matrix multiplication" in doTensors
.map { tensors =>
import tensors._

def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {

val columns1 = matrix1.unzip(1)

Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
(columns1 zip column2.unzip(0))
.map {
case (l: Tensor, r: Tensor) =>
l * r.broadcast(l.shape)
}
.reduce[Tensor](_ + _)
})
}

matrixMultiply(
Tensor(Array(Array(1.0f, 2.0f, 3.0f), Array(4.0f, 5.0f, 6.0f))),
Tensor(
Array(Array(7.0f, 8.0f, 9.0f, 10.0f), Array(11.0f, 12.0f, 13.0f, 14.0f), Array(15.0f, 16.0f, 17.0f, 18.0f)))
).toString should be("[[74.0,80.0,86.0,92.0],[173.0,188.0,203.0,218.0]]")

}
.run
.toScalaFuture

}