Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special case make_zer(array) for perf #1415

Merged
merged 2 commits into from
May 7, 2024
Merged

Special case make_zer(array) for perf #1415

merged 2 commits into from
May 7, 2024

Conversation

wsmoses
Copy link
Member

@wsmoses wsmoses commented May 6, 2024

No description provided.

@@ -1177,6 +1177,10 @@ function allocate_sret!(gutils::API.EnzymeGradientUtilsRef, N)
end
end

@inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N}
@inline function EnzymeCore.make_zero(x::Array{FT,N})::Array{FT,
N} where {FT<:AbstractFloat,N}

@inline function EnzymeCore.make_zero(x::Array{FT, N})::Array{FT, N} where {FT <: AbstractFloat, N}
return Base.zero(x)
end

@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive, RT<:AbstractFloat}
@inline function EnzymeCore.make_zero(::Type{RT}, seen::IdDict, prev::RT,
::Val{copy_if_inactive}=Val(false))::RT where {copy_if_inactive,
RT<:AbstractFloat}

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -67,6 +67,13 @@ end
end)
end

@inline function vaTypeof(args::Vararg{Any, N}) where N
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function vaTypeof(args::Vararg{Any, N}) where N
@inline function vaTypeof(args::Vararg{Any,N}) where {N}

Comment on lines +72 to +74
Base.@_inline_meta
Core.Typeof(args[i])
end)...}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Base.@_inline_meta
Core.Typeof(args[i])
end)...}
Base.@_inline_meta
return Core.Typeof(args[i])
end)...}

Comment on lines +111 to +112

@inline function refn_seed(x::T) where T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function refn_seed(x::T) where T
@inline function refn_seed(x::T) where {T}

end
end

@inline function imfn_seed(x::T) where T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function imfn_seed(x::T) where T
@inline function imfn_seed(x::T) where {T}

end
end

@inline function seed_complex_args(seen, seen2, args::Vararg{Annotation, Nargs}) where {Nargs}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function seed_complex_args(seen, seen2, args::Vararg{Annotation, Nargs}) where {Nargs}
@inline function seed_complex_args(seen, seen2,
args::Vararg{Annotation,Nargs}) where {Nargs}

if Compiler.active_reg_inner(X, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState
dx = Ref(make_zero(x))
autodiff(Reverse, f∘only, Active, Duplicated(Ref(x), dx))
autodiff(rm, f∘only, Active, Duplicated(Ref(x), dx))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
autodiff(rm, fonly, Active, Duplicated(Ref(x), dx))
autodiff(rm, f only, Active, Duplicated(Ref(x), dx))

@@ -362,6 +362,13 @@ end
end
end

@inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function staticInTup(::Val{T}, tup::NTuple{N, Val}) where {T, N}
@inline function staticInTup(::Val{T}, tup::NTuple{N,Val}) where {T,N}

src/compiler.jl Outdated
Comment on lines 367 to 369
Base.@_inline_meta
T == tup[i]
end)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Base.@_inline_meta
T == tup[i]
end)
Base.@_inline_meta
return T == tup[i]
end)

Comment on lines 372 to 373
@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing, UInt}, ::Val{justActive}=Val(false), ::Val{UnionSret}=Val(false))::ActivityState where {ST,T, justActive, UnionSret}
@inline function active_reg_inner(::Type{T}, seen::ST, world::Union{Nothing,UInt},
::Val{justActive}=Val(false),
::Val{UnionSret}=Val(false))::ActivityState where {ST,T,
justActive,
UnionSret}

Comment on lines +481 to +488
Base.@_inline_meta
sT = T.parameters[i]
if sT isa Core.TypeofVararg
Any
else
sT
end
end)...}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Base.@_inline_meta
sT = T.parameters[i]
if sT isa Core.TypeofVararg
Any
else
sT
end
end)...}
Base.@_inline_meta
sT = T.parameters[i]
if sT isa Core.TypeofVararg
Any
else
sT
end
end)...}

ntuple(Val(Nargs)) do i
Base.@_inline_meta
if args[i] isa Active
Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), results[1][i][3], imfn_seed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Compiler.recursive_add(Compiler.recursive_add(results[1][i][1], results[1][i][2], refn_seed), results[1][i][3], imfn_seed)
Compiler.recursive_add(Compiler.recursive_add(results[1][i][1],
results[1][i][2], refn_seed),
results[1][i][3], imfn_seed)

src/Enzyme.jl Outdated
Comment on lines 506 to 507
@inline function autodiff_deferred(mode::CMode, f::F, args::Varargs{Annotation, Nargs}) where {F, CMode<:Mode, Nargs}
autodiff_deferred(mode, Const(f), args...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function autodiff_deferred(mode::CMode, f::F, args::Varargs{Annotation, Nargs}) where {F, CMode<:Mode, Nargs}
autodiff_deferred(mode, Const(f), args...)
@inline function autodiff_deferred(mode::CMode, f::F,
args::Varargs{Annotation,Nargs}) where {F,CMode<:Mode,
Nargs}
return autodiff_deferred(mode, Const(f), args...)

src/Enzyme.jl Outdated
Comment on lines 509 to 510
@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Varargs{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs}
autodiff_deferred(mode, Const(f), RT, args...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT}, args::Varargs{Annotation, Nargs}) where {F, RT<:Annotation, CMode<:Mode, Nargs}
autodiff_deferred(mode, Const(f), RT, args...)
@inline function autodiff_deferred(mode::CMode, f::F, ::Type{RT},
args::Varargs{Annotation,Nargs}) where {F,RT<:Annotation,
CMode<:Mode,
Nargs}
return autodiff_deferred(mode, Const(f), RT, args...)


fty = Merger{seen,typeof(world),justActive, UnionSret}(world)
fty = Merger{seen2,typeof(world),justActive, UnionSret}(world)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
fty = Merger{seen2,typeof(world),justActive, UnionSret}(world)
fty = Merger{seen2,typeof(world),justActive,UnionSret}(world)

@@ -521,7 +536,7 @@ end
return res
end

Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode))
@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T, convert(API.CDerivativeMode, mode))
@inline Enzyme.guess_activity(::Type{T}, mode::Enzyme.Mode) where {T} = guess_activity(T,
convert(API.CDerivativeMode,
mode))


tx = (1.0, 2.0, 3.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
tx = (1.0, 2.0, 3.0)


tx = (1.0, 2.0, 3.0)

@inferred Enzyme.Compiler.active_reg_inner(Tuple{Float64,Float64,Float64}, (), nothing, Val(true))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inferred Enzyme.Compiler.active_reg_inner(Tuple{Float64,Float64,Float64}, (), nothing, Val(true))
tx = (1.0, 2.0, 3.0)
@inferred Enzyme.Compiler.active_reg_inner(Tuple{Float64,Float64,Float64}, (), nothing,
Val(true))


@inferred Enzyme.Compiler.active_reg_inner(Tuple{Float64,Float64,Float64}, (), nothing, Val(true))
@inferred Enzyme.make_zero(tx)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change


@inferred gradient(Reverse, abssum, tx)
@inferred gradient(Forward, abssum, tx)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@inferred gradient(Forward, abssum, tx)

end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@wsmoses wsmoses force-pushed the makezeroarr branch 3 times, most recently from 6312f06 to b25a7a6 Compare May 7, 2024 01:11
Comment on lines +367 to +369
Base.@_inline_meta
Val(T) == tup[i]
end)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Base.@_inline_meta
Val(T) == tup[i]
end)
Base.@_inline_meta
return Val(T) == tup[i]
end)

@@ -556,8 +571,9 @@ result, ∂v, ∂A
```
"""
@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI}, ::Type{FA}, ::Type{A}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT,RABI<:ABI, Nargs}
@inline function autodiff_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,
ModifiedBetweenT,RABI}, ::Type{FA},
::Type{A},
args::Vararg{Type{<:Annotation},Nargs}) where {FA<:Annotation,
A<:Annotation,
ReturnPrimal,
ReturnShadow,
Width,
ModifiedBetweenT,
RABI<:ABI,
Nargs}

@wsmoses wsmoses force-pushed the makezeroarr branch 3 times, most recently from adc9221 to ccb1566 Compare May 7, 2024 01:26
Comment on lines +416 to +423
@inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.ReverseWithPrimal, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.Forward, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x,x))
@inferred autodiff(Enzyme.Forward, abssum, DuplicatedNoNeed, Duplicated(x,x))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.ReverseWithPrimal, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.Forward, abssum, Duplicated(x,x))
@inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x,x))
@inferred autodiff(Enzyme.Forward, abssum, DuplicatedNoNeed, Duplicated(x,x))
@inferred autodiff(Enzyme.Reverse, abssum, Duplicated(x, x))
@inferred autodiff(Enzyme.ReverseWithPrimal, abssum, Duplicated(x, x))
@inferred autodiff(Enzyme.ReverseHolomorphic, abssum, Duplicated(x, x))
@inferred autodiff(Enzyme.ReverseHolomorphicWithPrimal, abssum, Duplicated(x, x))
@inferred autodiff(Enzyme.Forward, abssum, Duplicated(x, x))
@inferred autodiff(Enzyme.Forward, abssum, Duplicated, Duplicated(x, x))
@inferred autodiff(Enzyme.Forward, abssum, DuplicatedNoNeed, Duplicated(x, x))

@@ -1029,7 +1044,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2))
(3.0, 2.0)
```
"""
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X<:Array, chunk}
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk}
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{chunk};
shadow=chunkedonehot(x, Val(chunk))) where {F,X,chunk}

@@ -1039,7 +1054,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2))
tupleconcat(tmp...)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
tupleconcat(tmp...)
return tupleconcat(tmp...)

@@ -1039,7 +1054,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2))
tupleconcat(tmp...)
end

@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X<:Array}
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X}
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X}

@@ -1039,7 +1054,7 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2))
tupleconcat(tmp...)
end

@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X<:Array}
@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X}
ntuple(length(shadow)) do i
autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1]
return autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1]

@@ -780,10 +795,10 @@ result, ∂v, ∂A
(7.26, 2.2, [3.3])
```
"""
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A<:Annotation, A2, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI}, ::Type{TapeType}, ::Type{FA}, ::Type{A2}, args::Vararg{Type{<:Annotation}, Nargs}) where {FA<:Annotation, A2<:Annotation, TapeType, ReturnPrimal,ReturnShadow,Width,ModifiedBetweenT, RABI<:ABI, Nargs}
@inline function autodiff_deferred_thunk(::ReverseModeSplit{ReturnPrimal,ReturnShadow,Width,
ModifiedBetweenT,RABI},
::Type{TapeType}, ::Type{FA}, ::Type{A2},
args::Vararg{Type{<:Annotation},Nargs}) where {FA<:Annotation,
A2<:Annotation,
TapeType,
ReturnPrimal,
ReturnShadow,
Width,
ModifiedBetweenT,
RABI<:ABI,
Nargs}

@wsmoses wsmoses merged commit d62e4f3 into main May 7, 2024
39 of 48 checks passed
@wsmoses wsmoses deleted the makezeroarr branch May 7, 2024 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant