Skip to content

Commit

Permalink
Merge pull request #30 from ThoughtWorksInc/transformed-views
Browse files Browse the repository at this point in the history
Implement some transformed views
  • Loading branch information
Atry committed Feb 6, 2018
2 parents 5510b51 + bc2c539 commit ff87a60
Showing 1 changed file with 57 additions and 1 deletion.
58 changes: 57 additions & 1 deletion Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala
Expand Up @@ -163,7 +163,63 @@ trait Tensors extends OpenCL {
}
}

def split(dimension: Int): Seq[Tensor] = ???
def permute(dimensions: Array[Int]): TransformedTensor = {
val length = shape.length
if (dimensions.length != length) {
throw new IllegalArgumentException
}
val newShape = Array.ofDim[Int](length)
val matrix = Array.ofDim[Double](length * (length + 1))
@tailrec def loop(newDimensionIndex: Int): Unit = {
if (newDimensionIndex < length) {
val oldDimensionIndex = dimensions(newDimensionIndex)
newShape(newDimensionIndex) = shape(oldDimensionIndex)
matrix(oldDimensionIndex * (length + 1) + newDimensionIndex) = 1.0
loop(newDimensionIndex + 1)
}
}
loop(0)
transform(newShape, matrix)
}

def split(dimension: Int): IndexedSeq[Tensor] = {
// TODO: override map/reduce to produce less OpenCL C code
new IndexedSeq[Tensor] {

private val newShape = shape.patch(dimension, Nil, 1)

val length: Int = shape(dimension)

def apply(index: Int): TransformedTensor = {
val length = shape.length
val matrix = Array.ofDim[Double](length * length)

@tailrec
def loopBefore(i: Int): Unit = {
if (i < dimension) {
matrix(i * length + i) = 1.0
loopBefore(i + 1)
}
}
loopBefore(0)

matrix(dimension * length + length - 1) = index.toDouble

@tailrec
def loopAfter(i: Int): Unit = {
if (i < length) {
matrix(i * length + i - 1) = 1.0
loopAfter(i + 1)
}
}
loopAfter(dimension + 1)

// TODO: Cache the transformed tensor
transform(newShape, matrix)
}

}
}

// def debuggingInformation: Implicitly[DebuggingInformation]

Expand Down

0 comments on commit ff87a60

Please sign in to comment.