Skip to content

Commit

Permalink
oversample and undersample always return classes as well (#116)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Oct 28, 2022
1 parent 855f95b commit 125205b
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 26 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLUtils"
uuid = "f1d291b0-491e-4a28-83b9-f70985020b54"
authors = ["Carlo Lucibello <carlo.lucibello@gmail.com> and contributors"]
version = "0.2.12"
version = "0.3.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
48 changes: 26 additions & 22 deletions src/resample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ resulting data will be shuffled after its creation; if it is not
shuffled then all the repeated samples will be together at the
end, sorted by class. Defaults to `true`.
The output will contain both the resampled data and classes.
```julia
# 6 observations with 3 features each
X = rand(3, 6)
Expand All @@ -40,14 +42,7 @@ X_bal, Y_bal = oversample(X, Y)
```
For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref). For example, the following
code allows `oversample` to work on a `DataFrame`.
```julia
# Make DataFrames.jl work
MLUtils.getobs(data::DataFrame, i) = data[i,:]
MLUtils.numobs(data::DataFrame) = nrow(data)
```
[`numobs`](@ref) and [`getobs`](@ref).
Note that if `data` is a tuple and `classes` is not given,
then it will be assumed that the last element of the tuple contains the classes.
Expand Down Expand Up @@ -98,16 +93,22 @@ function oversample(data, classes; fraction=1, shuffle::Bool=true)
append!(inds, inds_for_lbl)
end
if num_extra_needed > 0
append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false))
if shuffle
append!(inds, sample(inds_for_lbl, num_extra_needed; replace=false))
else
append!(inds, inds_for_lbl[1:num_extra_needed])
end
end
end

shuffle && shuffle!(inds)
return obsview(data, inds)
return obsview(data, inds), obsview(classes, inds)
end

oversample(data::Tuple; kws...) = oversample(data, data[end]; kws...)

function oversample(data::Tuple; kws...)
d, c = oversample(data[1:end-1], data[end]; kws...)
return (d..., c)
end

"""
undersample(data, classes; shuffle=true)
Expand All @@ -123,6 +124,8 @@ resulting data will be shuffled after its creation; if it is not
shuffled then all the observations will be in their original
order. Defaults to `false`.
The output will contain both the resampled data and classes.
```julia
# 6 observations with 3 features each
X = rand(3, 6)
Expand All @@ -142,14 +145,8 @@ X_bal, Y_bal = undersample(X, Y)
```
For this function to work, the type of `data` must implement
[`numobs`](@ref) and [`getobs`](@ref). For example, the following
code allows `undersample` to work on a `DataFrame`.
[`numobs`](@ref) and [`getobs`](@ref).
```julia
# Make DataFrames.jl work
MLUtils.getobs(data::DataFrame, i) = data[i,:]
MLUtils.numobs(data::DataFrame) = nrow(data)
```
Note that if `data` is a tuple, then it will be assumed that the
last element of the tuple contains the targets.
Expand Down Expand Up @@ -186,11 +183,18 @@ function undersample(data, classes; shuffle::Bool=true)
inds = Int[]

for (lbl, inds_for_lbl) in lm
append!(inds, sample(inds_for_lbl, mincount; replace=false))
if shuffle
append!(inds, sample(inds_for_lbl, mincount; replace=false))
else
append!(inds, inds_for_lbl[1:mincount])
end
end

shuffle ? shuffle!(inds) : sort!(inds)
return obsview(data, inds)
return obsview(data, inds), obsview(classes, inds)
end

undersample(data::Tuple; kws...) = undersample(data, data[end]; kws...)
function undersample(data::Tuple; kws...)
d, c = undersample(data[1:end-1], data[end]; kws...)
return (d..., c)
end
10 changes: 7 additions & 3 deletions test/resample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
y2 = ["c", "c", "c", "a", "b"]

o = oversample((x, ya), fraction=1, shuffle=false)
@test o == oversample((x, ya), ya, shuffle=false)
@test o == oversample((x, ya), ya, shuffle=false)[1]
xo, yo = oversample(x, ya, shuffle=false)
@test (xo, yo) == o
ox, oy = getobs(o)
@test ox isa Matrix
@test oy isa Vector
Expand All @@ -15,7 +17,7 @@
@test oy[1:5] == ya
@test oy[6] == ya[5]

o = oversample((x, ya), y2, shuffle=false)
o = oversample((x, ya), y2, shuffle=false)[1]
ox, oy = getobs(o)
@test ox isa Matrix
@test oy isa Vector
Expand All @@ -35,14 +37,16 @@ end
y2 = ["c", "c", "c", "a", "b"]

o = undersample((x, ya), shuffle=false)
xo, yo = undersample(x, ya, shuffle=false)
@test (xo, yo) == o
ox, oy = getobs(o)
@test ox isa Matrix
@test oy isa Vector
@test size(ox) == (2, 3)
@test size(oy) == (3,)
@test ox[:,3] == x[:,5]

o = undersample((x, ya), y2, shuffle=false)
o = undersample((x, ya), y2, shuffle=false)[1]
ox, oy = getobs(o)
@test ox isa Matrix
@test oy isa Vector
Expand Down

2 comments on commit 125205b

@CarloLucibello
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/71209

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 125205bbeacfe3c8a2dd1ece2b5a188f4850bafc
git push origin v0.3.0

Please sign in to comment.