Skip to content
This repository has been archived by the owner on May 21, 2022. It is now read-only.

Commit

Permalink
make splitobs return tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
Evizero committed Mar 18, 2017
1 parent 94d5a54 commit 1ad70a2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 39 deletions.
41 changes: 23 additions & 18 deletions src/accesspattern/datasubset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,13 @@ Split the `data` into multiple subsets proportional to the
value(s) of `at`.
Note that this function will perform the splits statically and
thus not perform any randomization. The function creates a vector
`DataSubset` in which the first N-1 elements/subsets contain the
fraction of observations of `data` that is specified by `at`.
thus not perform any randomization. The function creates a
`NTuple` of data subsets in which the first N-1 elements/subsets
contain the fraction of observations of `data` that is specified
by `at`.
For example, if `at` is a `Float64` then the return-value will be
a vector with two elements (i.e. subsets), in which the first
a tuple with two elements (i.e. subsets), in which the first
element contains the fracion of observations specified by `at`
and the second element contains the rest. In the following code
the first subset `train` will contain the first 70% of the
Expand All @@ -598,9 +599,9 @@ observations, `val` will have next 30%, and `test` the last 20%
train, val, test = splitobs(X, at = (0.5, 0.3))
```
It is also possible to call it with multiple data arguments as
tuple, which all must have the same number of total observations.
This is useful for labeled data.
It is also possible to call `splitobs` with multiple data
arguments as tuple, which all must have the same number of total
observations. This is useful for labeled data.
```julia
train, test = splitobs((X, y), at = 0.7)
Expand Down Expand Up @@ -639,21 +640,25 @@ splitobs(data; at = 0.7, obsdim = default_obsdim(data)) =

# partition into 2 sets
function splitobs(data, at::AbstractFloat, obsdim=default_obsdim(data))
0 < at < 1 || throw(ArgumentError("The parameter \"at\" must be in interval (0, 1)"))
n = nobs(data, obsdim)
n1 = clamp(round(Int, at*n), 1, n)
[datasubset(data, idx, obsdim) for idx in (1:n1, n1+1:n)]
datasubset(data, 1:n1, obsdim), datasubset(data, n1+1:n, obsdim)
end

# partition into length(at)+1 sets
function splitobs{T<:AbstractFloat}(data, at::NTuple{T}, obsdim=default_obsdim(data))
n = nobs(data, obsdim)
nleft = n
lst = UnitRange{Int}[]
for (i,sz) in enumerate(at)
ni = clamp(round(Int, sz*n), 0, nleft)
push!(lst, n-nleft+1:n-nleft+ni)
nleft -= ni
@generated function splitobs{N,T<:AbstractFloat}(data, at::NTuple{N,T}, obsdim=default_obsdim(data))
quote
(all(map(x->x>0, at)) && sum(at) < 1) || throw(ArgumentError("All elements in \"at\" must be positive and their sum must be smaller than 1"))
n = nobs(data, obsdim)
nleft = n
lst = UnitRange{Int}[]
for (i,sz) in enumerate(at)
ni = clamp(round(Int, sz*n), 0, nleft)
push!(lst, n-nleft+1:n-nleft+ni)
nleft -= ni
end
push!(lst, n-nleft+1:n)
$(Expr(:tuple, (:(datasubset(data, lst[$i], obsdim)) for i in 1:N+1)...))
end
push!(lst, n-nleft+1:n)
[datasubset(data, idx, obsdim) for idx in lst]
end
50 changes: 29 additions & 21 deletions test/tst_datasubset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -813,26 +813,34 @@ println("<HEARTBEAT>")

@testset "typestability" begin
for var in vars
@test_throws ArgumentError splitobs(var, 0.)
@test_throws ArgumentError splitobs(var, 1.)
@test_throws ArgumentError splitobs(var, (0.2,0.0))
@test_throws ArgumentError splitobs(var, (0.2,0.8))
@test_throws MethodError splitobs(var, 0.5, ObsDim.Undefined())
@test typeof(@inferred(splitobs(var))) <: Vector
@test typeof(@inferred(splitobs(var))) <: NTuple{2}
@test eltype(@inferred(splitobs(var))) <: SubArray
@test typeof(@inferred(splitobs(var, 0.5))) <: Vector
@test typeof(@inferred(splitobs(var, (0.5,0.2)))) <: Vector
@test typeof(@inferred(splitobs(var, 0.5))) <: NTuple{2}
@test typeof(@inferred(splitobs(var, (0.5,0.2)))) <: NTuple{3}
@test eltype(@inferred(splitobs(var, 0.5))) <: SubArray
@test eltype(@inferred(splitobs(var, (0.5,0.2)))) <: SubArray
@test typeof(@inferred(splitobs(var, 0.5, ObsDim.Last()))) <: Vector
@test typeof(@inferred(splitobs(var, 0.5, ObsDim.First()))) <: Vector
@test typeof(@inferred(splitobs(var, 0.5, ObsDim.Last()))) <: NTuple{2}
@test typeof(@inferred(splitobs(var, 0.5, ObsDim.First()))) <: NTuple{2}
@test eltype(@inferred(splitobs(var, 0.5, ObsDim.First()))) <: SubArray
@test_throws ErrorException @inferred(splitobs(var, at=0.5))
@test_throws ErrorException @inferred(splitobs(var, obsdim=:last))
@test_throws ErrorException @inferred(splitobs(var, obsdim=1))
end
for tup in tuples
@test typeof(@inferred(splitobs(tup, 0.5))) <: Vector
@test typeof(@inferred(splitobs(tup, (0.5,0.2)))) <: Vector
@test_throws ArgumentError splitobs(tup, 0.)
@test_throws ArgumentError splitobs(tup, 1.)
@test_throws ArgumentError splitobs(tup, (0.2,0.0))
@test_throws ArgumentError splitobs(tup, (0.2,0.8))
@test typeof(@inferred(splitobs(tup, 0.5))) <: NTuple{2}
@test typeof(@inferred(splitobs(tup, (0.5,0.2)))) <: NTuple{3}
@test eltype(@inferred(splitobs(tup, 0.5))) <: Tuple
@test eltype(@inferred(splitobs(tup, (0.5,0.2)))) <: Tuple
@test typeof(@inferred(splitobs(tup, 0.5, ObsDim.Last()))) <: Vector
@test typeof(@inferred(splitobs(tup, 0.5, ObsDim.Last()))) <: NTuple{2}
@test eltype(@inferred(splitobs(tup, 0.5, ObsDim.Last()))) <: Tuple
@test_throws ErrorException @inferred(splitobs(tup, obsdim=:last))
end
Expand All @@ -843,12 +851,12 @@ println("<HEARTBEAT>")
@test splitobs(var) == splitobs(var, 0.7, ObsDim.Last())
@test splitobs(var, at=0.5) == splitobs(var, 0.5, ObsDim.Last())
@test splitobs(var, obsdim=1) == splitobs(var, 0.7, ObsDim.First())
@test nobs.(splitobs(var)) == [105,45]
@test nobs.(splitobs(var, at=(.2,.3))) == [30,45,75]
@test nobs.(splitobs(var, at=(.2,.3), obsdim=:last)) == [30,45,75]
@test nobs.(splitobs(var, at=(.1,.2,.3))) == [15,30,45,60]
@test nobs.(splitobs(var)) == (105,45)
@test nobs.(splitobs(var, at=(.2,.3))) == (30,45,75)
@test nobs.(splitobs(var, at=(.2,.3), obsdim=:last)) == (30,45,75)
@test nobs.(splitobs(var, at=(.1,.2,.3))) == (15,30,45,60)
end
@test nobs.(splitobs(X', obsdim=1),obsdim=1) == [105,45]
@test nobs.(splitobs(X', obsdim=1),obsdim=1) == (105,45)
# tests if all obs are still present and none duplicated
@test sum(vec.(sum.(getobs.(splitobs(sparse(X1))),2))) == fill(11325,10)
@test sum(vec.(sum.(splitobs(X1),2))) == fill(11325,10)
Expand All @@ -857,9 +865,9 @@ println("<HEARTBEAT>")
@test sum(vec.(sum.(splitobs(X1,at=(.1,.4,.2)),2))) == fill(11325,10)
@test sum(vec.(sum.(getobs.(splitobs(sparse(X1),at=(.2,.1))),2))) == fill(11325,10)
@test sum(vec.(sum.(splitobs(X1',obsdim=1),1))) == fill(11325,10)
@test sum.(splitobs(Y1)) == [5565, 5760]
@test sum.(getobs.(splitobs(sparse(Y1)))) == [5565, 5760]
@test sum.(splitobs(Y1, obsdim=:first)) == [5565, 5760]
@test sum.(splitobs(Y1)) == (5565, 5760)
@test sum.(getobs.(splitobs(sparse(Y1)))) == (5565, 5760)
@test sum.(splitobs(Y1, obsdim=:first)) == (5565, 5760)
end

println("<HEARTBEAT>")
Expand All @@ -871,12 +879,12 @@ println("<HEARTBEAT>")
@test_throws MethodError splitobs(tup...)
@test all(map(_->(typeof(_)<:Tuple), splitobs(tup)))
@test all(map(_->(typeof(_)<:Tuple), splitobs(tup,at=0.5)))
@test nobs.(splitobs(tup)) == [105,45]
@test nobs.(splitobs(tup, at=(.2,.3))) == [30,45,75]
@test nobs.(splitobs(tup, at=(.2,.3), obsdim=:last)) == [30,45,75]
@test nobs.(splitobs(tup, at=(.1,.2,.3))) == [15,30,45,60]
@test nobs.(splitobs(tup)) == (105,45)
@test nobs.(splitobs(tup, at=(.2,.3))) == (30,45,75)
@test nobs.(splitobs(tup, at=(.2,.3), obsdim=:last)) == (30,45,75)
@test nobs.(splitobs(tup, at=(.1,.2,.3))) == (15,30,45,60)
end
@test nobs.(splitobs((X',y), obsdim=1),obsdim=1) == [105,45]
@test nobs.(splitobs((X',y), obsdim=1),obsdim=1) == (105,45)
# tests if all obs are still present and none duplicated
# also tests that both paramter are split disjoint
train,test = splitobs((X1,Y1,X1))
Expand Down

0 comments on commit 1ad70a2

Please sign in to comment.