Skip to content

Commit

Permalink
Try #829:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Jul 26, 2022
2 parents 000daee + e8f8e13 commit 6f23efe
Show file tree
Hide file tree
Showing 14 changed files with 390 additions and 244 deletions.
12 changes: 4 additions & 8 deletions src/DataLayouts/struct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,12 +193,12 @@ function get_struct(array::AbstractArray{T}, ::Type{S}, offset) where {T, S}
end

# recursion base case: hit array type is the same as the struct leaf type
@propagate_inbounds function get_struct(
@inline function get_struct(
array::AbstractArray{S},
::Type{S},
offset,
) where {S}
return array[offset + 1]
return @inbounds array[offset + 1]
end

@inline function get_struct(array::AbstractArray{T}, ::Type{S}) where {T, S}
Expand Down Expand Up @@ -235,12 +235,8 @@ function set_struct!(array::AbstractArray{T}, val::S, offset) where {T, S}
end
end

@propagate_inbounds function set_struct!(
array::AbstractArray{S},
val::S,
offset,
) where {S}
array[offset + 1] = val
@inline function set_struct!(array::AbstractArray{S}, val::S, offset) where {S}
@inbounds array[offset + 1] = val
val
end

Expand Down
4 changes: 2 additions & 2 deletions src/Fields/Fields.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ const ColumnField{V, S} =
slab(field::Field, inds...) =
Field(slab(field_values(field), inds...), slab(axes(field), inds...))

column(field::Field, inds...) =
@inline column(field::Field, inds...) =
Field(column(field_values(field), inds...), column(axes(field), inds...))
column(field::FiniteDifferenceField, inds...) = field
@inline column(field::FiniteDifferenceField, inds...) = field



Expand Down
2 changes: 1 addition & 1 deletion src/Fields/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function slab(
Base.Broadcast.Broadcasted{Style}(bc.f, _args, _axes)
end

function column(
@inline function column(
bc::Base.Broadcast.Broadcasted{Style},
i,
j,
Expand Down
14 changes: 10 additions & 4 deletions src/Fields/indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@ end

Base.getindex(field::Field, colidx::ColumnIndex) = column(field, colidx)

function column(field::SpectralElementField1D, colidx::ColumnIndex{1})
@inline function column(field::SpectralElementField1D, colidx::ColumnIndex{1})
column(field, colidx.ij[1], colidx.h)
end
function column(field::ExtrudedFiniteDifferenceField, colidx::ColumnIndex{1})
@inline function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{1},
)
column(field, colidx.ij[1], colidx.h)
end
function column(field::SpectralElementField2D, colidx::ColumnIndex{2})
@inline function column(field::SpectralElementField2D, colidx::ColumnIndex{2})
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end
function column(field::ExtrudedFiniteDifferenceField, colidx::ColumnIndex{2})
@inline function column(
field::ExtrudedFiniteDifferenceField,
colidx::ColumnIndex{2},
)
column(field, colidx.ij[1], colidx.ij[2], colidx.h)
end

Expand Down
72 changes: 51 additions & 21 deletions src/Geometry/axistensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ ClimaCore.Geometry.CartesianAxis{(1, 2, 3)}()
function dual end

struct CovariantAxis{I} <: AbstractAxis{I} end
symbols(::CovariantAxis) = (:u₁, :u₂, :u₃)
@inline symbols(::CovariantAxis) = (:u₁, :u₂, :u₃)

struct ContravariantAxis{I} <: AbstractAxis{I} end
symbols(::ContravariantAxis) = (:u¹, :u², :u³)
@inline symbols(::ContravariantAxis) = (:u¹, :u², :u³)
dual(::CovariantAxis{I}) where {I} = ContravariantAxis{I}()
dual(::ContravariantAxis{I}) where {I} = CovariantAxis{I}()

struct LocalAxis{I} <: AbstractAxis{I} end
symbols(::LocalAxis) = (:u, :v, :w)
@inline symbols(::LocalAxis) = (:u, :v, :w)
dual(::LocalAxis{I}) where {I} = LocalAxis{I}()

struct CartesianAxis{I} <: AbstractAxis{I} end
symbols(::CartesianAxis) = (:u1, :u2, :u3)
@inline symbols(::CartesianAxis) = (:u1, :u2, :u3)
dual(::CartesianAxis{I}) where {I} = CartesianAxis{I}()

coordinate_axis(::Type{<:XPoint}) = (1,)
Expand All @@ -59,21 +59,46 @@ coordinate_axis(::Type{<:LatLongPoint}) = (1, 2)

coordinate_axis(coord::AbstractPoint) = coordinate_axis(typeof(coord))

@inline function idxin(I::Tuple{Vararg{Int}}, i::Int)
N = length(I)
for n in 1:N
if I[n] == i
return n
@inline idxin(I::Tuple{Int}, i::Int) = 1

@inline function idxin(I::Tuple{Int, Int}, i::Int)
@inbounds begin
if I[1] == i
return 1
else
return 2
end
end
end

@inline function idxin(I::Tuple{Int, Int, Int}, i::Int)
@inbounds begin
if I[1] == i
return 1
elseif I[2] == i
return 2
else
return 3
end
end
return nothing
end

#= For avoiding JET failures =#
error_on_no_name_found() = true

@inline function symidx(ax::AbstractAxis{I}, name::Symbol) where {I}
S = symbols(ax)
name == S[1] ? idxin(I, 1) :
name == S[2] ? idxin(I, 2) :
name == S[3] ? idxin(I, 3) : error("$ax has no symbol $name")
if name == S[1]
return idxin(I, 1)
elseif name == S[2]
return idxin(I, 2)
elseif name == S[3]
return idxin(I, 3)
elseif error_on_no_name_found()
error("$ax has no symbol $name")
else
return -1 # for type stability
end
end

# most of these are required for printing
Expand Down Expand Up @@ -222,9 +247,13 @@ const CartesianVector{T, I, S} = AxisVector{T, CartesianAxis{I}, S}
const LocalVector{T, I, S} = AxisVector{T, LocalAxis{I}, S}

Base.propertynames(x::AxisVector) = symbols(axes(x, 1))
function Base.getproperty(x::AxisVector, name::Symbol)
@inline function Base.getproperty(x::AxisVector, name::Symbol)
n = symidx(axes(x, 1), name)
isnothing(n) ? zero(eltype(x)) : components(x)[n]
if isnothing(n)
zero(eltype(x))
else
@inbounds components(x)[n]
end
end


Expand Down Expand Up @@ -349,7 +378,7 @@ function Base.:(-)(A::Axis2Tensor, b::LinearAlgebra.UniformScaling)
AxisTensor(axes(A), components(A) - b)
end

function _transform(
@inline function _transform(
ato::Ato,
x::AxisVector{T, Afrom, SVector{N, T}},
) where {Ato <: AbstractAxis{I}, Afrom <: AbstractAxis{I}} where {I, T, N}
Expand Down Expand Up @@ -417,7 +446,7 @@ end
return :(AxisVector(ato, SVector($(vals...))))
end

function _transform(
@inline function _transform(
ato::Ato,
x::Axis2Tensor{T, Tuple{Afrom, A2}},
) where {
Expand Down Expand Up @@ -508,10 +537,11 @@ end
))
end

transform(ato::CovariantAxis, v::CovariantTensor) = _transform(ato, v)
transform(ato::ContravariantAxis, v::ContravariantTensor) = _transform(ato, v)
transform(ato::CartesianAxis, v::CartesianTensor) = _transform(ato, v)
transform(ato::LocalAxis, v::LocalTensor) = _transform(ato, v)
@inline transform(ato::CovariantAxis, v::CovariantTensor) = _transform(ato, v)
@inline transform(ato::ContravariantAxis, v::ContravariantTensor) =
_transform(ato, v)
@inline transform(ato::CartesianAxis, v::CartesianTensor) = _transform(ato, v)
@inline transform(ato::LocalAxis, v::LocalTensor) = _transform(ato, v)

project(ato::CovariantAxis, v::CovariantTensor) = _project(ato, v)
project(ato::ContravariantAxis, v::ContravariantTensor) = _project(ato, v)
Expand Down
43 changes: 24 additions & 19 deletions src/Geometry/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,36 +133,41 @@ LocalVector(u::ContravariantVector{<:Any, (3,)}, ::LocalGeometry{(1, 2)}) =
AxisVector(WAxis(), components(u))


covariant1(u::AxisVector, local_geometry::LocalGeometry) =
@inline covariant1(u::AxisVector, local_geometry::LocalGeometry) =
CovariantVector(u, local_geometry).u₁
covariant2(u::AxisVector, local_geometry::LocalGeometry) =
@inline covariant2(u::AxisVector, local_geometry::LocalGeometry) =
CovariantVector(u, local_geometry).u₂
covariant3(u::AxisVector, local_geometry::LocalGeometry) =
@inline covariant3(u::AxisVector, local_geometry::LocalGeometry) =
CovariantVector(u, local_geometry).u₃

contravariant1(u::AxisVector, local_geometry::LocalGeometry) =
@inline contravariant1(u::AxisVector, local_geometry::LocalGeometry) =
transform(Contravariant123Axis(), u, local_geometry).
contravariant2(u::AxisVector, local_geometry::LocalGeometry) =
@inline contravariant2(u::AxisVector, local_geometry::LocalGeometry) =
transform(Contravariant123Axis(), u, local_geometry).
contravariant3(u::AxisVector, local_geometry::LocalGeometry) =
@inline contravariant3(u::AxisVector, local_geometry::LocalGeometry) =
transform(Contravariant123Axis(), u, local_geometry).

contravariant1(u::Axis2Tensor, local_geometry::LocalGeometry) =
transform(Contravariant123Axis(), u, local_geometry)[1, :]
contravariant2(u::Axis2Tensor, local_geometry::LocalGeometry) =
transform(Contravariant123Axis(), u, local_geometry)[2, :]
contravariant3(u::Axis2Tensor, local_geometry::LocalGeometry) =
transform(Contravariant123Axis(), u, local_geometry)[3, :]
@inline function contravariant1(u::Axis2Tensor, local_geometry::LocalGeometry)
@inbounds transform(Contravariant123Axis(), u, local_geometry)[1, :]
end
@inline function contravariant2(u::Axis2Tensor, local_geometry::LocalGeometry)
@inbounds transform(Contravariant123Axis(), u, local_geometry)[2, :]
end
@inline function contravariant3(u::Axis2Tensor, local_geometry::LocalGeometry)
@inbounds transform(Contravariant123Axis(), u, local_geometry)[3, :]
end

Jcontravariant3(u::AxisTensor, local_geometry::LocalGeometry) =
@inline Jcontravariant3(u::AxisTensor, local_geometry::LocalGeometry) =
local_geometry.J * contravariant3(u, local_geometry)

# required for curl-curl
covariant3(u::Contravariant3Vector, local_geometry::LocalGeometry{(1, 2)}) =
contravariant3(u, local_geometry)
@inline covariant3(
u::Contravariant3Vector,
local_geometry::LocalGeometry{(1, 2)},
) = contravariant3(u, local_geometry)

# workarounds for using a Covariant12Vector/Covariant123Vector in a UW space:
function LocalVector(
@inline function LocalVector(
vector::CovariantVector{<:Any, (1, 2, 3)},
local_geometry::LocalGeometry{(1, 3)},
)
Expand All @@ -171,23 +176,23 @@ function LocalVector(
u, w = components(transform(LocalAxis{(1, 3)}(), vector2, local_geometry))
return UVWVector(u, v, w)
end
function contravariant1(
@inline function contravariant1(
vector::CovariantVector{<:Any, (1, 2, 3)},
local_geometry::LocalGeometry{(1, 3)},
)
u₁, _, u₃ = components(vector)
vector2 = Covariant13Vector(u₁, u₃)
return transform(Contravariant13Axis(), vector2, local_geometry).
end
function contravariant3(
@inline function contravariant3(
vector::CovariantVector{<:Any, (1, 2)},
local_geometry::LocalGeometry{(1, 3)},
)
u₁, _ = components(vector)
vector2 = Covariant13Vector(u₁, zero(u₁))
return transform(Contravariant13Axis(), vector2, local_geometry).
end
function ContravariantVector(
@inline function ContravariantVector(
vector::CovariantVector{<:Any, (1, 2)},
local_geometry::LocalGeometry{(1, 3)},
)
Expand Down
2 changes: 1 addition & 1 deletion src/Geometry/localgeometry.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct LocalGeometry{I, C <: AbstractPoint, FT, S}
∂ξ∂x::Axis2Tensor{FT, Tuple{ContravariantAxis{I}, LocalAxis{I}}, S}
end

LocalGeometry(coordinates, J, WJ, ∂x∂ξ) =
@inline LocalGeometry(coordinates, J, WJ, ∂x∂ξ) =
LocalGeometry(coordinates, J, WJ, ∂x∂ξ, inv(∂x∂ξ))

"""
Expand Down
Loading

0 comments on commit 6f23efe

Please sign in to comment.