Skip to content

Commit

Permalink
More inline and fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
charleskawczynski committed Jul 26, 2022
1 parent c59c4bb commit 5645fe3
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 42 deletions.
28 changes: 23 additions & 5 deletions src/DataLayouts/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,10 @@ promote_parent_array_type(
Create an instance of type `T` from a tuple of field values `args`, bypassing
possible internal constructors. `T` should be a concrete type.
"""
Base.@propagate_inbounds @generated function bypass_constructor(::Type{T}, args) where {T}
Base.@propagate_inbounds @generated function bypass_constructor(
::Type{T},
args,
) where {T}
vars = ntuple(_ -> gensym(), fieldcount(T))
assign = [
:(@inbounds $var::$(fieldtype(T, i)) = getfield(args, $i)) for
Expand Down Expand Up @@ -166,7 +169,11 @@ end
Construct an object of type `S` from the values of `array`, optionally offset by `offset` from the start of the array.
"""
Base.@propagate_inbounds @generated function get_struct(array::AbstractArray{T}, ::Type{S}, offset) where {T, S}
Base.@propagate_inbounds @generated function get_struct(
array::AbstractArray{T},
::Type{S},
offset,
) where {T, S}
tup = :(())
for i in 1:fieldcount(S)
push!(
Expand Down Expand Up @@ -195,11 +202,18 @@ Base.@propagate_inbounds function get_struct(
return @inbounds array[offset + 1]
end

Base.@propagate_inbounds function get_struct(array::AbstractArray{T}, ::Type{S}) where {T, S}
Base.@propagate_inbounds function get_struct(
array::AbstractArray{T},
::Type{S},
) where {T, S}
@inbounds get_struct(array, S, 0)
end

Base.@propagate_inbounds @generated function set_struct!(array::AbstractArray{T}, val::S, offset) where {T, S}
Base.@propagate_inbounds @generated function set_struct!(
array::AbstractArray{T},
val::S,
offset,
) where {T, S}
ex = quote
Base.@_propagate_inbounds_meta
end
Expand All @@ -217,7 +231,11 @@ Base.@propagate_inbounds @generated function set_struct!(array::AbstractArray{T}
return ex
end

Base.@propagate_inbounds function set_struct!(array::AbstractArray{S}, val::S, offset) where {S}
Base.@propagate_inbounds function set_struct!(
array::AbstractArray{S},
val::S,
offset,
) where {S}
@inbounds array[offset + 1] = val
val
end
Expand Down
38 changes: 22 additions & 16 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ struct AxisTensor{
components::S
end

AxisTensor(
@inline AxisTensor(
axes::A,
components::S,
) where {
A <: Tuple{Vararg{AbstractAxis}},
S <: StaticArray{<:Tuple, T, N},
} where {T, N} = AxisTensor{T, N, A, S}(axes, components)

AxisTensor(axes::Tuple{Vararg{AbstractAxis}}, components) =
@inline AxisTensor(axes::Tuple{Vararg{AbstractAxis}}, components) =
AxisTensor(axes, SArray{Tuple{map(length, axes)...}}(components))

# if the axes are already defined
Expand Down Expand Up @@ -172,22 +172,23 @@ Returns a `StaticArray` containing the components of `a` in its stored basis.
"""
@inline components(a::AxisTensor) = getfield(a, :components)

@inline Base.getindex(v::AxisTensor, i...) = getindex(components(v), i...)
@inline Base.getindex(v::AxisTensor, i...) =
@inbounds getindex(components(v), i...)


@inline function Base.getindex(
v::AxisTensor{<:Any, 2, Tuple{A1, A2}},
::Colon,
i::Integer,
) where {A1, A2}
AxisVector(axes(v, 1), getindex(components(v), :, i))
@inbounds AxisVector(axes(v, 1), getindex(components(v), :, i))
end
function Base.getindex(
@inline function Base.getindex(
v::AxisTensor{<:Any, 2, Tuple{A1, A2}},
i::Integer,
::Colon,
) where {A1, A2}
AxisVector(axes(v, 2), getindex(components(v), i, :))
@inbounds AxisVector(axes(v, 2), getindex(components(v), i, :))
end


Expand Down Expand Up @@ -260,9 +261,10 @@ end
const AdjointAxisVector{T, A1, S} = Adjoint{T, AxisVector{T, A1, S}}

@inline components(va::AdjointAxisVector) = components(parent(va))'
@inline Base.getindex(va::AdjointAxisVector, i::Int) = getindex(components(va), i)
@inline Base.getindex(va::AdjointAxisVector, i::Int) =
@inbounds getindex(components(va), i)
@inline Base.getindex(va::AdjointAxisVector, i::Int, j::Int) =
getindex(components(va), i, j)
@inbounds getindex(components(va), i, j)

# 2-tensors
const Axis2Tensor{T, A, S} = AxisTensor{T, 2, A, S}
Expand Down Expand Up @@ -506,7 +508,10 @@ end
if $errcond
throw(InexactError(:transform, Ato, x))
end
@inbounds Axis2Tensor((ato, axes(x, 2)), SMatrix{$(length(Ito)), $M}($(vals...)))
@inbounds Axis2Tensor(
(ato, axes(x, 2)),
SMatrix{$(length(Ito)), $M}($(vals...)),
)
end
end

Expand Down Expand Up @@ -545,10 +550,11 @@ end
@inline transform(ato::CartesianAxis, v::CartesianTensor) = _transform(ato, v)
@inline transform(ato::LocalAxis, v::LocalTensor) = _transform(ato, v)

project(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
project(ato::ContravariantAxis, v::ContravariantTensor) = _project(ato, v)
project(ato::CartesianAxis, v::CartesianTensor) = _project(ato, v)
project(ato::LocalAxis, v::LocalTensor) = _project(ato, v)
@inline project(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
@inline project(ato::ContravariantAxis, v::ContravariantTensor) =
_project(ato, v)
@inline project(ato::CartesianAxis, v::CartesianTensor) = _project(ato, v)
@inline project(ato::LocalAxis, v::LocalTensor) = _project(ato, v)


"""
Expand Down Expand Up @@ -579,12 +585,12 @@ julia> [1.0,2.0] ⊗ (1.0, (a=2.0, b=3.0))
function outer end
const = outer

function outer(x::AbstractVector, y::AbstractVector)
@inline function outer(x::AbstractVector, y::AbstractVector)
x * y'
end
function outer(x::AbstractVector, y::Number)
@inline function outer(x::AbstractVector, y::Number)
x * y
end
function outer(x::AbstractVector, y)
@inline function outer(x::AbstractVector, y)
RecursiveApply.rmap(y -> x y, y)
end
44 changes: 27 additions & 17 deletions src/Geometry/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,10 @@ end
@inbounds transform(Contravariant123Axis(), u, local_geometry)[3, :]
end

Base.@propagate_inbounds Jcontravariant3(u::AxisTensor, local_geometry::LocalGeometry) =
local_geometry.J * contravariant3(u, local_geometry)
Base.@propagate_inbounds Jcontravariant3(
u::AxisTensor,
local_geometry::LocalGeometry,
) = local_geometry.J * contravariant3(u, local_geometry)

# required for curl-curl
@inline covariant3(
Expand Down Expand Up @@ -240,18 +242,22 @@ for op in (:transform, :project)
ax,
local_geometry.∂x∂ξ' * $op(dual(axes(local_geometry.∂x∂ξ, 1)), v),
)
@inline $op(ax::LocalAxis, v::CovariantTensor, local_geometry::LocalGeometry) =
$op(
ax,
local_geometry.∂ξ∂x' *
$op(dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
@inline $op(ax::CovariantAxis, v::LocalTensor, local_geometry::LocalGeometry) =
$op(
ax,
local_geometry.∂x∂ξ' *
$op(dual(axes(local_geometry.∂x∂ξ, 1)), v),
)
@inline $op(
ax::LocalAxis,
v::CovariantTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂ξ∂x' * $op(dual(axes(local_geometry.∂ξ∂x, 1)), v),
)
@inline $op(
ax::CovariantAxis,
v::LocalTensor,
local_geometry::LocalGeometry,
) = $op(
ax,
local_geometry.∂x∂ξ' * $op(dual(axes(local_geometry.∂x∂ξ, 1)), v),
)

# Contravariant <-> Cartesian
@inline $op(
Expand Down Expand Up @@ -312,11 +318,15 @@ for op in (:transform, :project)

@inline $op(ato::CovariantAxis, v::CovariantTensor, ::LocalGeometry) =
$op(ato, v)
@inline $op(ato::ContravariantAxis, v::ContravariantTensor, ::LocalGeometry) =
$op(ato, v)
@inline $op(
ato::ContravariantAxis,
v::ContravariantTensor,
::LocalGeometry,
) = $op(ato, v)
@inline $op(ato::CartesianAxis, v::CartesianTensor, ::LocalGeometry) =
$op(ato, v)
@inline $op(ato::LocalAxis, v::LocalTensor, ::LocalGeometry) = $op(ato, v)
@inline $op(ato::LocalAxis, v::LocalTensor, ::LocalGeometry) =
$op(ato, v)
end
end

Expand Down
3 changes: 2 additions & 1 deletion src/RecursiveApply/RecursiveApply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ const ⊞ = radd
# Adapted from Base/operators.jl for general nary operator fallbacks
for op in (:rmul, :radd)
@eval begin
@inline ($op)(a, b, c, xs...) = Base.afoldl($op, ($op)(($op)(a, b), c), xs...)
@inline ($op)(a, b, c, xs...) =
Base.afoldl($op, ($op)(($op)(a, b), c), xs...)
end
end

Expand Down
10 changes: 7 additions & 3 deletions test/Operators/finitedifference/column_benchmark_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ function benchmark_operators(z_elems, ::Type{FT}) where {FT}

(; cfield, ffield) = get_fields(z_elems, FT, :has_h_space)
benchmark_operators_base(trials, t_ave, cfield, ffield, :has_h_space)
test_results(t_ave)
return (; trials, t_ave)
end

function benchmark_operators_base(trials, t_ave, cfield, ffield, h_space)
Expand Down Expand Up @@ -388,13 +390,17 @@ function benchmark_operators_base(trials, t_ave, cfield, ffield, h_space)
benchmark_func!(t_ave, trials, op, cfield, ffield, h_space, #= verbose = =# false)
end

return nothing
end

function test_results(t_ave)
@test t_ave[(:no_h_space, op_GradientF2C!, :none)] < 500
@test t_ave[(:no_h_space, op_GradientF2C!, :SetValue, :SetValue)] < 500
@test t_ave[(:no_h_space, op_GradientC2F!, :SetGradient, :SetGradient)] < 500
@test t_ave[(:no_h_space, op_GradientC2F!, :SetValue, :SetValue)] < 500
@test t_ave[(:no_h_space, op_DivergenceF2C!, :none)] < 900
@test t_ave[(:no_h_space, op_DivergenceF2C!, :Extrapolate, :Extrapolate)] < 900
@test t_ave[(:no_h_space, op_DivergenceC2F!, :SetDivergence, :SetDivergence)] < 800
@test t_ave[(:no_h_space, op_DivergenceC2F!, :SetDivergence, :SetDivergence)] < 900
@test t_ave[(:no_h_space, op_InterpolateF2C!, :none)] < 500
@test t_ave[(:no_h_space, op_InterpolateC2F!, :SetValue, :SetValue)] < 500
@test t_ave[(:no_h_space, op_InterpolateC2F!, :Extrapolate, :Extrapolate)] < 500
Expand Down Expand Up @@ -448,8 +454,6 @@ function benchmark_operators_base(trials, t_ave, cfield, ffield, h_space)
@test_broken t_ave[(:has_h_space, op_div_interp_FF!, :none, :SetValue, :SetValue)] < 500
@test_broken t_ave[(:has_h_space, op_divgrad_uₕ!, :none, :SetValue, :Extrapolate)] < 500
@test_broken t_ave[(:has_h_space, op_divgrad_uₕ!, :none, :SetValue, :SetValue)] < 500 # different with/without h_space

return nothing
end

#! format: on

0 comments on commit 5645fe3

Please sign in to comment.