Skip to content

Commit

Permalink
Merge d4847fb into 22badcb
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolasmagalhaes committed Sep 5, 2020
2 parents 22badcb + d4847fb commit 69de79b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 3 deletions.
24 changes: 23 additions & 1 deletion src/sklearn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -238,4 +238,26 @@ function generate_classification(; n_samples::Int = 100,
random_state = random_state)

return convert(features, labels)
end
end

"""
function generate_swiss_roll(; n_samples::Int = 100,
noise::Float64 = 0.0,
random_state::Union{Int,Nothing} = nothing)
Generate a swiss roll dataset.
#Arguments
- `n_samples::Int = 100`: The number of samples.
- `noise::Float64 = 0.0 : Standard deviation of Gaussian noise added to the data.
- `random_state::Union{Int, Nothing} = nothing`: Determines random number generation for dataset creation. Pass an int for reproducible output across multiple function calls. See Glossary.
Reference: [link](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.make_swiss_roll.htmll)
"""
function generate_swiss_roll(; n_samples::Int = 100,
noise::Float64 = 0.0,
random_state::Union{Int,Nothing} = nothing)

(features, labels) = datasets.make_swiss_roll( n_samples = n_samples,
noise = noise,
random_state = random_state)

return convert(features, labels)
end
9 changes: 7 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ using Test
n_features = features,
n_classes = 1)


@test size(data)[1] == samples
@test size(data)[2] == features + 1

end
data = SyntheticDatasets.generate_swiss_roll(n_samples =samples,
noise = 2.2,
random_state = 5)

@test size(data)[1] == samples
@test size(data)[2] == 4
end

0 comments on commit 69de79b

Please sign in to comment.