Skip to content

Commit

Permalink
Merge c313478 into ecf7441
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasmagalhaes committed Sep 9, 2020
2 parents ecf7441 + c313478 commit b88c5a8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
19 changes: 19 additions & 0 deletions src/sklearn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,22 @@ function generate_swiss_roll(; n_samples::Int = 100,

return convert(features, labels)
end


"""
function generate_hastie_10_2(; n_samples::Int = 12000,
random_state::Union{Int,Nothing} = nothing)
Generates data for binary classification used in Hastie et al. 2009, Example 10.2.
#Arguments
- `n_samples::Int = 100`: The number of samples..
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_hastie_10_2.html)
"""
function generate_hastie_10_2(; n_samples::Int = 12000,
random_state::Union{Int,Nothing} = nothing)

(features, labels) = datasets.make_hastie_10_2( n_samples = n_samples,
random_state = random_state)

return convert(features, labels)
end
10 changes: 8 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,17 @@ using Test

@test size(data)[1] == samples
@test size(data)[2] == features
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,

data = SyntheticDatasets.generate_swiss_roll(n_samples = samples,
noise = 2.2,
random_state = 5)

@test size(data)[1] == samples
@test size(data)[2] == 4

data = SyntheticDatasets.generate_hastie_10_2(n_samples = samples,
random_state = 5)

@test size(data)[1] == samples
@test size(data)[2] == 11
end

0 comments on commit b88c5a8

Please sign in to comment.