Skip to content

Commit

Permalink
Add least_dims
Browse files Browse the repository at this point in the history
  • Loading branch information
yuehhua committed Dec 28, 2020
1 parent 511e4c9 commit ab03446
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,14 @@
Safely divide `x` by `y`. If `y` is zero, return `x` directly.
"""
safe_div(x, y) = ifelse(iszero(y), x, x/y)

"""
least_dims(idxs)
Compute the least dimensions, of which array can be accessed by the indecies `idxs`.
"""
least_dims(idxs::AbstractArray{<:Integer}) = (maximum(idxs), )

function least_dims(idxs::AbstractArray{<:Tuple})
Tuple(maximum(xs) for xs in zip(idxs...))
end
9 changes: 9 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
@testset "least_dims" begin
ind1 = [1,2,3,4,5,6]
@test NNlib.least_dims(ind1) == (6,)
ind2 = [(3,4,5), (1,2,3), (2,3,9)]
@test NNlib.least_dims(ind2) == (3,4,9)
ind3 = [(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)]
@test NNlib.least_dims(ind3) == (5,6,9)
end

0 comments on commit ab03446

Please sign in to comment.