Skip to content

Commit

Permalink
Merge pull request #13 from SBuercklin/sbuercklin/simple-array-rules
Browse files Browse the repository at this point in the history
Sbuercklin/simple array rules
  • Loading branch information
SBuercklin committed Aug 9, 2022
2 parents f6bb195 + 53f5884 commit dde0633
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 1 deletion.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[compat]
Expand Down
5 changes: 4 additions & 1 deletion src/UnitfulChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ module UnitfulChainRules

using Unitful
using Unitful: Quantity, Units, NoDims, FreeUnits
using ChainRulesCore: NoTangent, @scalar_rule, @thunk
using ChainRulesCore
import ChainRulesCore: rrule, frule, ProjectTo
using LinearAlgebra

const REALCOMPLEX = Union{Real, Complex}

Expand All @@ -18,4 +19,6 @@ include("./trig.jl") # sin, cos, tan, etc for degrees

include("./math.jl") # other math

include("./arraymath.jl") # Simple scalar-array math

end # module
66 changes: 66 additions & 0 deletions src/arraymath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
const CommutativeMulQuantity = Quantity{T,D,U} where {T<:Union{Real,Complex}, D, U}
const CommMulVal = Union{Real, Complex, CommutativeMulQuantity}

# Reference: https://github.com/JuliaDiff/ChainRules.jl/blob/148fa8875725a19cf658405609fa1a56671d0cbd/src/rulesets/Base/arraymath.jl

# Defines *, / for the pairs where:
# 1. The scalar is a commutative/mul quantity and the array is real, complex, or a comm/mul quantity
# 2. The scalar is a commutative/mul number and the array is a comm/mul quantity
# We have to be careful defining this so that we always have a Quantity in the signature, otherwise
# we overwrite methods from ChainRules.jl
for (s_type,a_type) in (
(:CommutativeMulQuantity, :(<:CommMulVal)),
(:(Union{Real,Complex}), :(<:CommutativeMulQuantity))
)
@eval function rrule(
::typeof(*), A::$(s_type), B::AbstractArray{$(a_type)}
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
= unthunk(ȳ)
return (
NoTangent(),
@thunk(project_A(dot(Ȳ, B)')),
InplaceableThunk(
-> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
)
)
end
return A * B, times_pullback
end

@eval function rrule(
::typeof(*), B::AbstractArray{$(a_type)}, A::$(s_type)
)
project_A = ProjectTo(A)
project_B = ProjectTo(B)
function times_pullback(ȳ)
= unthunk(ȳ)
return (
NoTangent(),
InplaceableThunk(
-> mul!(X̄, conj(A), Ȳ, true, true),
@thunk(project_B(A' * Ȳ)),
),
@thunk(project_A(dot(Ȳ, B)')),
)
end
return A * B, times_pullback
end

@eval function rrule(::typeof(/), A::AbstractArray{$(a_type)}, b::$(s_type))
Y = A/b
function slash_pullback_scalar(ȳ)
= unthunk(ȳ)
Athunk = InplaceableThunk(
dA -> dA .+=./ conj(b),
@thunk(Ȳ / conj(b)),
)
bthunk = @thunk(-dot(A,Ȳ) / conj(b^2))
return (NoTangent(), Athunk, bthunk)
end
return Y, slash_pullback_scalar
end
end
46 changes: 46 additions & 0 deletions test/arraymath.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Unitful
using UnitfulChainRules

using Zygote

using Random
rng = VERSION >= v"1.7" ? Random.Xoshiro(0x0451) : Random.MersenneTwister(0x0451)

@testset "Array-Scalar Multiplication" begin
for (a_unit, s_unit) in ((1.0,oneunit(1.0u"m")), (oneunit(1.0u"m"), 1.0))
A = randn(rng, 5) * a_unit
s = randn(rng) * s_unit

@testset "A * s ($a_unit, $s_unit)" begin
Ω, pb = Zygote.pullback(*, A, s)

@test Ω A * s
@test all(first(pb(one.(Ω))) .≈ s)
@test last(pb(one.(Ω))) sum(A)
end

@testset "s * A ($s_unit, $a_unit)" begin

Ω, pb = Zygote.pullback(*, s, A)

@test Ω s * A
@test all(last(pb(one.(Ω))) .≈ s)
@test first(pb(one.(Ω))) sum(A)
end
end
end

@testset "Array-Scalar Division" begin
for (a_unit, s_unit) in ((1.0,oneunit(1.0u"m")), (oneunit(1.0u"m"), 1.0))
@testset "($a_unit, $s_unit) division" begin
A = randn(rng, 5) * a_unit
s = randn(rng) * s_unit

Ω, pb = Zygote.pullback(/, A, s)

@test Ω A / s
@test all(first(pb(one.(Ω))) .≈ inv(s))
@test last(pb(one.(Ω))) -sum(A)/s^2
end
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,8 @@ end

@safetestset "Math" begin
include("./math.jl")
end

@safetestset "Array Math" begin
include("./arraymath.jl")
end

2 comments on commit dde0633

@SBuercklin
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/65973

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.2 -m "<description of version>" dde0633c84c1c4a3c65216ec5c8242259bbcf538
git push origin v0.1.2

Please sign in to comment.