Skip to content

Commit

Permalink
[TF] Add bullet operator (•) for matrix multiplication (#17173)
Browse files Browse the repository at this point in the history
Add `•` operator for matmul, and remove `Tensor.dot` and `⊗` completely.
  • Loading branch information
rxwei committed Jun 14, 2018
1 parent d2bf0dc commit b876a3d
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 25 deletions.
13 changes: 3 additions & 10 deletions stdlib/public/TensorFlow/Ops.swift
Expand Up @@ -265,21 +265,14 @@ public func matmul<Scalar : Numeric>(
return Raw.matMul(left, right)
}

infix operator : MultiplicationPrecedence
infix operator : MultiplicationPrecedence

public extension Tensor where Scalar : Numeric {
@_inlineable @inline(__always)
@available(*, renamed: "matmul(_:_:)")
func dot(_ other: Tensor) -> Tensor {
return matmul(self, other)
}

/// Performs matrix multiplication between two tensors and produces the
/// result.
@_inlineable @inline(__always)
@available(*, renamed: "matmul(_:_:)")
static func (lhs: Tensor, rhs: Tensor) -> Tensor {
return lhs.dot(rhs)
static func (lhs: Tensor, rhs: Tensor) -> Tensor {
return matmul(lhs, rhs)
}
}

Expand Down
12 changes: 6 additions & 6 deletions test/TensorFlow/crashers.swift
Expand Up @@ -33,7 +33,7 @@ public func postdom_crash1(w1: Tensor<Float>, inputBatch: Tensor<Float>) {
// expected-warning @-2 {{'inputBatch' implicitly copied to the accelerator}}
let iterationCount = 1000
for _ in 0..<iterationCount {
_ = inputBatch w1 // expected-note 2 {{value used here}}
_ = inputBatch w1 // expected-note 2 {{value used here}}
}
}

Expand Down Expand Up @@ -91,10 +91,10 @@ public func testStraightLineXORTraining() {

// Training loop
for _ in 0..<iterationCount {
let mmul1 = inputBatch w1
let mmul1 = inputBatch w1
let l1 = mmul1 + b1
let o1 = sigmoid(l1)
let mmul2 = o1 w2
let mmul2 = o1 w2
let l2 = mmul2 + b2
let pred = sigmoid(l2)

Expand All @@ -109,15 +109,15 @@ public func testStraightLineXORTraining() {
let dL2 = dPred * pred * (1 - pred)
let dMmul2 = dL2
let dB2 = dL2
let dO1 = dMmul2 w2.transposed(withPermutations: 1, 0)
let dW2 = o1.transposed(withPermutations: 1, 0) dMmul2
let dO1 = dMmul2 w2.transposed(withPermutations: 1, 0)
let dW2 = o1.transposed(withPermutations: 1, 0) dMmul2
let dL1 = dO1 * l1 * (1 - l1)
let dMmul1 = dL1
let dB1 = dL1

// Statically detected shape mismatch!
// expected-error @+1 {{(op: 'MatMul') with input shapes: [4,2], [4,4]}}
let dW1 = inputBatch dMmul1
let dW1 = inputBatch dMmul1

// Descent
w1 -= (dW1 * learningRate)
Expand Down
14 changes: 7 additions & 7 deletions test/TensorFlow/no_copy.swift
Expand Up @@ -167,8 +167,8 @@ struct Classifier {
var b2 = Tensor<Float>(zeros: [1, 10])

func prediction(for input: Tensor<Float>) -> Tensor<Float> {
let h1 = sigmoid(input w1 + b1)
return sigmoid(h1 w2 + b2)
let h1 = sigmoid(input w1 + b1)
return sigmoid(h1 w2 + b2)
}

mutating func train(images: Tensor<Float>, labels: Tensor<Float>,
Expand All @@ -177,17 +177,17 @@ struct Classifier {
var epochCount = epochCount
repeat {
// Forward pass
let z1 = images w1 + b1
let z1 = images w1 + b1
let h1 = sigmoid(z1)
let z2 = h1 w2 + b2
let z2 = h1 w2 + b2
let pred = sigmoid(z2)

// Backward pass
let dz2 = pred - labels
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
let dw2 = h1.transposed(withPermutations: 1, 0) dz2
let db2 = dz2.sum(squeezingAxes: 0)
let dz1 = dz2.dot(w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
let dw1 = images.transposed(withPermutations: 1, 0) dz1
let dz1 = matmul(dz2, w2.transposed(withPermutations: 1, 0)) * h1 * (1 - h1)
let dw1 = images.transposed(withPermutations: 1, 0) dz1
let db1 = dz1.sum(squeezingAxes: 0)

// Gradient descent
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_debuglog.swift
Expand Up @@ -26,7 +26,7 @@ TensorTests.testAllBackends("XWPlusB") {
// Shape: 2
let b = Tensor<Float>([0.5, 0.5])
// Do xW+b!
let result = x w + b
let result = x w + b
expectEqual([1, 2], result.shape)
expectEqual([12.5, 6.5], result.scalars)
}
Expand Down
2 changes: 1 addition & 1 deletion test/TensorFlowRuntime/tensor_xla_debuglog.swift
Expand Up @@ -27,7 +27,7 @@ XLATests.test("XWPlusB_XLA") {
// Shape: 2
let b = Tensor<Float>([0.5, 0.5])
// Do xW+b!
let result = x w + b
let result = x w + b
expectEqual([1, 2], result.shape)
expectEqual([12.5, 6.5], result.scalars)
#endif
Expand Down

0 comments on commit b876a3d

Please sign in to comment.