Skip to content
This repository has been archived by the owner on May 21, 2022. It is now read-only.

Commit

Permalink
add new signature splitobs(n::Int,...)
Browse files Browse the repository at this point in the history
allows to precompute index assignments witout a dataset
  • Loading branch information
Evizero committed Apr 5, 2017
1 parent f7994ed commit 1abca77
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 19 deletions.
57 changes: 38 additions & 19 deletions src/splitobs.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,36 @@
"""
splitobs(n::Int, [at = 0.7]) -> Tuple
TODO
"""
splitobs(n::Int; at = 0.7) = splitobs(n, at)

# partition into 2 sets
function splitobs(n::Int, at::AbstractFloat)
0 < at < 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
n1 = clamp(round(Int, at*n), 1, n)
(1:n1, n1+1:n)
end

# has to be outside the generated function
_ispos(x) = x > 0
# partition into length(at)+1 sets
# we use @generated because we compute "N+1"
@generated function splitobs{N}(n::Int, at::NTuple{N,AbstractFloat})
quote
(all(map(_ispos, at)) && sum(at) < 1) || throw(ArgumentError("all elements in \"at\" must be positive and their sum must be smaller than 1"))
nleft = n
lst = UnitRange{Int}[]
for (i, sz) in enumerate(at)
ni = clamp(round(Int, sz*n), 0, nleft)
push!(lst, n-nleft+1:n-nleft+ni)
nleft -= ni
end
push!(lst, n-nleft+1:n)
$(Expr(:tuple, (:(lst[$i]) for i in 1:N+1)...))
end
end

"""
splitobs(data, [at = 0.7], [obsdim])
Expand Down Expand Up @@ -70,27 +103,13 @@ splitobs(data; at = 0.7, obsdim = default_obsdim(data)) =

# partition into 2 sets
function splitobs(data, at::AbstractFloat, obsdim=default_obsdim(data))
0 < at < 1 || throw(ArgumentError("the parameter \"at\" must be in interval (0, 1)"))
n = nobs(data, obsdim)
n1 = clamp(round(Int, at*n), 1, n)
datasubset(data, 1:n1, obsdim), datasubset(data, n1+1:n, obsdim)
idx1, idx2 = splitobs(n, at)
datasubset(data, idx1, obsdim), datasubset(data, idx2, obsdim)
end

# has to be outside the generated function
_ispos(x) = x > 0
# partition into length(at)+1 sets
@generated function splitobs{N,T<:AbstractFloat}(data, at::NTuple{N,T}, obsdim=default_obsdim(data))
quote
(all(map(_ispos, at)) && sum(at) < 1) || throw(ArgumentError("all elements in \"at\" must be positive and their sum must be smaller than 1"))
n = nobs(data, obsdim)
nleft = n
lst = UnitRange{Int}[]
for (i,sz) in enumerate(at)
ni = clamp(round(Int, sz*n), 0, nleft)
push!(lst, n-nleft+1:n-nleft+ni)
nleft -= ni
end
push!(lst, n-nleft+1:n)
$(Expr(:tuple, (:(datasubset(data, lst[$i], obsdim)) for i in 1:N+1)...))
end
function splitobs{N,T<:AbstractFloat}(data, at::NTuple{N,T}, obsdim=default_obsdim(data))
n = nobs(data, obsdim)
map(idx->datasubset(data, idx, obsdim), splitobs(n, at))
end
35 changes: 35 additions & 0 deletions test/tst_splitobs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,20 @@
@test_throws DimensionMismatch splitobs((X, rand(149)), obsdim=:last)

@testset "typestability" begin
@testset "Int" begin
@test_throws ArgumentError splitobs(10, 0.)
@test_throws ArgumentError splitobs(10, 1.)
@test_throws ArgumentError splitobs(10, (0.2,0.0))
@test_throws ArgumentError splitobs(10, (0.2,0.8))
@test_throws MethodError splitobs(10, 0.5, ObsDim.Undefined())
@test typeof(@inferred(splitobs(10))) <: NTuple{2}
@test eltype(@inferred(splitobs(10))) <: UnitRange
@test typeof(@inferred(splitobs(10, 0.5))) <: NTuple{2}
@test typeof(@inferred(splitobs(10, (0.5,0.2)))) <: NTuple{3}
@test eltype(@inferred(splitobs(10, 0.5))) <: UnitRange
@test eltype(@inferred(splitobs(10, (0.5,0.2)))) <: UnitRange
@test_throws ErrorException @inferred(splitobs(10, at=0.5))
end
for var in vars
@test_throws ArgumentError splitobs(var, 0.)
@test_throws ArgumentError splitobs(var, 1.)
Expand Down Expand Up @@ -36,10 +50,31 @@
end
end

@testset "Int" begin
@test splitobs(10) == (1:7,8:10)
@test splitobs(10, 0.5) == (1:5,6:10)
@test splitobs(10, (0.5,0.3)) == (1:5,6:8,9:10)
@test splitobs(150) == splitobs(150, 0.7)
@test splitobs(150, at=0.5) == splitobs(150, 0.5)
@test splitobs(150, at=(0.5,0.2)) == splitobs(150, (0.5,0.2))
@test nobs.(splitobs(150)) == (105,45)
@test nobs.(splitobs(150, at=(.2,.3))) == (30,45,75)
@test nobs.(splitobs(150, at=(.1,.2,.3))) == (15,30,45,60)
# tests if all obs are still present and none duplicated
@test sum(sum.(getobs.(splitobs(150)))) == 11325
@test sum(sum.(splitobs(150,at=.1))) == 11325
@test sum(sum.(splitobs(150,at=(.2,.1)))) == 11325
@test sum(sum.(splitobs(150,at=(.1,.4,.2)))) == 11325
@test sum.(splitobs(150)) == (5565, 5760)
end

println("<HEARTBEAT>")

@testset "Array, SparseArray, and SubArray" begin
for var in (Xs, ys, vars...)
@test splitobs(var) == splitobs(var, 0.7, ObsDim.Last())
@test splitobs(var, at=0.5) == splitobs(var, 0.5, ObsDim.Last())
@test splitobs(var, at=(0.5,0.2)) == splitobs(var, (0.5,0.2), ObsDim.Last())
@test splitobs(var, obsdim=1) == splitobs(var, 0.7, ObsDim.First())
@test nobs.(splitobs(var)) == (105,45)
@test nobs.(splitobs(var, at=(.2,.3))) == (30,45,75)
Expand Down

0 comments on commit 1abca77

Please sign in to comment.