Skip to content

Commit

Permalink
Merge pull request #484 from JuliaDiff/ox/norngcon
Browse files Browse the repository at this point in the history
Make construtors of all AbstractRNGs as nondifferentiable
  • Loading branch information
oxinabox committed Jul 30, 2021
2 parents cf55350 + ef2d549 commit 36508af
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.3.0"
version = "1.4.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
8 changes: 4 additions & 4 deletions src/rulesets/Random/random.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
frule(Δargs, ::Type{MersenneTwister}, args...) = MersenneTwister(args...), ZeroTangent()
frule(Δargs, T::Type{<:AbstractRNG}, args...) = T(args...), ZeroTangent()

function rrule(::Type{MersenneTwister}, args...)
function MersenneTwister_pullback(ΔΩ)
function rrule(T::Type{<:AbstractRNG}, args...)
function AbstractRNG_pullback(ΔΩ)
return (NoTangent(), map(_ -> ZeroTangent(), args)...)
end
return MersenneTwister(args...), MersenneTwister_pullback
return T(args...), AbstractRNG_pullback
end

@non_differentiable Broadcast.broadcastable(::AbstractRNG)
Expand Down
26 changes: 17 additions & 9 deletions test/rulesets/Random/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,26 @@ end
Random.rand(d::NormalDistribution) = d.μ + d.σ*randn()

@testset "random" begin
@testset "MersenneTwister" begin
rng_types = [MersenneTwister]
isdefined(Random, :Xoshiro) && push!(rng_types, getfield(Random, :Xoshiro))

@testset "$Rng" for Rng in rng_types
@testset "no args" begin
rng, dΩ = frule((5.0,), MersenneTwister)
@test rng isa MersenneTwister
rng, dΩ = frule((5.0,), Rng)
@test rng isa Rng
@testisa ZeroTangent

rng, pb = rrule(MersenneTwister)
@test rng isa MersenneTwister
rng, pb = rrule(Rng)
@test rng isa Rng
@test first(pb(10)) isa typeof(NoTangent())
end
@testset "unary" begin
rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123)
@test rng isa MersenneTwister
rng, dΩ = frule((5.0, 4.0), Rng, 123)
@test rng isa Rng
@testisa ZeroTangent

rng, pb = rrule(MersenneTwister, 123)
@test rng isa MersenneTwister
rng, pb = rrule(Rng, 123)
@test rng isa Rng
@test all(map(x -> x isa AbstractZero, pb(10)))
end
end
Expand All @@ -37,6 +40,11 @@ Random.rand(d::NormalDistribution) = d.μ + d.σ*randn()
((Float32,(2,2)), Matrix{<:Float32}),
((2,2), Matrix{<:Float64}),
]
if isdefined(Random, :Xoshiro)
Xoshiro = getfield(Random, :Xoshiro)
push!(non_differentiables, ((Xoshiro(123),), Float64))
push!(non_differentiables, ((Xoshiro(123),2,2), Matrix{<:Float64}))
end

for (args, xType) in non_differentiables
x, dΩ = frule((ZeroTangent(), randn(args...)), rand, args...)
Expand Down

2 comments on commit 36508af

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/41876

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.0 -m "<description of version>" 36508afd519cf12647098aa5e69e92f869f7a2a1
git push origin v1.4.0

Please sign in to comment.