From 39af0276f08372c888bcfc07889c42b48dcc769c Mon Sep 17 00:00:00 2001 From: Victor Guerra Date: Mon, 11 Jan 2021 09:58:38 +0100 Subject: [PATCH 1/2] [AutoDiff] Registers VJPs for `FloatingPoint.[maximum|minimum]` Resolves TF-1134. --- .../FloatingPointDifferentiation.swift.gyb | 24 +++++++++++++++++++ test/AutoDiff/stdlib/floating_point.swift.gyb | 16 +++++++++++++ 2 files changed, 40 insertions(+) diff --git a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb index f2ff2fcbff530..2595cbb277230 100644 --- a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb @@ -282,4 +282,28 @@ where let y = squareRoot() return (y, { v in v / (2 * y) }) } + + @inlinable + @derivative(of: minimum) + static func _vjpMinimum(_ x: Self, _ y: Self) -> ( + value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) + ) { + func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { + if x <= y || y.isNaN { return (v, .zero) } + return (.zero, v) + } + return (value: Self.minimum(x, y), pullback: pullback) + } + + @inlinable + @derivative(of: maximum) + static func _vjpMaximum(_ x: Self, _ y: Self) -> ( + value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) + ) { + func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { + if x > y || y.isNaN { return (v, .zero) } + return (.zero, v) + } + return (value: Self.maximum(x, y), pullback: pullback) + } } diff --git a/test/AutoDiff/stdlib/floating_point.swift.gyb b/test/AutoDiff/stdlib/floating_point.swift.gyb index c1eef2f03a8b5..512070963df7a 100644 --- a/test/AutoDiff/stdlib/floating_point.swift.gyb +++ b/test/AutoDiff/stdlib/floating_point.swift.gyb @@ -83,6 +83,22 @@ FloatingPointDerivativeTests.test("${Self}.addingProduct") { expectEqual((1, 2, 3), gradient(at: ${Self}(10), 3, 2, in: { $0.addingProduct($1, $2) })) } +FloatingPointDerivativeTests.test("${Self}.minimum") { + expectEqual((1.0, 0.0), gradient(at: ${Self}(1), ${Self}(2), in : { ${Self}.minimum($0, $1) })) + expectEqual((1.0, 0.0), gradient(at: ${Self}(1), ${Self}(1), in : { ${Self}.minimum($0, $1) })) + expectEqual((0.0, 1.0), gradient(at: ${Self}(2), ${Self}(1), in : { ${Self}.minimum($0, $1) })) + expectEqual((1.0, 0.0), gradient(at: ${Self}(1), .nan, in : { ${Self}.minimum($0, $1) })) + expectEqual((0.0, 1.0), gradient(at: .nan, ${Self}(1), in : { ${Self}.minimum($0, $1) })) +} + +FloatingPointDerivativeTests.test("${Self}.maximum") { + expectEqual((0.0, 1.0), gradient(at: ${Self}(1), ${Self}(2), in : { ${Self}.maximum($0, $1) })) + expectEqual((0.0, 1.0), gradient(at: ${Self}(1), ${Self}(1), in : { ${Self}.maximum($0, $1) })) + expectEqual((1.0, 0.0), gradient(at: ${Self}(2), ${Self}(1), in : { ${Self}.maximum($0, $1) })) + expectEqual((1.0, 0.0), gradient(at: ${Self}(1), .nan, in : { ${Self}.maximum($0, $1) })) + expectEqual((0.0, 1.0), gradient(at: .nan, ${Self}(1), in : { ${Self}.maximum($0, $1) })) +} + %if Self == 'Float80': #endif %end From bc40409420a2c5fa688fcc0347e44d094ee53f50 Mon Sep 17 00:00:00 2001 From: Victor Guerra Date: Wed, 13 Jan 2021 11:58:40 +0100 Subject: [PATCH 2/2] Simplifying control flow within pullbacks. To avoid executing control flow code twice ( once in VJP and once in the pullback ). --- .../FloatingPointDifferentiation.swift.gyb | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb index 2595cbb277230..8d784c1ae1274 100644 --- a/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb +++ b/stdlib/public/Differentiation/FloatingPointDifferentiation.swift.gyb @@ -288,11 +288,8 @@ where static func _vjpMinimum(_ x: Self, _ y: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { - func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { - if x <= y || y.isNaN { return (v, .zero) } - return (.zero, v) - } - return (value: Self.minimum(x, y), pullback: pullback) + if x <= y || y.isNaN { return (x, { v in (v, .zero) }) } + return (y, { v in (.zero, v) }) } @inlinable @@ -300,10 +297,7 @@ where static func _vjpMaximum(_ x: Self, _ y: Self) -> ( value: Self, pullback: (TangentVector) -> (TangentVector, TangentVector) ) { - func pullback(_ v: TangentVector) -> (TangentVector, TangentVector) { - if x > y || y.isNaN { return (v, .zero) } - return (.zero, v) - } - return (value: Self.maximum(x, y), pullback: pullback) + if x > y || y.isNaN { return (x, { v in (v, .zero) }) } + return (y, { v in (.zero, v) }) } }