Skip to content

Commit

Permalink
Merge pull request #67 from LuxDL/ap/tstable
Browse files Browse the repository at this point in the history
Fixes to type stability of Zygote
  • Loading branch information
avik-pal committed May 11, 2024
2 parents 7ab9307 + 4bcf0a1 commit 0dea4f1
Show file tree
Hide file tree
Showing 24 changed files with 184 additions and 155 deletions.
2 changes: 0 additions & 2 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ steps:
- "normalization"
- "common_ops"
- "others"
- "normalization_sp"

# Downstream CUDA Tests
- group: ":telescope: Downstream CUDA"
Expand Down Expand Up @@ -116,7 +115,6 @@ steps:
- "normalization"
- "common_ops"
- "others"
- "normalization_sp"

# Downstream AMDGPU Tests
- group: ":telescope: Downstream AMD GPU"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ jobs:
- "normalization"
- "common_ops"
- "others"
- "normalization_sp"
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/Downgrade.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version: ['1.10']
test_group: ['normalization', 'common_ops', 'others', 'normalization_sp']
test_group: ['normalization', 'common_ops', 'others']
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v2
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "0.3.20"
version = "0.3.21"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxLibTrackercuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ const TR_BNParamType = Union{

function LuxLib.batchnorm(
x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType, bias::TR_BNParamType,
running_mean::TR_BNParamType, running_var::TR_BNParamType,
σ::F=identity; momentum::Real, training::Val, epsilon::Real) where {F}
running_mean::TR_BNParamType, running_var::TR_BNParamType, training::Val,
σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F}
rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training)
# NOTE: The following returns a tracked tuple so we can't do `first` on it
x_ = LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training)[1]
Expand Down
4 changes: 2 additions & 2 deletions ext/LuxLibcuDNNExt/LuxLibcuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ const CUDNN_BN_ARRAY_TYPE = Union{
const BNParamType = Union{Nothing, CuVector{<:Union{Float32, Float64}}}

function LuxLib.batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType,
running_mean::BNParamType, running_var::BNParamType, σ::F=identity;
momentum::Real, training::Val, epsilon::Real) where {F}
running_mean::BNParamType, running_var::BNParamType, training::Val,
σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F}
rm, rv = LuxLib._get_batchnorm_statistics(x, running_mean, running_var, training)
x_ = first(LuxLib.batchnorm_cudnn(rm, rv, scale, bias, x, momentum, epsilon, training))
return LuxLib.fast_activation!!(σ, x_), (; running_mean=rm, running_var=rv)
Expand Down
4 changes: 3 additions & 1 deletion src/LuxLib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ArrayInterface: ArrayInterface
using ChainRulesCore: ChainRulesCore
using ChainRulesCore: ChainRulesCore, NoTangent
using FastBroadcast: @..
using FastClosures: @closure
using GPUArraysCore: GPUArraysCore, AnyGPUArray
Expand Down Expand Up @@ -43,6 +43,8 @@ include("api/dense.jl")
include("api/conv.jl")
include("api/fast_activation.jl")

include("deprecations.jl")

export batchnorm, groupnorm, instancenorm, layernorm, alpha_dropout, dropout
export fused_dense_bias_activation, fused_conv_bias_activation
export fast_activation!!
Expand Down
39 changes: 18 additions & 21 deletions src/api/batchnorm.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@doc doc"""
batchnorm(x, scale, bias, running_mean, running_var, σ=identity; momentum, epsilon,
training)
batchnorm(x, scale, bias, running_mean, running_var, training, σ=identity,
momentum = 0.1f0, epsilon = 1f-5)
Batch Normalization. For details see [1].
Expand All @@ -15,13 +15,10 @@ accordingly.
- `bias`: Bias factor (``\beta``) (can be `nothing`)
- `running_mean`: Running mean (can be `nothing`)
- `running_var`: Running variance (can be `nothing`)
- `σ`: Activation function (default: `identity`)
## Keyword Arguments
- `momentum`: Momentum for updating running mean and variance
- `epsilon`: Value added to the denominator for numerical stability
- `training`: Set to `Val(true)` if running in training mode
- `σ`: Activation function (default: `identity`)
- `momentum`: Momentum for updating running mean and variance (default: `0.1f0`)
- `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`)
## Returns
Expand All @@ -43,8 +40,8 @@ fallback is used which is not highly optimized.
function batchnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector},
bias::Union{Nothing, <:AbstractVector},
running_mean::Union{Nothing, <:AbstractVector},
running_var::Union{Nothing, <:AbstractVector}, σ::F=identity;
momentum::Real, training::Val, epsilon::Real) where {F, N}
running_var::Union{Nothing, <:AbstractVector}, training::Val,
σ::F=identity, momentum::Real=0.1f0, epsilon::Real=1.0f-5) where {F, N}
x_, xm, xv = _normalization(x, _drop_forwarddiff_partials(running_mean),
_drop_forwarddiff_partials(running_var), scale, bias,
_get_batchnorm_reduce_dims(x), training, momentum, epsilon, σ)
Expand All @@ -57,17 +54,17 @@ end
return :($(Val(Tuple(collect([1:(N - 2); N])))))
end

function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{T}) where {T}
if T
# NNlib silently updates running_mean and running_var. Copying them!
rm = _copy_autodiff_barrier(running_mean)
rv = _copy_autodiff_barrier(running_var)
else
N = ndims(x)
dims = collect([1:(N - 2); N])
rm = running_mean === nothing ? mean(x; dims) : running_mean
rv = running_var === nothing ? var(x; mean=rm, dims, corrected=false) : running_var
end
function _get_batchnorm_statistics(x, running_mean, running_var, ::Val{true})
rm = _copy_autodiff_barrier(running_mean)
rv = _copy_autodiff_barrier(running_var)
return rm, rv
end

function _get_batchnorm_statistics(
x::AbstractArray{T, N}, running_mean, running_var, ::Val{false}) where {T, N}
dims = collect([1:(N - 2); N])
rm = running_mean === nothing ? mean(x; dims) : running_mean
rv = running_var === nothing ? var(x; mean=rm, dims, corrected=false) : running_var
return rm, rv
end

Expand Down
69 changes: 38 additions & 31 deletions src/api/dropout.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
@doc doc"""
dropout(rng::AbstractRNG, x, p, ::Val{training}, invp; dims)
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp;
dims)
dropout(rng::AbstractRNG, x, p, ::Val{training}, invp, dims)
dropout(rng::AbstractRNG, x, mask, p, ::Val{training}, ::Val{update_mask}, invp, dims)
Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see [1].
Expand All @@ -16,9 +15,6 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
- `Val(update_mask)`: If `true` then the mask is generated and used. Else, the `mask`
provided is directly used
- `invp`: Inverse of the probability
## Keyword Arguments
- `dims`: Dimensions along which dropout is applied
- `invp`: Inverse of the probability (``\frac{1}{p}``)
Expand All @@ -34,43 +30,33 @@ Dropout: Simple Way to prevent Neural Networks for Overfitting. For details see
overfitting." The journal of machine learning research 15.1 (2014): 1929-1958.
"""
function dropout(
rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T; dims) where {T}
rng::AbstractRNG, x::AbstractArray, p::T, ::Val{true}, invp::T, dims) where {T}
rng = LuxCore.replicate(rng)
mask = _generate_dropout_mask(rng, x, p, invp; dims)
return (x .* CRC.ignore_derivatives(mask), mask, rng)
end

function dropout(
rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T; dims) where {T}
rng::AbstractRNG, x::AbstractArray, p::T, ::Val{false}, ::T, dims) where {T}
return (x, x, rng)
end

function dropout(
rng::AbstractRNG, x::AbstractArray, p::T, t::Val; dims, invp::T=inv(p)) where {T}
return dropout(rng, x, p, t, invp; dims)
end

function dropout(rng::AbstractRNG, x::AbstractArray, mask::AbstractArray,
p::T, t::Val, ::Val{true}, invp::T; dims) where {T}
return dropout(rng, x, p, t; dims, invp)
function dropout(rng::AbstractRNG, x::AbstractArray, ::AbstractArray,
p::T, t::Val, ::Val{true}, invp::T, dims) where {T}
return dropout(rng, x, p, t, invp, dims)
end

function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
p::T, ::Val{true}, ::Val{false}, invp::T; dims) where {T, T1, T2, N}
size(x) != size(mask) && return dropout(rng, x, p, Val(true); dims, invp)
p::T, ::Val{true}, ::Val{false}, invp::T, dims) where {T, T1, T2, N}
size(x) != size(mask) && return dropout(rng, x, p, Val(true), invp, dims)
return x .* CRC.ignore_derivatives(mask), mask, rng
end

function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
p::T, ::Val{false}, ::Val{false}, invp::T; dims) where {T, T1, T2, N}
p::T, ::Val{false}, ::Val{false}, invp::T, dims) where {T, T1, T2, N}
return (x, mask, rng)
end

function dropout(rng::AbstractRNG, x::AbstractArray{T1, N}, mask::AbstractArray{T2, N},
p::T, t::Val, um::Val; dims, invp::T=inv(p)) where {T, T1, T2, N}
return dropout(rng, x, mask, p, t, um, invp; dims)
end

"""
alpha_dropout(rng::AbstractRNG, x, p, ::Val{training})
alpha_dropout(rng::AbstractRNG, x, p, ::Val{training}, α, A, B)
Expand Down Expand Up @@ -104,7 +90,6 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray{T}, p, t::Val{true}) w
α = T(-1.7580993408473766)
A = T(inv(sqrt((1 - p) * (1 + p * α^2))))
B = T(-A * α * p)

return alpha_dropout(rng, x, p, t, α, A, B)
end

Expand All @@ -113,12 +98,11 @@ function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, t::Val{false})
end

function alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{true}, α, A, B)
rng = LuxCore.replicate(rng)
noise = rand!(rng, similar(x, _dropout_fptype(x)))
# NOTE(@avik-pal): Combining the last 2 lines causes a compilation error for Tracker
# on GPU
y = ifelse.(noise .> p, x, α)
return (A .* y .+ B), rng
noise, rng = _alpha_dropout_noise(rng, x)
# NOTE: Combining the last 2 lines causes a compilation error for Tracker on GPU
y = _alpha_dropout_kernel(noise, p, x, α)
res = @. A * y + B
return res, rng
end

alpha_dropout(rng::AbstractRNG, x::AbstractArray, p, ::Val{false}, α, A, B) = (x, rng)
Expand All @@ -131,8 +115,31 @@ end

@inline _dropout_kernel(y, p, invp) = ifelse(y > p, invp, oftype(y, 0))

@inline _alpha_dropout_kernel(noise, p, x, α) = @. ifelse(noise > p, x, α)

## Zygote is otherwise type unstable
@inline function CRC.rrule(::typeof(_alpha_dropout_kernel), noise, p, x, α)
_cond = noise .> p
y = ifelse.(_cond, x, α)
_∇alpha_dropout_kernel = @closure Δ -> begin
return NoTangent(), NoTangent(), NoTangent(), (_cond .* Δ), sum(@.((1 - _cond)*Δ))
end
return y, _∇alpha_dropout_kernel
end

@inline _dropout_fptype(x) = float(real(eltype(x)))

CRC.@non_differentiable _dropout_fptype(::Any...)

@inline function _alpha_dropout_noise(rng, x)
rng = LuxCore.replicate(rng)
noise = similar(x, _dropout_fptype(x))
rand!(rng, noise)
return noise, rng
end

CRC.@non_differentiable _alpha_dropout_noise(::Any...)

@inline function _generate_dropout_mask(rng::AbstractRNG, x, p, invp; dims)
realfptype = _dropout_fptype(x)
y = rand!(rng, similar(x, realfptype, _dropout_shape(x, dims)))
Expand Down
11 changes: 6 additions & 5 deletions src/api/fast_activation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
fast_activation!!(σ::F, x) where {F}
fast_activation!!(σ::F, x::AbstractArray) where {F}
Compute `σ.(x)` with the best possible implementation available. If it is possible to
rewrite `x` in-place, it does so. If `x` is an immutable array, it falls back to the
Expand All @@ -19,8 +19,9 @@ generic implementation.
- Output Array with the same size as `x`
"""
@inline function fast_activation!!::F, x::AbstractArray) where {F}
σ === identity && return x
ArrayInterface.can_setindex(x) && return __fast_activation_impl!!(σ, x)
return σ.(x)
@inline fast_activation!!(::typeof(identity), x::AbstractArray) = x

@inline @generated function fast_activation!!::F, x::AbstractArray) where {F}
ArrayInterface.can_setindex(x) && :(return __fast_activation_impl!!(σ, x))
return :(σ.(x))
end
51 changes: 24 additions & 27 deletions src/api/groupnorm.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@doc doc"""
groupnorm(x, scale, bias; groups, epsilon)
groupnorm(x, scale, bias, groups, σ::F=identity, epsilon::Real=1.0f-5)
Group Normalization. For details see [1].
Expand All @@ -13,11 +13,9 @@ statistics.
- `x`: Input to be Normalized
- `scale`: Scale factor (``\gamma``) (can be `nothing`)
- `bias`: Bias factor (``\beta``) (can be `nothing`)
## Keyword Arguments
- `groups`: Number of groups
- `epsilon`: Value added to the denominator for numerical stability
- `σ`: Activation function (default: `identity`)
- `epsilon`: Value added to the denominator for numerical stability (default: `1f-5`)
## Returns
Expand All @@ -44,19 +42,10 @@ interface.
function groupnorm(x::AbstractArray{<:Union{Float32, Float64}, 4},
scale::AbstractVector{<:Union{Float32, Float64}},
bias::AbstractVector{<:Union{Float32, Float64}},
σ::F=identity; groups::Int, epsilon::Real) where {F}
_assert_same_backend(x, scale, bias)
if length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \
channels (N - 1 dim of the input array)."))
end
if size(x, 3) % groups != 0
throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups."))
end

groups::Int, σ::F=identity, epsilon::Real=1.0f-5) where {F}
_test_valid_groupnorm_arguments(x, scale, bias, groups)
# FIXME: We need to fuse the activation function into the kernel for optimal performance
return fast_activation!!(σ, __fast_groupnorm(x, groups, scale, bias, epsilon))
# return σ.(__fast_groupnorm(x, groups, scale, bias, epsilon))
end

# Separate this out for a cleaner rrule later on
Expand All @@ -66,16 +55,9 @@ end

# Slow Fallback (without custom Pullback Implementation)
function groupnorm(x::AbstractArray{<:Real, N}, scale::Union{Nothing, <:AbstractVector},
bias::Union{Nothing, <:AbstractVector}, σ::F=identity;
groups::Int, epsilon::Real) where {F, N}
_assert_same_backend(x, scale, bias)
if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \
channels (N - 1 dim of the input array)."))
end
if size(x, N - 1) % groups != 0
throw(ArgumentError(lazy"Number of channels $(size(x, 3)) must be divisible by the number of groups $groups."))
end
bias::Union{Nothing, <:AbstractVector}, groups::Int,
σ::F=identity, epsilon::Real=1.0f-5) where {F, N}
_test_valid_groupnorm_arguments(x, scale, bias, groups)

sz = size(x)
x_reshaped = reshape(x, sz[1:(N - 2)]..., sz[N - 1] ÷ groups, groups, sz[N])
Expand All @@ -94,7 +76,22 @@ function CRC.rrule(::typeof(__fast_groupnorm), x, groups, scale, bias, epsilon)
y, μ, σ⁻¹ = _groupnorm(x, groups, scale, bias, epsilon)
∇groupnorm = @closure Δ -> begin
∂x, ∂scale, ∂bias = _∇groupnorm(Δ, y, x, groups, scale, bias, μ, σ⁻¹)
return CRC.NoTangent(), ∂x, CRC.NoTangent(), ∂scale, ∂bias, CRC.NoTangent()
return NoTangent(), ∂x, NoTangent(), ∂scale, ∂bias, NoTangent()
end
return y, ∇groupnorm
end

function _test_valid_groupnorm_arguments(
x::AbstractArray{T, N}, scale, bias, groups) where {T, N}
_assert_same_backend(x, scale, bias)
if scale !== nothing && bias !== nothing && length(scale) != length(bias) != size(x, 3)
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of \
channels (N - 1 dim of the input array)."))
end
if size(x, N - 1) % groups != 0
throw(ArgumentError(lazy"Number of channels $(size(x, N - 1)) must be divisible by the number of groups $groups."))
end
return nothing
end

CRC.@non_differentiable _test_valid_groupnorm_arguments(::Any...)

2 comments on commit 0dea4f1

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/106617

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.21 -m "<description of version>" 0dea4f19f2b2572ee96b071faf3e4290b840d48e
git push origin v0.3.21

Please sign in to comment.