Skip to content

Commit

Permalink
Merge pull request #1867 from ShoofLLC/master
Browse files Browse the repository at this point in the history
Updated Dropout for more input types.
  • Loading branch information
ToucheSir committed Feb 8, 2022
2 parents 7b56813 + ccb328c commit 1930966
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 1 deletion.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ been removed in favour of MLDatasets.jl.
* `params` is not exported anymore since it is a common name and is also exported by Distributions.jl
* `flatten` is not exported anymore due to clash with Iterators.flatten.
* Remove Juno.jl progress bar support as it is now obsolete.
* Improved compatibility of Dropout with Int and Complex types.

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)
Expand Down
3 changes: 2 additions & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
realfptype = float(real(eltype(x)))
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end
Expand Down
5 changes: 5 additions & 0 deletions test/layers/normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ evalwgrad(f, x...) = pullback(f, x...)[1]

@testset "Dropout" begin
@testset for rng_kwargs in ((), (; rng = MersenneTwister()))
x = [1.0+0im,2.0+1im,3.0+3im]
@test x == Dropout(0.1; rng_kwargs...)(x)
@test x == evalwgrad(Dropout(0; rng_kwargs...), x)
@test zero(x) == evalwgrad(Dropout(1; rng_kwargs...), x)

x = [1.,2.,3.]
@test x == Dropout(0.1; rng_kwargs...)(x)
@test x == evalwgrad(Dropout(0; rng_kwargs...), x)
Expand Down

0 comments on commit 1930966

Please sign in to comment.