Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 51 additions & 46 deletions Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import scalaz.syntax.tag._
// TODO: Rename to VirtualTensors, like virtual-dom
trait Tensors extends OpenCL {

def concatenate(tensors: Seq[Tensor], dimension: Int): Tensor = ???

protected val trees: FloatArrayTrees with StructuralTrees { type Category = Floats with Arrays } =
Factory[FloatArrayTrees with StructuralTrees].newInstance()
import trees._
Expand Down Expand Up @@ -81,8 +83,56 @@ trait Tensors extends OpenCL {
}

sealed trait Tensor { thisTensor =>
def broadcast(newShape: Array[Int]): Tensor = ???

def *(rightHandSide: Tensor): Tensor = ???

def +(rightHandSide: Tensor): Tensor = ???

def translate(
offset: Array[Double],
newShape: Array[Int] = shape) /*(implicit debuggingInformation0: Implicitly[DebuggingInformation])*/: Tensor = {
if (offset.length != thisTensor.shape.length) {
throw new IllegalArgumentException
}

thisTensor match {
case thisTensor: TransformedTensor =>
new TransformedTensor {
val matrix: RealMatrix = {
val newMatrix = thisTensor.matrix.copy()
for (i <- offset.indices) {
newMatrix.addToEntry(i, newMatrix.getColumnDimension - 1, offset(i))
}
newMatrix
}
val checkpoint: Tensor = thisTensor.checkpoint
val shape: Array[Int] = thisTensor.shape
// val debuggingInformation: Implicitly[DebuggingInformation] = debuggingInformation0
val padding: Float = thisTensor.padding
}
case _ =>
new TransformedTensor {
val checkpoint: Tensor = thisTensor
val shape: Array[Int] = newShape
// val debuggingInformation: Implicitly[DebuggingInformation] = debuggingInformation0
val matrix: RealMatrix = {
val newMatrix = MatrixUtils.createRealMatrix(shape.length, shape.length + 1)
for (i <- offset.indices) {
newMatrix.setEntry(i, i, 1.0)
newMatrix.setEntry(i, newMatrix.getColumnDimension - 1, offset(i))
}
newMatrix
}

def padding: Float = checkpoint.padding
}
}
}

// def debuggingInformation: Implicitly[DebuggingInformation]
def split(dimension: Int): Seq[Tensor] = ???

// def debuggingInformation: Implicitly[DebuggingInformation]

def shape: Array[Int]

Expand Down Expand Up @@ -231,49 +281,4 @@ trait Tensors extends OpenCL {
}
}

def translate(previousTensor: Tensor, offset: Seq[Double]): Tensor = {
translate(previousTensor, offset, previousTensor.shape)
}

def translate(previousTensor: Tensor,
offset: Seq[Double],
newShape: Array[Int]) /*(implicit debuggingInformation0: Implicitly[DebuggingInformation])*/: Tensor = {
if (offset.length != previousTensor.shape.length) {
throw new IllegalArgumentException
}

previousTensor match {
case previousTensor: TransformedTensor =>
new TransformedTensor {
val matrix: RealMatrix = {
val newMatrix = previousTensor.matrix.copy()
for (i <- offset.indices) {
newMatrix.addToEntry(i, newMatrix.getColumnDimension - 1, offset(i))
}
newMatrix
}
val checkpoint: Tensor = previousTensor.checkpoint
val shape: Array[Int] = previousTensor.shape
// val debuggingInformation: Implicitly[DebuggingInformation] = debuggingInformation0
val padding: Float = previousTensor.padding
}
case _ =>
new TransformedTensor {
val checkpoint: Tensor = previousTensor
val shape: Array[Int] = newShape
// val debuggingInformation: Implicitly[DebuggingInformation] = debuggingInformation0
val matrix: RealMatrix = {
val newMatrix = MatrixUtils.createRealMatrix(shape.length, shape.length + 1)
for (i <- offset.indices) {
newMatrix.setEntry(i, i, 1.0)
newMatrix.setEntry(i, newMatrix.getColumnDimension - 1, offset(i))
}
newMatrix
}

def padding: Float = checkpoint.padding
}
}
}

}
57 changes: 55 additions & 2 deletions Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ import org.scalatest._
* @author 杨博 (Yang Bo)
*/
class TensorsSpec extends AsyncFreeSpec with Matchers {
private def doTensors: Do[Tensors] = Do.monadicCloseable(
Factory[
private def doTensors: Do[Tensors] =
Do.monadicCloseable(Factory[
OpenCL.GlobalExecutionContext with OpenCL.UseAllDevices with OpenCL.UseFirstPlatform with OpenCL.CommandQueuePool with Tensors]
.newInstance(
handleOpenCLNotification = handleOpenCLNotification,
Expand Down Expand Up @@ -48,6 +48,59 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
}
}.run.toScalaFuture

"convolution" ignore {
doTensors.flatMap { tensors =>
import tensors.Tensor
import tensors.concatenate
def convolute(input: Tensor /* batchSize × height × width × depth */,
weight: Tensor /* kernelHeight × kernelWidth × depth × filterSize */,
bias: Tensor /* filterSize*/ ): Tensor = {
input.shape match {
case Array(batchSize, height, width, depth) =>
weight.shape match {
case Array(kernelHeight, kernelWidth, `depth`, filterSize) =>
bias.shape match {
case Array(`filterSize`) =>
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.split(dimension = 3)

val weightSeq: Seq[Seq[Seq[Seq[Tensor]]]] /* filterSize × kernelHeight × kernelWidth × depth */ =
weight
.split(dimension = 3)
.map(_.split(dimension = 0).map(_.split(dimension = 0).map(_.split(dimension = 0))))

val biasSeq: Seq[Tensor] /* filterSize */ = bias.split(dimension = 0)

val outputChannels: Seq[Tensor] = weightSeq.view.zip(biasSeq).map {
case (weightPerFilter, biasPerFilter) =>
val summands: Seq[Tensor] = for {
(offsetY, weightPerRow) <- (-1 to 1).view.zip(weightPerFilter)
(offsetX, weightPerPixel) <- (-1 to 1).view.zip(weightPerRow)
(
inputPerChannel /* batchSize × height × width */,
weightPerChannel /* scalar */
) <- inputSeq.view.zip(weightPerPixel)
} yield {
inputPerChannel.translate(Array(0, offsetY, offsetX)) *
weightPerChannel.broadcast(Array(batchSize, height, width))
}

biasPerFilter.broadcast(Array(batchSize, height, width)) + summands.reduce(_ + _)
}
concatenate(outputChannels, dimension = 3)
case _ =>
throw new IllegalArgumentException
}
case _ =>
throw new IllegalArgumentException
}
case _ =>
throw new IllegalArgumentException
}
}

??? : Do[Assertion]
}
}.run.toScalaFuture
}

object TensorsSpec {
Expand Down