Skip to content

Commit

Permalink
enable more vdims
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrikekre committed May 19, 2023
1 parent c4c0d29 commit ff81ee5
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
26 changes: 17 additions & 9 deletions src/FEValues/cell_values.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ function CellValues(::Type{T}, qr::QR, ip::IP, gip::GIP = default_geometric_inte
return CellValues{IP, N_t, dNdx_t, dNdξ_t, M_t, dMdξ_t, QR, GIP}(qr, ip, gip)
end

# reinit! for regular (non-embedded) elements
function reinit!(cv::CellValues{<:Any, N_t, dNdx_t}, x::AbstractVector{Vec{dim,T}}) where {
dim, T,
N_t <: Union{Number, Vec{dim}},
dNdx_t <: Union{Vec{dim}, Tensor{2, dim}}
# reinit! for regular (non-embedded) elements (rdim == sdim)
function reinit!(cv::CellValues{<:Any, N_t, dNdx_t, dNdξ_t}, x::AbstractVector{Vec{dim,T}}) where {
dim, T, vdim,
N_t <: Union{Number, Vec{dim}, SVector{vdim} },
dNdx_t <: Union{Vec{dim}, Tensor{2, dim}, SMatrix{vdim, dim}},
dNdξ_t <: Union{Vec{dim}, Tensor{2, dim}, SMatrix{vdim, dim}},
}
n_geom_basefuncs = getngeobasefunctions(cv)
n_func_basefuncs = getnbasefunctions(cv)
Expand All @@ -156,14 +157,21 @@ function reinit!(cv::CellValues{<:Any, N_t, dNdx_t}, x::AbstractVector{Vec{dim,T
cv.detJdV[i] = detJ * w
Jinv = inv(fecv_J)
for j in 1:n_func_basefuncs
cv.dNdx[j, i] = cv.dNdξ[j, i] Jinv
# cv.dNdx[j, i] = cv.dNdξ[j, i] ⋅ Jinv
cv.dNdx[j, i] = dothelper(cv.dNdξ[j, i], Jinv)
end
end
end

# Hotfix to get the dots right for embedded elements until mixed tensors are merged.
@inline dothelper(x::SVector, A::SMatrix) = A' * x
@inline dothelper(B::SMatrix, A::SMatrix) = B * A
# Scalar/Vector interpolations with sdim == rdim (== vdim)
@inline dothelper(A, B) = A B
# Vector interpolations with sdim == rdim != vdim
@inline dothelper(A::SMatrix{vdim, dim}, B::Tensor{2, dim}) where {vdim, dim} = A * SMatrix{dim, dim}(B)
# Scalar interpolations with sdim > rdim
@inline dothelper(A::SVector{rdim}, B::SMatrix{rdim, sdim}) where {rdim, sdim} = B' * A
# Vector interpolations with sdim > rdim
@inline dothelper(B::SMatrix{vdim, rdim}, A::SMatrix{rdim, sdim}) where {vdim, rdim, sdim} = B * A

# Entrypoint for embedded `ScalarInterpolation`s (rdim < sdim)
function CellValues(::Type{T}, qr::QR, ip::IP, gip::VGIP) where {
Expand Down Expand Up @@ -238,7 +246,7 @@ Reinit for embedded elements, i.e. elements whose reference dimension is smaller
"""
function reinit!(cv::CellValues{<:Any, N_t, dNdx_t, dNdξ_t}, x::AbstractVector{Vec{sdim,T}}) where {
rdim, sdim, vdim, T,
N_t <: Union{Number, Vec{vdim}},
N_t <: Union{Number, SVector{vdim}},
dNdx_t <: Union{SVector{sdim, T}, SMatrix{vdim, sdim, T}},
dNdξ_t <: Union{SVector{rdim, T}, SMatrix{vdim, rdim, T}},
}
Expand Down
32 changes: 19 additions & 13 deletions test/test_cellvalues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,17 @@ end
end

@testset "Embedded elements" begin
@testset "Scalar on curves" begin
ue = [-1.5, 2.0]
ip = Lagrange{RefLine,1}()
@testset "Scalar/vector on curves (vdim = $vdim)" for vdim in (0, 1, 2, 3)
ip_base = Lagrange{RefLine,1}()
ip = vdim > 0 ? ip_base^vdim : ip_base
ue = 2 * rand(getnbasefunctions(ip))
qr = QuadratureRule{1,RefLine}(1)
# Reference values
csv1 = CellValues(qr, ip)
reinit!(csv1, [Vec((0.0,)), Vec((1.0,))])

## Consistency with 1D
csv2 = CellValues(qr, ip, ip^2)
## sdim = 2, Consistency with 1D
csv2 = CellValues(qr, ip, ip_base^2)
reinit!(csv2, [Vec((0.0, 0.0)), Vec((1.0, 0.0))])
# Test spatial interpolation
@test spatial_coordinate(csv2, 1, [Vec((0.0, 0.0)), Vec((1.0, 0.0))]) == Vec{2}((0.5, 0.0))
Expand All @@ -188,8 +189,8 @@ end
@test function_gradient(csv1, 1, ue)[1] == function_gradient(csv2, 1, ue)[1]
@test 0.0 == function_gradient(csv2, 1, ue)[2]

## Consistency with 1D
csv3 = CellValues(qr, ip, ip^3)
## sdim = 3, Consistency with 1D
csv3 = CellValues(qr, ip, ip_base^3)
reinit!(csv3, [Vec((0.0, 0.0, 0.0)), Vec((1.0, 0.0, 0.0))])
# Test spatial interpolation
@test spatial_coordinate(csv3, 1, [Vec((0.0, 0.0, 0.0)), Vec((1.0, 0.0, 0.0))]) == Vec{3}((0.5, 0.0, 0.0))
Expand All @@ -204,7 +205,7 @@ end
@test 0.0 == function_gradient(csv3, 1, ue)[2]
@test 0.0 == function_gradient(csv3, 1, ue)[3]

## Consistency in 2D
## sdim = 3, Consistency in 2D
reinit!(csv2, [Vec((-1.0, 2.0)), Vec((3.0, -4.0))])
reinit!(csv3, [Vec((-1.0, 2.0, 0.0)), Vec((3.0, -4.0, 0.0))])
# Test spatial interpolation
Expand All @@ -230,12 +231,13 @@ end
@test function_gradient(csv2, 1, ue)[2] == function_gradient(csv3, 1, ue)[3]
end

@testset "Scalar on surface" begin
ue = [-1.5, 2.0, 3.0, -1.0]
ip = Lagrange{RefQuadrilateral,1}()
@testset "Scalar/vector on surface (vdim = $vdim)" for vdim in (0, 1, 2, 3)
ip_base = Lagrange{RefQuadrilateral,1}()
ip = vdim > 0 ? ip_base^vdim : ip_base
ue = rand(getnbasefunctions(ip))
qr = QuadratureRule{2,RefQuadrilateral}(1)
csv2 = CellValues(qr, ip)
csv3 = CellValues(qr, ip, ip^3)
csv3 = CellValues(qr, ip, ip_base^3)
reinit!(csv2, [Vec((-1.0,-1.0)), Vec((1.0,-1.0)), Vec((1.0,1.0)), Vec((-1.0,1.0))])
reinit!(csv3, [Vec((-1.0,-1.0,0.0)), Vec((1.0,-1.0,0.0)), Vec((1.0,1.0,0.0)), Vec((-1.0,1.0,0.0))])
# Test spatial interpolation
Expand All @@ -247,7 +249,11 @@ end
@test function_value(csv2, 1, ue) == function_value(csv3, 1, ue)
@test function_gradient(csv2, 1, ue)[1] == function_gradient(csv3, 1, ue)[1]
@test function_gradient(csv2, 1, ue)[2] == function_gradient(csv3, 1, ue)[2]
@test 0.0 == function_gradient(csv3, 1, ue)[3]
if vdim != 2
@test 0.0 == function_gradient(csv3, 1, ue)[3]
else
@test_broken 0.0 == function_gradient(csv3, 1, ue)[3]
end
end
end

Expand Down

0 comments on commit ff81ee5

Please sign in to comment.