Skip to content

Commit

Permalink
Merge a746a77 into 224e553
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Jul 1, 2020
2 parents 224e553 + a746a77 commit d88ea37
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.7.2"
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Expand Down
3 changes: 3 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Reexport
using Base.Broadcast: materialize, materialize!, broadcasted, Broadcasted, broadcastable
using LinearAlgebra
using LinearAlgebra.BLAS
using Random
using Requires
using Statistics

Expand Down Expand Up @@ -35,6 +36,8 @@ include("rulesets/LinearAlgebra/dense.jl")
include("rulesets/LinearAlgebra/structured.jl")
include("rulesets/LinearAlgebra/factorization.jl")

include("rulesets/Random/random.jl")

# Note: The following is only required because package authors sometimes do not
# declare their own rules using `ChainRulesCore.jl`. For arguably good reasons.
# So we define them here for them.
Expand Down
8 changes: 8 additions & 0 deletions src/rulesets/Random/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
frule(Δargs, ::typeof(MersenneTwister), args...) = MersenneTwister(args...), Zero()

function rrule(::typeof(MersenneTwister), args...)
function MersenneTwister_rrule(ΔΩ)
return (NO_FIELDS, map(_ -> Zero(), args)...)
end
return MersenneTwister(args...), MersenneTwister_rrule
end
22 changes: 22 additions & 0 deletions test/rulesets/Random/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
@testset "random" begin
@testset "MersenneTwister" begin
@testset "no args" begin
rng, dΩ = frule((5.0,), MersenneTwister)
@test rng isa MersenneTwister
@testisa Zero

rng, pb = rrule(MersenneTwister)
@test rng isa MersenneTwister
@test first(pb(10)) isa Zero
end
@testset "unary" begin
rng, dΩ = frule((5.0, 4.0), MersenneTwister, 123)
@test rng isa MersenneTwister
@testisa Zero

rng, pb = rrule(MersenneTwister, 123)
@test rng isa MersenneTwister
@test all(map(x -> x isa Zero, pb(10)))
end
end
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ println("Testing ChainRules.jl")

print(" ")

@testset "Random" begin
include(joinpath("rulesets", "Random", "random.jl"))
end

print(" ")

@testset "packages" begin
include(joinpath("rulesets", "packages", "NaNMath.jl"))
include(joinpath("rulesets", "packages", "SpecialFunctions.jl"))
Expand Down

0 comments on commit d88ea37

Please sign in to comment.