Skip to content

Commit

Permalink
Support bitstypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Affie committed May 11, 2023
1 parent 10aab94 commit 5fe29e5
Showing 1 changed file with 38 additions and 1 deletion.
39 changes: 38 additions & 1 deletion src/Interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,43 @@ function setPointPartial!(Mdest::AbstractManifold,
return dest
end

function setPointPartial!(
Mdest::AbstractManifold,
dest::AbstractArray{T},
Msrc::AbstractManifold,
src::AbstractArray{U},
partial::AbstractVector{<:Integer},
destIdx,
srcIdx=destIdx,
asPartial::Bool=true
) where {T<:AbstractArray,U<:AbstractArray}

if isbitstype(T)
#TODO needs cleanup, this is copied from setPointPartial! above with index changes
if length(partial) == 0
return dest[destIdx]
end
dest_coords = collect(AMP.makeCoordsFromPoint(Mdest, dest[destIdx]))
src_coords = AMP.makeCoordsFromPoint(Msrc, src[srcIdx])
dest_coords[partial] .= asPartial ? src_coords : view(src_coords, partial)
return dest[destIdx] = makePointFromCoords(Mdest, dest_coords, dest[destIdx])

else
return setPointPartial!(Mdest, dest[destIdx], Msrc, src[srcIdx], partial, asPartial)
end

end

#TODO workaround for supporting bitstypes, need rewrite, can consider `PowerManifoldNestedReplacing` or similar
function setPointsMani!(dest::AbstractArray{T}, src::AbstractArray{U}, destIdx, srcIdx=destIdx) where {T<:AbstractArray,U<:AbstractArray}
if isbitstype(T)
dest[destIdx] = src[srcIdx]
else
setPointsMani!(dest[destIdx],src[srcIdx])
end
end


#TODO ArrayPartition should work for now as it's an AbstractVector, but it won't remain mutable
setPointsMani!(dest::AbstractVector, src::AbstractVector) = (dest .= src)
setPointsMani!(dest::AbstractMatrix, src::AbstractMatrix) = (dest .= src)
Expand Down Expand Up @@ -165,7 +202,7 @@ function Base.replace( dest::ManifoldKernelDensity{M,<:BallTreeDensity,Nothing},
ipc[pl] .= src.infoPerCoord[pl]

# and _u0 point is a bit more tricky
c0 = vee(dest.manifold, dest._u0, log(dest.manifold, dest._u0, dest._u0))
c0 = collect(vee(dest.manifold, dest._u0, log(dest.manifold, dest._u0, dest._u0)))
c_ = vee(dest.manifold, dest._u0, log(dest.manifold, dest._u0, src._u0))
c0[pl] .= c_[pl]
u0 = exp(dest.manifold, dest._u0, hat(dest.manifold, dest._u0, c0))
Expand Down

0 comments on commit 5fe29e5

Please sign in to comment.