Skip to content

Commit

Permalink
Fix #64 cat bug and support merging non-vector axes
Browse files Browse the repository at this point in the history
  • Loading branch information
Gord Stephen authored and timholy committed Mar 25, 2017
1 parent 8f8272d commit 5255e09
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/combine.jl
Expand Up @@ -11,7 +11,7 @@ end #equalvalued

sizes{T<:AxisArray}(As::T...) = tuple(zip(map(a -> map(length, indices(a)), As)...)...)
matchingdims{N,T<:AxisArray}(As::NTuple{N,T}) = all(equalvalued, sizes(As...))
matchingdimsexcept{N,T<:AxisArray}(As::NTuple{N,T}, n::Int) = all(equalvalued, sizes(As[[1:n-1; n+1:end]]...))
matchingdimsexcept{N,T<:AxisArray}(As::NTuple{N,T}, n::Int) = all(equalvalued, sizes(As...)[[1:n-1; n+1:end]])

function Base.cat{T}(n::Integer, As::AxisArray{T}...)
if n <= ndims(As[1])
Expand Down Expand Up @@ -53,7 +53,7 @@ function combineaxes{T,N,D,Ax}(method::Symbol, As::AxisArray{T,N,D,Ax}...)
return resultaxes, resultaxeslengths, axismaps
end #combineaxes

function mergevalues{T}(values::Tuple{Vararg{Vector{T}}}, method::Symbol)
function mergevalues{T}(values::Tuple{Vararg{AbstractVector{T}}}, method::Symbol)
if method == :inner
intersect(values...)
elseif method == :left
Expand Down
14 changes: 13 additions & 1 deletion test/combine.jl
@@ -1,3 +1,4 @@
# cat
A1data, A2data = [1 3; 2 4], [5 7; 6 8]

A1 = AxisArray(A1data, Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B]))
Expand All @@ -17,7 +18,13 @@ A2 = AxisArray(A2data, Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B]))
Axis{:Row}([:First, :Second]), Axis{:Col}([:A, :B]),
Axis{:page}(1:2))

Adata, Bdata = randn(4,4,2), randn(4,4,2)
A1 = AxisArray(A1data, :Row, :Col)
A2 = AxisArray(A2data, :Row, :Col)
@test_throws ArgumentError cat(2, A1, A2)
@test cat(3, A1, A2) == AxisArray(cat(3, A1data, A2data), :Row, :Col)

# merge
Adata, Bdata, Cdata = randn(4,4,2), randn(4,4,2), randn(4,4,2)
A = AxisArray(Adata, Axis{:X}([1,2,3,4]), Axis{:Y}([10.,20,30,40]), Axis{:Z}([:First, :Second]))
B = AxisArray(Bdata, Axis{:X}([3,4,5,6]), Axis{:Y}([30.,40,50,60]), Axis{:Z}([:First, :Second]))

Expand All @@ -26,6 +33,11 @@ ABdata[1:4,1:4,:] = Adata
ABdata[3:6,3:6,:] = Bdata
@test merge(A,B) == AxisArray(ABdata, Axis{:X}([1,2,3,4,5,6]), Axis{:Y}([10.,20,30,40,50,60]), Axis{:Z}([:First, :Second]))

AC = AxisArray(cat(3, Adata, Cdata), :X, :Y, :Z)
B2 = AxisArray(Bdata, :X, :Y, :Z)
@test merge(AC,B2) == AxisArray(cat(3, Bdata, Cdata), :X, :Y, :Z)

# join
ABdata = zeros(6,6,2,2)
ABdata[1:4,1:4,:,1] = Adata
ABdata[3:6,3:6,:,2] = Bdata
Expand Down

0 comments on commit 5255e09

Please sign in to comment.