Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ celu
elu
gelu
gelu_tanh
gelu_sigmoid
gelu_erf
hardsigmoid
sigmoid_fast
Expand Down
42 changes: 37 additions & 5 deletions src/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

ACTIVATIONS = [
:σ, :hardσ, :hardtanh, :relu,
:leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_erf, :swish, :hardswish, :selu,
:leakyrelu, :relu6, :rrelu, :elu, :gelu_tanh, :gelu_sigmoid, :gelu_erf, :swish, :hardswish, :selu,
:celu, :softplus, :softsign, :logσ, :logcosh,
:mish, :tanhshrink, :softshrink, :trelu, :lisht,
:tanh_fast, :sigmoid_fast,
Expand Down Expand Up @@ -305,6 +305,10 @@ deriv_elu(Ω, α=1) = ifelse(Ω ≥ 0, one(Ω), Ω + oftype(Ω, α))

Activation function from ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415) using tanh approximation.

This implementation uses `tanh` which allows for better pattern matching and fusion in optimizing
compilers compared to the sigmoid-based implementation. For a potentially faster implementation
that uses `sigmoid_fast`, see [`gelu_sigmoid`](@ref).

```julia-repl
julia> lineplot(gelu_tanh, -2, 2, height=7)
┌────────────────────────────────────────┐
Expand Down Expand Up @@ -337,16 +341,43 @@ julia> lineplot!(ans, swish)
"""
function gelu_tanh(x)
α = oftf(x, 0.044715)
# λ = oftf(x, gelu_λ)
# x/2 * (1 + tanh(λ * (x + α * x^3))) # Standard implementation, for reference
λλ = oftf(x, gelu_2λ)
x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x))) # This is faster & more accurate
λ = oftf(x, gelu_λ)
x/2 * (1 + tanh_fast(λ * (x + α * x^3)))
end

const gelu_λ = √(2 / π)
const gelu_2λ = √(8 / π)

function deriv_gelu_tanh(x)
α = oftf(x, 0.044715)
α2 = oftf(x, 0.08943)
λ = oftf(x, gelu_λ)
x2 = x * x
t = muladd(x2, α, one(x))
z = λ * x * t
Ω = tanh_fast(z)
sech2 = 1 - Ω^2
(1 + Ω)/2 + x * λ * muladd(x2, α2, t) * sech2 / 2
Comment on lines +352 to +360
Copy link
Member

@CarloLucibello CarloLucibello Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot explain this derivation

end

"""
gelu_sigmoid(x) = x * σ(√(8/π) * (x + 0.044715x^3))

Alternative implementation of the GELU activation function using `sigmoid` instead of `tanh`.
This is mathematically equivalent to [`gelu_tanh`](@ref) but may be faster in some cases.

The sigmoid-based implementation may prevent pattern matching and fusion in some optimizing
compilers. Use [`gelu_tanh`](@ref) if you need better compiler optimization support.

See ["Gaussian Error Linear Units"](https://arxiv.org/abs/1606.08415).
"""
function gelu_sigmoid(x)
α = oftf(x, 0.044715)
λλ = oftf(x, gelu_2λ)
x * sigmoid_fast(λλ * x * muladd(x^2, α, one(x)))
end

function deriv_gelu_sigmoid(x)
α = oftf(x, 0.044715)
α2 = oftf(x, 0.08943)
λλ = oftf(x, gelu_2λ)
Expand Down Expand Up @@ -896,6 +927,7 @@ UNARY_ACTS = [ # f, dfdx
# rrelu is random, can't write a rule.
(:elu, :(deriv_elu(Ω))),
(:gelu_tanh, :(deriv_gelu_tanh(x))),
(:gelu_sigmoid, :(deriv_gelu_sigmoid(x))),
(:gelu_erf, :(deriv_gelu_erf(x))),
(:swish, :(Ω + sigmoid_fast(x) * (1 - Ω))),
(:hardswish, :(deriv_hardswish(x))),
Expand Down
9 changes: 6 additions & 3 deletions test/activations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
@test elu(0.0) == 0.0
@test gelu(0.0) == 0.0
@test gelu_tanh(0.0) == 0.0
@test gelu_sigmoid(0.0) == 0.0
@test gelu_erf(0.0) == 0.0
@test swish(0.0) == 0.0
@test hardswish(0.0) == 0.0
Expand All @@ -37,8 +38,9 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
@test relu6(1.0) == 1.0
@test rrelu(1.0) == 1.0
@test elu(1.0) == 1.0
@test gelu(1.0) == 0.8411919906082768
@test gelu_tanh(1.0) == 0.8411919906082768
@test gelu(1.0) ≈ 0.8411919906082768
@test gelu_tanh(1.0) ≈ 0.8411919906082768
@test gelu_sigmoid(1.0) ≈ 0.8411919906082768
@test gelu_erf(1.0) == 0.8413447460685429
@test swish(1.0) == sigmoid(1.0)
@test hardswish(1.0) == hardsigmoid(1.0)
Expand All @@ -63,6 +65,7 @@ BINARY_ACTIVATIONS = filter(f -> hasmethod(f, Tuple{Float64, Float64}), ACTIVATI
@test elu(-1.0) == exp(-1.0) - 1.0
@test gelu(-1.0) ≈ -0.15880800939172324
@test gelu_tanh(-1.0) ≈ -0.15880800939172324
@test gelu_sigmoid(-1.0) ≈ -0.15880800939172324
@test gelu_erf(-1.0) == -0.15865525393145707
@test swish(-1.0) == -sigmoid(-1.0)
@test hardswish(-1.0) == -hardsigmoid(-1.0)
Expand Down Expand Up @@ -120,7 +123,7 @@ end
a == softsign && continue
@test !isnan(a(Inf32))

a in [gelu, gelu_tanh, gelu_erf, swish, hardswish, logcosh, mish] && continue
a in [gelu, gelu_tanh, gelu_sigmoid, gelu_erf, swish, hardswish, logcosh, mish] && continue
@test !isnan(a(-Inf32))
end
end
Expand Down
Loading