Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ object Expressions {
}

protected trait TupleTermApi extends ValueTermApi with TupleExpressionApi { this: TupleTerm =>
def unzip: Seq[Element]
def split: Seq[Element]
}

/** @template */
Expand All @@ -213,7 +213,7 @@ object Expressions {

def parameter(id: Any, element: ValueType, length: Int): TupleTerm { type Element = element.ThisTerm }

def zip[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 }
def join[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 }

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
type FloatTerm <: (ValueTerm with Any) with ClFloatTerm

trait ClTupleTerm extends TupleTermApi with ClValueTerm { thisTupleTerm: TupleTerm =>
def unzip: Seq[Element] = new IndexedSeq[Element] {
def split: Seq[Element] = new IndexedSeq[Element] {

def length: Int = thisTupleTerm.length

Expand Down Expand Up @@ -566,7 +566,7 @@ trait OpenCLKernelBuilder extends AllExpressions {
tupleTermFactory[element.ThisTerm].newInstance(element, length, termCode)
}

def zip[Element0 <: ValueTerm](elements: Element0*): TupleTerm {
def join[Element0 <: ValueTerm](elements: Element0*): TupleTerm {
type Element = Element0
} = {
val elementType = elements.head.valueType
Expand Down
6 changes: 3 additions & 3 deletions Tensors/src/main/scala/com/thoughtworks/compute/Tensors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ trait Tensors extends OpenCL {
}
}

def zip(tensors0: Seq[Tensor]): BufferedTensor = {
def join(tensors0: Seq[Tensor]): BufferedTensor = {
def force[A](seq: Seq[A]) = {
seq match {
case seqView: SeqView[A, _] @unchecked =>
Expand All @@ -559,7 +559,7 @@ trait Tensors extends OpenCL {
} with BufferedTensor {
private[compute] val doBuffer = {
val elements = tensors.map(_.closure)
enqueueClosure(trees.tuple.zip(elements: _*), headTensor.shape).asInstanceOf[Do[PendingBuffer[Float]]]
enqueueClosure(trees.tuple.join(elements: _*), headTensor.shape).asInstanceOf[Do[PendingBuffer[Float]]]
}.shared
}
}
Expand Down Expand Up @@ -937,7 +937,7 @@ trait Tensors extends OpenCL {
/**
* @group delayed
*/
def unzip(dimension: Int): IndexedSeq[Tensor] = {
def split(dimension: Int): IndexedSeq[Tensor] = {
// TODO: override map/reduce to produce less OpenCL C code
val newShape = shape.patch(dimension, Nil, 1)
new IndexedSeq[Tensor] {
Expand Down
26 changes: 13 additions & 13 deletions Tensors/src/test/scala/com/thoughtworks/compute/TensorsSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
doTensors.map { tensors =>
import tensors._
val tensor = Tensor(Seq(Seq(Seq(Seq(1.0f, 5.0f)))))
tensor.unzip(dimension = 3).map(_.toString) should be(Seq("[[[1.0]]]", "[[[5.0]]]"))
tensor.split(dimension = 3).map(_.toString) should be(Seq("[[[1.0]]]", "[[[5.0]]]"))
}
}.run.toScalaFuture

Expand All @@ -140,7 +140,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
"convolution" in {
doTensors.flatMap { tensors =>
import tensors.Tensor
import tensors.Tensor.zip
import tensors.Tensor.join
def convolute(input: Tensor /* batchSize × height × width × depth */,
weight: Tensor /* kernelHeight × kernelWidth × depth × filterSize */,
bias: Tensor /* filterSize */ ): Tensor = {
Expand All @@ -150,20 +150,20 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
case Array(kernelHeight, kernelWidth, `depth`, filterSize) =>
bias.shape match {
case Array(`filterSize`) =>
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.unzip(dimension = 3)
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.split(dimension = 3)

inputSeq.size should be(depth)
inputSeq.head.shape should be(Array(batchSize, height, width))

val weightSeq: Seq[Seq[Seq[Seq[Tensor]]]] /* filterSize × kernelHeight × kernelWidth × depth */ =
weight.unzip(dimension = 3).map { khKwD =>
weight.split(dimension = 3).map { khKwD =>
khKwD.shape should be(Array(kernelHeight, kernelWidth, depth))

khKwD.unzip(dimension = 0).map { kwD =>
khKwD.split(dimension = 0).map { kwD =>
kwD.shape should be(Array(kernelWidth, depth))
kwD.unzip(dimension = 0).map { d =>
kwD.split(dimension = 0).map { d =>
d.shape should be(Array(depth))
d.unzip(dimension = 0)
d.split(dimension = 0)
}
}
}
Expand All @@ -173,7 +173,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
weightSeq.head.head.length should be(kernelWidth)
weightSeq.head.head.head.length should be(depth)

val biasSeq: Seq[Tensor] /* filterSize */ = bias.unzip(dimension = 0)
val biasSeq: Seq[Tensor] /* filterSize */ = bias.split(dimension = 0)

val outputChannels: Seq[Tensor] = weightSeq.view
.zip(biasSeq)
Expand All @@ -197,7 +197,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
biasPerFilter.broadcast(Array(batchSize, height, width)) + summands.reduce(_ + _)
}

zip(outputChannels)
join(outputChannels)
case _ =>
throw new IllegalArgumentException
}
Expand Down Expand Up @@ -307,7 +307,7 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {
val Array(`j`, k) = matrix2.shape
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))

product.unzip(1).reduce(_ + _)
product.split(1).reduce(_ + _)

}

Expand Down Expand Up @@ -338,10 +338,10 @@ class TensorsSpec extends AsyncFreeSpec with Matchers {

def matrixMultiply(matrix1: Tensor, matrix2: Tensor): Tensor = {

val columns1 = matrix1.unzip(1)
val columns1 = matrix1.split(1)

Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
(columns1 zip column2.unzip(0))
Tensor.join(matrix2.split(1).map { column2: Tensor =>
(columns1 zip column2.split(0))
.map {
case (l: Tensor, r: Tensor) =>
l * r.broadcast(l.shape)
Expand Down
8 changes: 4 additions & 4 deletions Trees/src/main/scala/com/thoughtworks/compute/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ object Trees {
protected def erasedExport(foreignCategory: Category, context: ExportContext): Category#Term = {
def foreignTerm = {
val foreignElements = elementTrees.map(_.export(foreignCategory, context))
foreignCategory.tuple.zip(foreignElements: _*)
foreignCategory.tuple.join(foreignElements: _*)
}
context.asScala.getOrElseUpdate(this, foreignTerm)
}
Expand All @@ -984,7 +984,7 @@ object Trees {

protected def erasedExport(foreignCategory: Category, context: ExportContext) = {
context.asScala
.getOrElseUpdate(this, tuple.export(foreignCategory, context).unzip.apply(index))
.getOrElseUpdate(this, tuple.export(foreignCategory, context).split.apply(index))

}

Expand All @@ -1004,7 +1004,7 @@ object Trees {

val length: Int

def unzip: Seq[Element] = {
def split: Seq[Element] = {
new IndexedSeq[Element] {
def length = thisTuple.length
def apply(index: Int): Element = {
Expand Down Expand Up @@ -1037,7 +1037,7 @@ object Trees {
tupleTermFactory[element.ThisTerm].newInstance(element, length, parameterTree)
}

def zip[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 } = {
def join[Element0 <: ValueTerm](elements: Element0*): TupleTerm { type Element = Element0 } = {
val elementTrees = elements.map(_.tree.asInstanceOf[Tree { type TermIn[C <: Category] = Element0#TermIn[C] }])
val zipTree = Concatenate[Element0](elementTrees)
tupleTermFactory[Element0].newInstance(elements.head.valueType, elements.length, zipTree)
Expand Down
10 changes: 5 additions & 5 deletions Trees/src/test/scala/com/thoughtworks/compute/TreesSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ final class TreesSpec extends FreeSpec with Matchers {
"tuple.zip" - {
"reflexive" in {
reflexive(
trees.tuple.zip(
trees.tuple.join(
trees.float.parameter("my_id"),
trees.float.literal(2.0f),
trees.float.literal(3.0f)
Expand All @@ -65,12 +65,12 @@ final class TreesSpec extends FreeSpec with Matchers {

"sameStructuralDifferentParameterName" in {
sameStructuralDifferentParameterName(
trees.tuple.zip(
trees.tuple.join(
trees.float.parameter("my_id1"),
trees.float.parameter("my_id2"),
trees.float.literal(0.0f)
),
trees.tuple.zip(
trees.tuple.join(
trees.float.parameter("my_id2"),
trees.float.parameter("my_id3"),
trees.float.literal(0.0f)
Expand All @@ -80,11 +80,11 @@ final class TreesSpec extends FreeSpec with Matchers {

"differentStructural" in {
differentStructural(
trees.tuple.zip(
trees.tuple.join(
trees.float.literal(1.0f),
trees.float.literal(0.0f)
),
trees.tuple.zip(
trees.tuple.join(
trees.float.literal(0.0f),
trees.float.literal(1.0f)
)
Expand Down
22 changes: 11 additions & 11 deletions benchmarks/src/jmh/scala/com/thoughtworks/compute/benchmarks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ object benchmarks {
val Array(i, j) = matrix1.shape
if (i >= unrollThreshold) {
// unroll j and k
val columns1 = matrix1.unzip(1)
Tensor.zip(matrix2.unzip(1).map { column2: Tensor =>
(columns1 zip column2.unzip(0))
val columns1 = matrix1.split(1)
Tensor.join(matrix2.split(1).map { column2: Tensor =>
(columns1 zip column2.split(0))
.map {
case (l: Tensor, r: Tensor) =>
l * r.broadcast(l.shape)
Expand All @@ -100,7 +100,7 @@ object benchmarks {
// unroll only j
val Array(`j`, k) = matrix2.shape
val product = matrix1.broadcast(Array(i, j, k)) * matrix2.reshape(Array(1, j, k)).broadcast(Array(i, j, k))
product.unzip(1).reduce(_ + _)
product.split(1).reduce(_ + _)
}
}

Expand Down Expand Up @@ -380,7 +380,7 @@ object benchmarks {
case Array(kernelHeight, kernelWidth, `depth`, filterSize) =>
bias.shape match {
case Array(`filterSize`) =>
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.unzip(dimension = 3)
val inputSeq: Seq[Tensor /* batchSize × height × width */ ] = input.split(dimension = 3)

if (inputSeq.size != depth) {
throw new IllegalArgumentException
Expand All @@ -393,27 +393,27 @@ object benchmarks {
}

val weightSeq: Seq[Seq[Seq[Seq[Tensor]]]] /* filterSize × kernelHeight × kernelWidth × depth */ =
weight.unzip(dimension = 3).map { khKwD =>
weight.split(dimension = 3).map { khKwD =>
khKwD.shape match {
case Array(kernelHeight, kernelWidth, depth) =>
case _ =>
throw new IllegalArgumentException
}

khKwD.unzip(dimension = 0).map { kwD =>
khKwD.split(dimension = 0).map { kwD =>
kwD.shape match {
case Array(kernelWidth, depth) =>
case _ =>
throw new IllegalArgumentException
}

kwD.unzip(dimension = 0).map { d =>
kwD.split(dimension = 0).map { d =>
d.shape match {
case Array(depth) =>
case _ =>
throw new IllegalArgumentException
}
d.unzip(dimension = 0)
d.split(dimension = 0)
}
}
}
Expand All @@ -428,7 +428,7 @@ object benchmarks {
throw new IllegalArgumentException
}

val biasSeq: Seq[Tensor] /* filterSize */ = bias.unzip(dimension = 0)
val biasSeq: Seq[Tensor] /* filterSize */ = bias.split(dimension = 0)

val outputChannels: Seq[Tensor] = weightSeq.view
.zip(biasSeq)
Expand All @@ -454,7 +454,7 @@ object benchmarks {
biasPerFilter.broadcast(Array(batchSize, height, width)) + summands.reduce(_ + _)
}

Tensor.zip(outputChannels)
Tensor.join(outputChannels)
case _ =>
throw new IllegalArgumentException
}
Expand Down