Skip to content

Commit

Permalink
Try #1113:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Feb 1, 2023
2 parents ba71538 + 18a9cf4 commit 6a894f7
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 47 deletions.
35 changes: 19 additions & 16 deletions src/DataLayouts/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -407,26 +407,29 @@ end

# broadcasting scalar assignment
# Performance optimization for the common identity scalar case: dest .= val
@inline function Base.copyto!(
function Base.copyto!(
dest::AbstractData,
bc::Base.Broadcast.Broadcasted{Style},
) where {Style <: Base.Broadcast.AbstractArrayStyle{0}}
) where {
Style <:
Union{Base.Broadcast.AbstractArrayStyle{0}, Base.Broadcast.Style{Tuple}},
}
bc = Base.Broadcast.instantiate(
Base.Broadcast.Broadcasted{Style}(bc.f, bc.args, ()),
)
@inbounds bc0 = bc[]
fill!(dest, bc0)
end

@inline function Base.copyto!(
function Base.copyto!(
dest::DataF{S},
bc::Union{DataF{S, A}, Base.Broadcast.Broadcasted{DataFStyle{A}}},
) where {S, A}
@inbounds dest[] = convert(S, bc[])
return dest
end

@inline function Base.copyto!(
function Base.copyto!(
dest::IJFH{S, Nij},
bc::Union{IJFH{S, Nij}, Base.Broadcast.Broadcasted{<:IJFHStyle{Nij}}},
) where {S, Nij}
Expand All @@ -439,7 +442,7 @@ end
return dest
end

@inline function Base.copyto!(
function Base.copyto!(
dest::IFH{S, Ni},
bc::Union{IFH{S, Ni}, Base.Broadcast.Broadcasted{<:IFHStyle{Ni}}},
) where {S, Ni}
Expand All @@ -453,7 +456,7 @@ end
end

# inline inner slab(::DataSlab2D) copy
@inline function Base.copyto!(
function Base.copyto!(
dest::IJF{S, Nij},
bc::Union{IJF{S, Nij, A}, Base.Broadcast.Broadcasted{IJFStyle{Nij, A}}},
) where {S, Nij, A}
Expand All @@ -465,7 +468,7 @@ end
end

# inline inner slab(::DataSlab1D) copy
@inline function Base.copyto!(
function Base.copyto!(
dest::IF{S, Ni},
bc::Base.Broadcast.Broadcasted{IFStyle{Ni, A}},
) where {S, Ni, A}
Expand All @@ -477,7 +480,7 @@ end
end

# inline inner column(::DataColumn) copy
@inline function Base.copyto!(
function Base.copyto!(
dest::VF{S},
bc::Union{VF{S, A}, Base.Broadcast.Broadcasted{VFStyle{A}}},
) where {S, A}
Expand All @@ -489,7 +492,7 @@ end
return dest
end

@inline function _serial_copyto!(
function _serial_copyto!(
dest::VIFH{S, Ni},
bc::Union{VIFH{S, Ni, A}, Base.Broadcast.Broadcasted{VIFHStyle{Ni, A}}},
) where {S, Ni, A}
Expand All @@ -503,7 +506,7 @@ end
return dest
end

@inline function _threaded_copyto!(
function _threaded_copyto!(
dest::VIFH{S, Ni},
bc::Base.Broadcast.Broadcasted{VIFHStyle{Ni, A}},
) where {S, Ni, A}
Expand All @@ -522,14 +525,14 @@ end
return dest
end

@inline function Base.copyto!(
function Base.copyto!(
dest::VIFH{S, Ni},
source::VIFH{S, Ni, A},
) where {S, Ni, A}
return _serial_copyto!(dest, source)
end

@inline function Base.copyto!(
function Base.copyto!(
dest::VIFH{S, Ni},
bc::Base.Broadcast.Broadcasted{VIFHStyle{Ni, A}},
) where {S, Ni, A}
Expand All @@ -539,7 +542,7 @@ end
return _serial_copyto!(dest, bc)
end

@inline function _serial_copyto!(
function _serial_copyto!(
dest::VIJFH{S, Nij},
bc::Union{VIJFH{S, Nij, A}, Base.Broadcast.Broadcasted{VIJFHStyle{Nij, A}}},
) where {S, Nij, A}
Expand All @@ -553,7 +556,7 @@ end
return dest
end

@inline function _threaded_copyto!(
function _threaded_copyto!(
dest::VIJFH{S, Nij},
bc::Base.Broadcast.Broadcasted{VIJFHStyle{Nij, A}},
) where {S, Nij, A}
Expand All @@ -572,14 +575,14 @@ end
return dest
end

@inline function Base.copyto!(
function Base.copyto!(
dest::VIJFH{S, Nij},
source::VIJFH{S, Nij, A},
) where {S, Nij, A}
return _serial_copyto!(dest, source)
end

@inline function Base.copyto!(
function Base.copyto!(
dest::VIJFH{S, Nij},
bc::Base.Broadcast.Broadcasted{VIJFHStyle{Nij, A}},
) where {S, Nij, A}
Expand Down
31 changes: 24 additions & 7 deletions src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@ FieldStyle(x::Base.Broadcast.Unknown) = x
Base.Broadcast.BroadcastStyle(::Type{Field{V, S}}) where {V, S} =
FieldStyle(DataStyle(V))

# Broadcasting over scalars (Ref or Tuple)
Base.Broadcast.BroadcastStyle(
::Base.Broadcast.AbstractArrayStyle{0},
fs::AbstractFieldStyle,
) = fs
Base.Broadcast.BroadcastStyle(
::Base.Broadcast.Style{Tuple},
fs::AbstractFieldStyle,
) = fs

Base.Broadcast.BroadcastStyle(
::FieldStyle{DS1},
Expand All @@ -42,15 +47,14 @@ _first(data::DataLayouts.VF) = data[]
_first(field::Field) = _first_data_layout(field_values(column(field, 1, 1, 1)))
_first(space::Spaces.AbstractSpace) =
_first_data_layout(field_values(column(space, 1, 1, 1)))
_first(bc::Base.Broadcast.Broadcasted) = _first.(bc.args) # Is this case necessary?
_first(x::Base.RefValue{T}) where {T} = x
unref(x::Ref{T}) where {T} = x.x
unref(T) = T
_first(bc::Base.Broadcast.Broadcasted) = _first(copy(bc))
_first(x::Ref{T}) where {T} = x.x
_first(x::Tuple{T}) where {T} = x[1]

function call_with_first(bc)
# Try calling with first applied to all arguments:
bc′ = Base.Broadcast.preprocess(nothing, bc)
first_args = map(arg -> unref(_first(arg)), bc′.args)
first_args = map(arg -> _first(arg), bc′.args)
bc.f(first_args...)
end

Expand Down Expand Up @@ -183,8 +187,8 @@ end
end
return space1
end
@inline Base.Broadcast.broadcast_shape(space::AbstractSpace, ::Tuple{}) = space
@inline Base.Broadcast.broadcast_shape(::Tuple{}, space::AbstractSpace) = space
@inline Base.Broadcast.broadcast_shape(space::AbstractSpace, ::Tuple) = space
@inline Base.Broadcast.broadcast_shape(::Tuple, space::AbstractSpace) = space

@inline Base.Broadcast.broadcast_shape(
pointspace::AbstractPointSpace,
Expand Down Expand Up @@ -232,6 +236,12 @@ end
)
return nothing
end
@inline function Base.Broadcast.check_broadcast_shape(
::AbstractSpace,
::Tuple{T},
) where {T}
return nothing
end
@inline function Base.Broadcast.check_broadcast_shape(
::AbstractSpace,
::AbstractPointSpace,
Expand Down Expand Up @@ -380,6 +390,13 @@ function Base.Broadcast.copyto!(
copyto!(Fields.field_values(field), bc)
return field
end
function Base.Broadcast.copyto!(
field::Field,
bc::Base.Broadcast.Broadcasted{Base.Broadcast.Style{Tuple}},
)
copyto!(Fields.field_values(field), bc)
return field
end

function Base.Broadcast.copyto!(field::Field, nt::NamedTuple)
copyto!(
Expand Down
7 changes: 4 additions & 3 deletions src/Operators/finitedifference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3109,6 +3109,7 @@ Base.@propagate_inbounds function getidx(
end

# unwap boxed scalars
@inline getidx(scalar::Tuple{T}, loc::Location, idx, hidx) where {T} = scalar[1]
@inline getidx(scalar::Ref, loc::Location, idx, hidx) = scalar[]
@inline getidx(field::Fields.PointField, loc::Location, idx, hidx) = field[]
@inline getidx(field::Fields.PointField, loc::Location, idx) = field[]
Expand Down Expand Up @@ -3275,7 +3276,7 @@ function Base.similar(
return Field(Eltype, sp)
end

@inline function _serial_copyto!(
function _serial_copyto!(
field_out::Field,
bc::Base.Broadcast.Broadcasted{S},
Ni::Int,
Expand All @@ -3288,7 +3289,7 @@ end
return field_out
end

@inline function _threaded_copyto!(
function _threaded_copyto!(
field_out::Field,
bc::Base.Broadcast.Broadcasted{S},
Ni::Int,
Expand All @@ -3305,7 +3306,7 @@ end
return field_out
end

@inline function Base.copyto!(
function Base.copyto!(
field_out::Field,
bc::Base.Broadcast.Broadcasted{S},
) where {S <: AbstractStencilStyle}
Expand Down
42 changes: 23 additions & 19 deletions test/Fields/field_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ end
include(joinpath(@__DIR__, "util_spaces.jl"))

# https://github.com/CliMA/ClimaCore.jl/issues/946
@testset "Allocations with broadcasting Refs" begin
@testset "Allocations with broadcasting Scalars" begin
FT = Float64
function foo!(Yx::Fields.Field)
Yx .= Ref(1) .+ Yx
Yx .= (1,) .+ Yx
return nothing
end
function foocolumn!(Yx::Fields.Field)
Fields.bycolumn(axes(Yx)) do colidx
Yx[colidx] .= Ref(1) .+ Yx[colidx]
Yx[colidx] .= (1,) .+ Yx[colidx]
nothing
end
return nothing
Expand Down Expand Up @@ -58,7 +58,7 @@ end
nothing
end
function callfill!(Y)
fill!(Y, Ref((; x = 2.0)))
fill!(Y, ((; x = 2.0),))
nothing
end
for space in all_spaces(FT)
Expand All @@ -80,13 +80,13 @@ function allocs_test1!(Y)
x = Y.x
FT = Spaces.undertype(axes(x))
I = sc(FT)
x .= x .+ Ref(I)
x .= x .+ (I,)
nothing
end
function allocs_test2!(Y)
x = Y.x
FT = Spaces.undertype(axes(x))
IR = Ref(sc(FT))
IR = (sc(FT),)
@. x += IR
nothing
end
Expand All @@ -96,15 +96,15 @@ function allocs_test1_column!(Y)
FT = Spaces.undertype(axes(x))
# I = sc(FT)
I = Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))
x[colidx] .= x[colidx] .+ Ref(I)
x[colidx] .= x[colidx] .+ (I,)
end
nothing
end
function allocs_test2_column!(Y)
Fields.bycolumn(axes(Y.x)) do colidx
x = Y.x
FT = Spaces.undertype(axes(x))
IR = Ref(sc(FT))
IR = (sc(FT),)
@. x[colidx] += IR
end
nothing
Expand All @@ -119,10 +119,10 @@ end

function allocs_test3_column!(x)
FT = Spaces.undertype(axes(x))
IR = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))))
IR = (Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))),)
@. x += IR
I = Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT)))
x .+= Ref(I)
x .+= (I,)
nothing
end

Expand Down Expand Up @@ -154,9 +154,9 @@ end
end
nothing

function allocs_test_Ref_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
function allocs_test_scalar_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
Fields.bycolumn(axes(S)) do colidx
allocs_test_Ref_with_compose_column!(
allocs_test_scalar_with_compose_column!(
S[colidx],
∂ᶠ𝕄ₜ∂ᶜρ[colidx],
∂ᶜρₜ∂ᶠ𝕄[colidx],
Expand All @@ -165,15 +165,15 @@ function allocs_test_Ref_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂
nothing
end

function allocs_test_Ref_with_compose_column!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
function allocs_test_scalar_with_compose_column!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
compose = Operators.ComposeStencils()
FT = Spaces.undertype(axes(S))
IR = Ref(Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))))
IR = (Operators.StencilCoefs{-1, 1}((zero(FT), one(FT), zero(FT))),)
@. S = compose(∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄) - IR
nothing
end

@testset "Allocations StencilCoefs Ref with ComposeStencils broadcasting" begin
@testset "Allocations StencilCoefs scalar with ComposeStencils broadcasting" begin
FT = Float64
for space in all_spaces(FT)
space isa Spaces.CenterExtrudedFiniteDifferenceSpace || continue
Expand All @@ -185,12 +185,16 @@ end
tridiag_type = Operators.StencilCoefs{-1, 1, NTuple{3, FT}}
S = Fields.Field(tridiag_type, fspace)

allocs_test_Ref_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
p = @allocated allocs_test_Ref_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
allocs_test_scalar_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
p = @allocated allocs_test_scalar_with_compose!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
@test p == 0

allocs_test_Ref_with_compose_column!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
p = @allocated allocs_test_Ref_with_compose_column!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
allocs_test_scalar_with_compose_column!(S, ∂ᶠ𝕄ₜ∂ᶜρ, ∂ᶜρₜ∂ᶠ𝕄)
p = @allocated allocs_test_scalar_with_compose_column!(
S,
∂ᶠ𝕄ₜ∂ᶜρ,
∂ᶜρₜ∂ᶠ𝕄,
)
@test p == 0
end
end
Expand Down
4 changes: 2 additions & 2 deletions test/Operators/finitedifference/opt_examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function alloc_test_derivative(cfield, ffield, ∇c, ∇f)
p = @allocated begin
c∇closure()
end
@test_broken p == 0
@test p == 0

##### C2F
# wvec = Geometry.WVector # cannot re-define, otherwise many allocations
Expand Down Expand Up @@ -168,7 +168,7 @@ function alloc_test_operators_in_loops(cfield, ffield)
p = @allocated begin
c∇closure()
end
@test_broken p == 0
@test p == 0
end
end
function alloc_test_nested_expressions_1(cfield, ffield)
Expand Down

0 comments on commit 6a894f7

Please sign in to comment.