diff --git a/src/splitobs.jl b/src/splitobs.jl index 7f1c4e9..47b58c9 100644 --- a/src/splitobs.jl +++ b/src/splitobs.jl @@ -1,3 +1,36 @@ +""" + splitobs(n::Int, [at = 0.7]) -> Tuple + +TODO +""" +splitobs(n::Int; at = 0.7) = splitobs(n, at) + +# partition into 2 sets +function splitobs(n::Int, at::AbstractFloat) + 0 < at < 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)")) + n1 = clamp(round(Int, at*n), 1, n) + (1:n1, n1+1:n) +end + +# has to be outside the generated function +_ispos(x) = x > 0 +# partition into length(at)+1 sets +# we use @generated because we compute "N+1" +@generated function splitobs{N}(n::Int, at::NTuple{N,AbstractFloat}) + quote + (all(map(_ispos, at)) && sum(at) < 1) || throw(ArgumentError("all elements in \"at\" must be positive and their sum must be smaller than 1")) + nleft = n + lst = UnitRange{Int}[] + for (i, sz) in enumerate(at) + ni = clamp(round(Int, sz*n), 0, nleft) + push!(lst, n-nleft+1:n-nleft+ni) + nleft -= ni + end + push!(lst, n-nleft+1:n) + $(Expr(:tuple, (:(lst[$i]) for i in 1:N+1)...)) + end +end + """ splitobs(data, [at = 0.7], [obsdim]) @@ -70,27 +103,13 @@ splitobs(data; at = 0.7, obsdim = default_obsdim(data)) = # partition into 2 sets function splitobs(data, at::AbstractFloat, obsdim=default_obsdim(data)) - 0 < at < 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)")) n = nobs(data, obsdim) - n1 = clamp(round(Int, at*n), 1, n) - datasubset(data, 1:n1, obsdim), datasubset(data, n1+1:n, obsdim) + idx1, idx2 = splitobs(n, at) + datasubset(data, idx1, obsdim), datasubset(data, idx2, obsdim) end -# has to be outside the generated function -_ispos(x) = x > 0 # partition into length(at)+1 sets -@generated function splitobs{N,T<:AbstractFloat}(data, at::NTuple{N,T}, obsdim=default_obsdim(data)) - quote - (all(map(_ispos, at)) && sum(at) < 1) || throw(ArgumentError("all elements in \"at\" must be positive and their sum must be smaller than 1")) - n = nobs(data, obsdim) - nleft = n - lst = UnitRange{Int}[] - for (i,sz) in enumerate(at) - ni = clamp(round(Int, sz*n), 0, nleft) - push!(lst, n-nleft+1:n-nleft+ni) - nleft -= ni - end - push!(lst, n-nleft+1:n) - $(Expr(:tuple, (:(datasubset(data, lst[$i], obsdim)) for i in 1:N+1)...)) - end +function splitobs{N,T<:AbstractFloat}(data, at::NTuple{N,T}, obsdim=default_obsdim(data)) + n = nobs(data, obsdim) + map(idx->datasubset(data, idx, obsdim), splitobs(n, at)) end diff --git a/test/tst_splitobs.jl b/test/tst_splitobs.jl index 5434e53..453dbb1 100644 --- a/test/tst_splitobs.jl +++ b/test/tst_splitobs.jl @@ -2,6 +2,20 @@ @test_throws DimensionMismatch splitobs((X, rand(149)), obsdim=:last) @testset "typestability" begin + @testset "Int" begin + @test_throws ArgumentError splitobs(10, 0.) + @test_throws ArgumentError splitobs(10, 1.) + @test_throws ArgumentError splitobs(10, (0.2,0.0)) + @test_throws ArgumentError splitobs(10, (0.2,0.8)) + @test_throws MethodError splitobs(10, 0.5, ObsDim.Undefined()) + @test typeof(@inferred(splitobs(10))) <: NTuple{2} + @test eltype(@inferred(splitobs(10))) <: UnitRange + @test typeof(@inferred(splitobs(10, 0.5))) <: NTuple{2} + @test typeof(@inferred(splitobs(10, (0.5,0.2)))) <: NTuple{3} + @test eltype(@inferred(splitobs(10, 0.5))) <: UnitRange + @test eltype(@inferred(splitobs(10, (0.5,0.2)))) <: UnitRange + @test_throws ErrorException @inferred(splitobs(10, at=0.5)) + end for var in vars @test_throws ArgumentError splitobs(var, 0.) @test_throws ArgumentError splitobs(var, 1.) @@ -36,10 +50,31 @@ end end +@testset "Int" begin + @test splitobs(10) == (1:7,8:10) + @test splitobs(10, 0.5) == (1:5,6:10) + @test splitobs(10, (0.5,0.3)) == (1:5,6:8,9:10) + @test splitobs(150) == splitobs(150, 0.7) + @test splitobs(150, at=0.5) == splitobs(150, 0.5) + @test splitobs(150, at=(0.5,0.2)) == splitobs(150, (0.5,0.2)) + @test nobs.(splitobs(150)) == (105,45) + @test nobs.(splitobs(150, at=(.2,.3))) == (30,45,75) + @test nobs.(splitobs(150, at=(.1,.2,.3))) == (15,30,45,60) + # tests if all obs are still present and none duplicated + @test sum(sum.(getobs.(splitobs(150)))) == 11325 + @test sum(sum.(splitobs(150,at=.1))) == 11325 + @test sum(sum.(splitobs(150,at=(.2,.1)))) == 11325 + @test sum(sum.(splitobs(150,at=(.1,.4,.2)))) == 11325 + @test sum.(splitobs(150)) == (5565, 5760) +end + +println("") + @testset "Array, SparseArray, and SubArray" begin for var in (Xs, ys, vars...) @test splitobs(var) == splitobs(var, 0.7, ObsDim.Last()) @test splitobs(var, at=0.5) == splitobs(var, 0.5, ObsDim.Last()) + @test splitobs(var, at=(0.5,0.2)) == splitobs(var, (0.5,0.2), ObsDim.Last()) @test splitobs(var, obsdim=1) == splitobs(var, 0.7, ObsDim.First()) @test nobs.(splitobs(var)) == (105,45) @test nobs.(splitobs(var, at=(.2,.3))) == (30,45,75)