Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 37 additions & 21 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ julia> Flux.glorot_uniform(2, 3)

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG, dims::Integer...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims::Integer...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...)

ChainRulesCore.@non_differentiable glorot_uniform(::Any...)

"""
glorot_normal([rng=GLOBAL_RNG], dims...)

Expand Down Expand Up @@ -113,10 +115,12 @@ julia> Flux.glorot_normal(3, 2)

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG, dims::Integer...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims::Integer...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...)

ChainRulesCore.@non_differentiable glorot_normal(::Any...)

"""
kaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2)

Expand Down Expand Up @@ -146,14 +150,16 @@ julia> Flux.kaiming_uniform(3, 2)

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015.
"""
function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain = √2)
bound = Float32(√3 * gain / sqrt(first(nfan(dims...)))) # fan_in
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable kaiming_uniform(::Any...)

"""
kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)

Expand Down Expand Up @@ -183,14 +189,16 @@ julia> Flux.kaiming_normal(3, 2)

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015.
"""
function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain = √2f0)
std = Float32(gain / sqrt(first(nfan(dims...)))) # fan_in
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable kaiming_normal(::Any...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)

Expand Down Expand Up @@ -221,7 +229,7 @@ julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1
Expand All @@ -237,9 +245,11 @@ function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2,
return xs
end

truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(dims::Integer...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable truncated_normal(::Any...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)

Expand Down Expand Up @@ -307,6 +317,8 @@ end
orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable orthogonal(::Any...)

"""
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)

Expand Down Expand Up @@ -336,7 +348,7 @@ julia> Flux.sparse_init(3, 2, sparsity=0.1)

[1] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010.
"""
function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
if length(dims) != 2
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
end
Expand All @@ -348,9 +360,11 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(dims::Integer...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable sparse_init(::Any...)

"""
identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0)

Expand Down Expand Up @@ -415,30 +429,32 @@ julia> Flux.identity_init(3,3,2,2)
```
"""
# Assume bias
identity_init(cols; gain=1, shift=0) = zeros32(cols)
identity_init(cols::Integer; gain=1, shift=0) = zeros32(cols)

# Assume matrix multiplication
identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift)
identity_init(rows::Integer, cols::Integer; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift)

# Assume convolution
function identity_init(dims...; gain=1, shift=0)
function identity_init(dims::Integer...; gain=1, shift=0)
nin, nout = dims[end-1], dims[end]
centers = map(d -> cld(d, 2), dims[1:end-2])
weights = zeros32(dims)
weights = zeros32(dims...)
for i in 1:min(nin,nout)
weights[centers..., i, i] = gain
end
return circshift(weights, shift)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ones32(dims...) = Base.ones(Float32, dims...)
zeros32(dims...) = Base.zeros(Float32, dims...)
rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)
ChainRulesCore.@non_differentiable identity_init(::Any...)

ones32(dims::Integer...) = Base.ones(Float32, dims...)
zeros32(dims::Integer...) = Base.zeros(Float32, dims...)
rand32(dims::Integer...) = Base.rand(Float32, dims...)
randn32(dims::Integer...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, size...)
Expand Down
79 changes: 53 additions & 26 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, batch, unbatch,
sparse_init, identity_init, stack, unstack, batch, unbatch,
unsqueeze, params, loadparams!
using StatsBase: var, std
using Statistics, LinearAlgebra
Expand Down Expand Up @@ -69,6 +69,37 @@ end
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
end

@testset "Basics: $init" for init in [
glorot_uniform, glorot_normal,
kaiming_uniform, kaiming_normal,
orthogonal,
sparse_init,
truncated_normal,
identity_init,
]
if init == sparse_init
init = (args...) -> sparse_init(args...; sparsity=0.5)
else
# sparse_init is the only one which accepts only matrices:
@test size(init(3)) == (3,)
@test size(init(3, 4, 5)) == (3, 4, 5)
end
@test size(init(3, 4)) == (3, 4)
# only init(size...) is accepted:
@test_throws MethodError size(init((3, 4, 5))) == (3, 4, 5)

# rng, and currying:
@test size(init(MersenneTwister(1), 3, 4)) == (3, 4)
closure = init(MersenneTwister(1))
@test size(closure(3, 4)) == (3, 4)

# eltype, default Float32
@test eltype(init(3, 4)) == Float32

# @non_differentiable
@test gradient(x -> sum(x .* init(3, 4)), 5.0)[1] isa Number
end

@testset "glorot" begin
# glorot_uniform and glorot_normal should both yield a kernel with
# variance ≈ 2/(fan_in + fan_out)
Expand All @@ -78,7 +109,6 @@ end
fan_in, fan_out = nfan(dims...)
σ2 = 2 / (fan_in + fan_out)
@test 0.9σ2 < var(v) < 1.1σ2
@test eltype(v) == Float32
end
end
end
Expand All @@ -91,12 +121,10 @@ end
σ2 = sqrt(6/n_out)
@test -1σ2 < minimum(v) < -0.9σ2
@test 0.9σ2 < maximum(v) < 1σ2
@test eltype(v) == Float32

v = kaiming_normal(n_in, n_out)
σ2 = sqrt(2/n_out)
@test 0.9σ2 < std(v) < 1.1σ2
@test eltype(v) == Float32
end
end

Expand Down Expand Up @@ -125,54 +153,53 @@ end
@test_throws ArgumentError sparse_init(100, 100, 100, sparsity=0.1)
v = sparse_init(100, 100, sparsity=-0.1)
@test sum(v .== 0) == 0
@test eltype(v) == Float32
v = sparse_init(100, 100, sparsity=1.1)
@test sum(v .== 0) == length(v)
@test eltype(v) == Float32

for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)]
expected_zeros = ceil(Integer, n_in * sparsity)
v = sparse_init(n_in, n_out, sparsity=sparsity, std=σ)
@test all([sum(v[:,col] .== 0) == expected_zeros for col in 1:n_out])
@test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ
@test eltype(v) == Float32
end
end

@testset "truncated_normal" begin
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
m = truncated_normal(100, 100)
@test minimum(m) ≈ -2 atol = 0.05 # default arguments
@test maximum(m) ≈ 2 atol = 0.05
@test mean(m) ≈ 0 atol = 0.1

size100 = (100, 100, 100)
for (μ, σ, lo, hi) in [(0.0, 1, -2, 3), (1, 2, -4.0, 5.0)]
v = truncated_normal(size100...; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(minimum(v), lo; atol = 1f-1)
@test isapprox(maximum(v), hi; atol = 1f-1)
@test eltype(v) == Float32
@test isapprox(minimum(v), lo; atol = 1f-2)
@test isapprox(maximum(v), hi; atol = 1f-2)
@test eltype(v) == Float32 # despite some Float64 arguments
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
for (μ, σ, lo, hi) in [(6, 2, -100.0, 100), (-7.0, 10, -100, 100)]
v = truncated_normal(size100...; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(std(v), σ; atol = 1f-1)
@test eltype(v) == Float32
end
end

@testset "partial_application" begin
big = 1e9

partial_ku = kaiming_uniform(gain=big)
@test maximum(partial_ku(8, 8)) > big / 2
@test maximum(partial_ku(8, 8, gain=1)) < big / 2
@testset "Partial application" begin
partial_ku = kaiming_uniform(gain=1e9)
@test maximum(partial_ku(8, 8)) > 1e9 / 2
@test maximum(partial_ku(8, 8, gain=1)) < 1e9 / 2

partial_kn = kaiming_normal(gain=big)
@test maximum(partial_kn(8, 8)) > big / 2
@test maximum(partial_kn(8, 8, gain=1)) < big / 2
partial_kn = kaiming_normal(gain=1e9)
@test maximum(partial_kn(8, 8)) > 1e9 / 2
@test maximum(partial_kn(8, 8, gain=1)) < 1e9 / 2

partial_si = sparse_init(sparsity=1)
@test maximum(partial_si(8, 8)) == 0
@test maximum(partial_si(8, 8, sparsity=0)) > 0
end

@testset "identity_init" begin
import Flux: identity_init

@testset "Basic" begin
partial = identity_init(gain=3)
Expand Down