From 125205bbeacfe3c8a2dd1ece2b5a188f4850bafc Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Fri, 28 Oct 2022 06:01:20 +0200 Subject: [PATCH] oversample and undersample always return classes as well (#116) --- Project.toml | 2 +- src/resample.jl | 48 ++++++++++++++++++++++++++---------------------- test/resample.jl | 10 +++++++--- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/Project.toml b/Project.toml index 82ec98f..33e2a80 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLUtils" uuid = "f1d291b0-491e-4a28-83b9-f70985020b54" authors = ["Carlo Lucibello and contributors"] -version = "0.2.12" +version = "0.3.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/resample.jl b/src/resample.jl index 4c1805e..a9fa6c3 100644 --- a/src/resample.jl +++ b/src/resample.jl @@ -21,6 +21,8 @@ resulting data will be shuffled after its creation; if it is not shuffled then all the repeated samples will be together at the end, sorted by class. Defaults to `true`. +The output will contain both the resampled data and classes. + ```julia # 6 observations with 3 features each X = rand(3, 6) @@ -40,14 +42,7 @@ X_bal, Y_bal = oversample(X, Y) ``` For this function to work, the type of `data` must implement -[`numobs`](@ref) and [`getobs`](@ref). For example, the following -code allows `oversample` to work on a `DataFrame`. - -```julia -# Make DataFrames.jl work -MLUtils.getobs(data::DataFrame, i) = data[i,:] -MLUtils.numobs(data::DataFrame) = nrow(data) -``` +[`numobs`](@ref) and [`getobs`](@ref). Note that if `data` is a tuple and `classes` is not given, then it will be assumed that the last element of the tuple contains the classes. @@ -98,16 +93,22 @@ function oversample(data, classes; fraction=1, shuffle::Bool=true) append!(inds, inds_for_lbl) end if num_extra_needed > 0 - append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false)) + if shuffle + append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false)) + else + append!(inds, inds_for_lbl[1:num_extra_needed]) + end end end shuffle && shuffle!(inds) - return obsview(data, inds) + return obsview(data, inds), obsview(classes, inds) end -oversample(data::Tuple; kws...) = oversample(data, data[end]; kws...) - +function oversample(data::Tuple; kws...) + d, c = oversample(data[1:end-1], data[end]; kws...) + return (d..., c) +end """ undersample(data, classes; shuffle=true) @@ -123,6 +124,8 @@ resulting data will be shuffled after its creation; if it is not shuffled then all the observations will be in their original order. Defaults to `false`. +The output will contain both the resampled data and classes. + ```julia # 6 observations with 3 features each X = rand(3, 6) @@ -142,14 +145,8 @@ X_bal, Y_bal = undersample(X, Y) ``` For this function to work, the type of `data` must implement -[`numobs`](@ref) and [`getobs`](@ref). For example, the following -code allows `undersample` to work on a `DataFrame`. +[`numobs`](@ref) and [`getobs`](@ref). -```julia -# Make DataFrames.jl work -MLUtils.getobs(data::DataFrame, i) = data[i,:] -MLUtils.numobs(data::DataFrame) = nrow(data) -``` Note that if `data` is a tuple, then it will be assumed that the last element of the tuple contains the targets. @@ -186,11 +183,18 @@ function undersample(data, classes; shuffle::Bool=true) inds = Int[] for (lbl, inds_for_lbl) in lm - append!(inds, sample(inds_for_lbl, mincount; replace=false)) + if shuffle + append!(inds, sample(inds_for_lbl, mincount; replace=false)) + else + append!(inds, inds_for_lbl[1:mincount]) + end end shuffle ? shuffle!(inds) : sort!(inds) - return obsview(data, inds) + return obsview(data, inds), obsview(classes, inds) end -undersample(data::Tuple; kws...) = undersample(data, data[end]; kws...) +function undersample(data::Tuple; kws...) + d, c = undersample(data[1:end-1], data[end]; kws...) + return (d..., c) +end diff --git a/test/resample.jl b/test/resample.jl index 89ea557..eba222d 100644 --- a/test/resample.jl +++ b/test/resample.jl @@ -4,7 +4,9 @@ y2 = ["c", "c", "c", "a", "b"] o = oversample((x, ya), fraction=1, shuffle=false) - @test o == oversample((x, ya), ya, shuffle=false) + @test o == oversample((x, ya), ya, shuffle=false)[1] + xo, yo = oversample(x, ya, shuffle=false) + @test (xo, yo) == o ox, oy = getobs(o) @test ox isa Matrix @test oy isa Vector @@ -15,7 +17,7 @@ @test oy[1:5] == ya @test oy[6] == ya[5] - o = oversample((x, ya), y2, shuffle=false) + o = oversample((x, ya), y2, shuffle=false)[1] ox, oy = getobs(o) @test ox isa Matrix @test oy isa Vector @@ -35,6 +37,8 @@ end y2 = ["c", "c", "c", "a", "b"] o = undersample((x, ya), shuffle=false) + xo, yo = undersample(x, ya, shuffle=false) + @test (xo, yo) == o ox, oy = getobs(o) @test ox isa Matrix @test oy isa Vector @@ -42,7 +46,7 @@ end @test size(oy) == (3,) @test ox[:,3] == x[:,5] - o = undersample((x, ya), y2, shuffle=false) + o = undersample((x, ya), y2, shuffle=false)[1] ox, oy = getobs(o) @test ox isa Matrix @test oy isa Vector