Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
style = "sciml"
43 changes: 43 additions & 0 deletions .github/workflows/FormatCheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: format-check

on:
pull_request:
branches:
- master
push:
branches:
- master

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
version:
- '1.6' # Replace this with the minimum Julia version that your package supports. E.g. if your package requires Julia 1.5 or higher, change this to '1.5'.
os:
- ubuntu-latest
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- name: Install JuliaFormatter and Format
run: |
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
julia -e 'using JuliaFormatter; format(".", verbose=true)'
- name: Format check
run: |
julia -e '
out = Cmd(`git diff`) |> read |> String
if out == ""
exit(0)
else
@error "Some files have not been formatted !!!"
write(stdout, out)
exit(1)
end'
48 changes: 19 additions & 29 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,24 @@ using Documenter, MLDatasets
# Build documentation.
# ====================

makedocs(
modules = [MLDatasets],
doctest = true,
clean = false,
sitename = "MLDatasets.jl",
format = Documenter.HTML(
canonical = "https://juliadata.github.io/MLDatasets.jl/stable/",
assets = ["assets/favicon.ico"],
prettyurls = get(ENV, "CI", nothing) == "true",
collapselevel=3,
),

authors = "Hiroyuki Shindo, Christof Stocker, Carlo Lucibello",

pages = Any[
"Home" => "index.md",
"Datasets" => Any[
"Graphs" => "datasets/graphs.md",
"Meshes" => "datasets/meshes.md",
"Miscellaneous" => "datasets/misc.md",
"Text" => "datasets/text.md",
"Vision" => "datasets/vision.md",
],
"Creating Datasets" => Any["containers/overview.md"], # still experimental
"LICENSE.md",
],
strict = true,
checkdocs = :exports
)
makedocs(modules = [MLDatasets],
doctest = true,
clean = false,
sitename = "MLDatasets.jl",
format = Documenter.HTML(canonical = "https://juliadata.github.io/MLDatasets.jl/stable/",
assets = ["assets/favicon.ico"],
prettyurls = get(ENV, "CI", nothing) == "true",
collapselevel = 3),
authors = "Hiroyuki Shindo, Christof Stocker, Carlo Lucibello",
pages = Any["Home" => "index.md",
"Datasets" => Any["Graphs" => "datasets/graphs.md",
"Meshes" => "datasets/meshes.md",
"Miscellaneous" => "datasets/misc.md",
"Text" => "datasets/text.md",
"Vision" => "datasets/vision.md"],
"Creating Datasets" => Any["containers/overview.md"], # still experimental
"LICENSE.md"],
strict = true,
checkdocs = :exports)

deploydocs(repo = "github.com/JuliaML/MLDatasets.jl.git")
15 changes: 7 additions & 8 deletions src/MLDatasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ include("require.jl") # export @require
# In the other case instead, use `require import SomePkg` to force
# the use to manually import.

@require import JSON3="0f8b85d8-7281-11e9-16c2-39a750bddbf1"
@require import DataFrames="a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@require import ImageShow="4e3cecfd-b093-5904-9786-8bbb286a6a31"
@require import Chemfiles="46823bd8-5fb3-5f92-9aa0-96921f3dd015"
@require import JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
@require import DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
@require import ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
@require import Chemfiles = "46823bd8-5fb3-5f92-9aa0-96921f3dd015"

# @lazy import NPZ # lazy imported by FileIO
@lazy import Pickle="fbb45041-c46e-462f-888f-7c521cafbc2c"
@lazy import MAT="23992714-dd62-5051-b70f-ba57cb901cac"
@lazy import HDF5="f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
@lazy import Pickle = "fbb45041-c46e-462f-888f-7c521cafbc2c"
@lazy import MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
@lazy import HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
# @lazy import JLD2

export getobs, numobs # From MLUtils.jl
Expand Down Expand Up @@ -93,7 +93,6 @@ export Omniglot
include("datasets/vision/svhn2.jl")
export SVHN2


## Text

include("datasets/text/ptblm.jl")
Expand Down
30 changes: 15 additions & 15 deletions src/abstract_datasets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@ Implements the following functionality:
"""
abstract type AbstractDataset <: AbstractDataContainer end


MLUtils.getobs(d::AbstractDataset) = d[:]
MLUtils.getobs(d::AbstractDataset, i) = d[i]

function Base.show(io::IO, d::D) where D <: AbstractDataset
function Base.show(io::IO, d::D) where {D <: AbstractDataset}
print(io, "$(D.name.name)()")
end

function Base.show(io::IO, ::MIME"text/plain", d::D) where D <: AbstractDataset
function Base.show(io::IO, ::MIME"text/plain", d::D) where {D <: AbstractDataset}
recur_io = IOContext(io, :compact => false)

print(io, "dataset $(D.name.name):") # if the type is parameterized don't print the parameters
Expand All @@ -38,7 +37,7 @@ function leftalign(s::AbstractString, n::Int)
if m > n
return s[1:n]
else
return s * repeat(" ", n-m)
return s * repeat(" ", n - m)
end
end

Expand All @@ -59,19 +58,23 @@ a `features` and a `targets` fields.
"""
abstract type SupervisedDataset <: AbstractDataset end


Base.length(d::SupervisedDataset) = Tables.istable(d.features) ? numobs_table(d.features) :
numobs((d.features, d.targets))

function Base.length(d::SupervisedDataset)
Tables.istable(d.features) ? numobs_table(d.features) :
numobs((d.features, d.targets))
end

# We return named tuples
Base.getindex(d::SupervisedDataset, ::Colon) = Tables.istable(d.features) ?
(features = d.features, targets=d.targets) :
function Base.getindex(d::SupervisedDataset, ::Colon)
Tables.istable(d.features) ?
(features = d.features, targets = d.targets) :
getobs((; d.features, d.targets))
end

Base.getindex(d::SupervisedDataset, i) = Tables.istable(d.features) ?
(features = getobs_table(d.features, i), targets=getobs_table(d.targets, i)) :
function Base.getindex(d::SupervisedDataset, i)
Tables.istable(d.features) ?
(features = getobs_table(d.features, i), targets = getobs_table(d.targets, i)) :
getobs((; d.features, d.targets), i)
end

"""
UnsupervisedDataset <: AbstractDataset
Expand All @@ -81,13 +84,11 @@ Concrete dataset types inheriting from it must provide a `features` field.
"""
abstract type UnsupervisedDataset <: AbstractDataset end


Base.length(d::UnsupervisedDataset) = numobs(d.features)

Base.getindex(d::UnsupervisedDataset, ::Colon) = getobs(d.features)
Base.getindex(d::UnsupervisedDataset, i) = getobs(d.features, i)


### DOCSTRING TEMPLATES ######################

# SUPERVISED TABLE
Expand All @@ -110,7 +111,6 @@ const METHODS_SUPERVISED_TABLE = """
- `length(dataset)`: Number of observations.
"""


# SUPERVISED ARRAY DATASET

const ARGUMENTS_SUPERVISED_ARRAY = """
Expand Down
3 changes: 2 additions & 1 deletion src/containers/cacheddataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ end

CachedDataset(source, cachesize::Int) = CachedDataset(source, 1:cachesize)

CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source)) =
function CachedDataset(source, cacheidx::AbstractVector{<:Integer} = 1:numobs(source))
CachedDataset(source, collect(cacheidx), make_cache(source, cacheidx))
end

function Base.getindex(dataset::CachedDataset, i::Integer)
_i = findfirst(==(i), dataset.cacheidx)
Expand Down
19 changes: 12 additions & 7 deletions src/containers/filedataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,24 @@ Wrap a set of file `paths` as a dataset (traversed in the same order as `paths`)
Alternatively, specify a `dir` and collect all paths that match a glob `pattern`
(recursively globbing by `depth`). The glob order determines the traversal order.
"""
struct FileDataset{F, T<:AbstractString} <: AbstractDataContainer
struct FileDataset{F, T <: AbstractString} <: AbstractDataContainer
loadfn::F
paths::Vector{T}
end

FileDataset(paths) = FileDataset(FileIO.load, paths)
FileDataset(loadfn,
dir::AbstractString,
pattern::AbstractString = "*",
depth = 4) = FileDataset(loadfn, rglob(pattern, string(dir), depth))
FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4) =
function FileDataset(loadfn,
dir::AbstractString,
pattern::AbstractString = "*",
depth = 4)
FileDataset(loadfn, rglob(pattern, string(dir), depth))
end
function FileDataset(dir::AbstractString, pattern::AbstractString = "*", depth = 4)
FileDataset(FileIO.load, dir, pattern, depth)
end

Base.getindex(dataset::FileDataset, i::Integer) = dataset.loadfn(dataset.paths[i])
Base.getindex(dataset::FileDataset, is::AbstractVector) = map(Base.Fix1(getobs, dataset), is)
function Base.getindex(dataset::FileDataset, is::AbstractVector)
map(Base.Fix1(getobs, dataset), is)
end
Base.length(dataset::FileDataset) = length(dataset.paths)
19 changes: 13 additions & 6 deletions src/containers/hdf5dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ See [`close(::HDF5Dataset)`](@ref) for closing the underlying HDF5 file pointer.
For array datasets, the last dimension is assumed to be the observation dimension.
For scalar datasets, the stored value is returned by `getobs` for any index.
"""
struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
struct HDF5Dataset{T <: Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractDataContainer
fid::HDF5.File
paths::T
shapes::Vector{Tuple}

function HDF5Dataset(fid::HDF5.File, paths::T, shapes::Vector) where T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}
function HDF5Dataset(fid::HDF5.File, paths::T,
shapes::Vector) where {
T <:
Union{HDF5.Dataset, Vector{HDF5.Dataset}}}
_check_hdf5_shapes(shapes) ||
throw(ArgumentError("Cannot create HDF5Dataset for datasets with mismatched number of observations."))

Expand All @@ -33,11 +36,13 @@ struct HDF5Dataset{T<:Union{HDF5.Dataset, Vector{HDF5.Dataset}}} <: AbstractData
end

HDF5Dataset(fid::HDF5.File, path::HDF5.Dataset) = HDF5Dataset(fid, path, [size(path)])
HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset}) =
function HDF5Dataset(fid::HDF5.File, paths::Vector{HDF5.Dataset})
HDF5Dataset(fid, paths, map(size, paths))
end
HDF5Dataset(fid::HDF5.File, path::AbstractString) = HDF5Dataset(fid, fid[path])
HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString}) =
function HDF5Dataset(fid::HDF5.File, paths::Vector{<:AbstractString})
HDF5Dataset(fid, map(p -> fid[p], paths))
end
HDF5Dataset(file::AbstractString, paths) = HDF5Dataset(h5open(file, "r"), paths)

_getobs_hdf5(dataset::HDF5.Dataset, ::Tuple{}, i) = read(dataset)
Expand All @@ -46,10 +51,12 @@ function _getobs_hdf5(dataset::HDF5.Dataset, shape, i)

return dataset[I..., i]
end
Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i) =
function Base.getindex(dataset::HDF5Dataset{HDF5.Dataset}, i)
_getobs_hdf5(dataset.paths, only(dataset.shapes), i)
Base.getindex(dataset::HDF5Dataset{<:Vector}, i) =
end
function Base.getindex(dataset::HDF5Dataset{<:Vector}, i)
Tuple(map((p, s) -> _getobs_hdf5(p, s, i), dataset.paths, dataset.shapes))
end
Base.length(dataset::HDF5Dataset) = last(first(filter(!isempty, dataset.shapes)))

"""
Expand Down
6 changes: 4 additions & 2 deletions src/containers/jld2dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Calling `getobs` on a `JLD2Dataset` is equivalent to mapping `getobs` on
each dataset in `paths`.
See [`close(::JLD2Dataset)`](@ref) for closing the underlying JLD2 file pointer.
"""
struct JLD2Dataset{T<:JLD2.JLDFile, S<:Tuple} <: AbstractDataContainer
struct JLD2Dataset{T <: JLD2.JLDFile, S <: Tuple} <: AbstractDataContainer
fid::T
paths::S

Expand All @@ -27,7 +27,9 @@ end
JLD2Dataset(file::JLD2.JLDFile, path::String) = JLD2Dataset(file, (path,))
JLD2Dataset(file::AbstractString, paths) = JLD2Dataset(jldopen(file, "r"), paths)

Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i) = getobs(only(dataset.paths), i)
function Base.getindex(dataset::JLD2Dataset{<:JLD2.JLDFile, <:NTuple{1}}, i)
getobs(only(dataset.paths), i)
end
Base.getindex(dataset::JLD2Dataset, i) = map(Base.Fix2(getobs, i), dataset.paths)
Base.length(dataset::JLD2Dataset) = numobs(dataset.paths[1])

Expand Down
2 changes: 0 additions & 2 deletions src/containers/tabledataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ end
TableDataset(table::T) where {T} = TableDataset{T}(table)
TableDataset(path::AbstractString) = TableDataset(read_csv(path))


# slow accesses based on Tables.jl
_getobs_row(x, i) = first(Iterators.peel(Iterators.drop(x, i - 1)))
function _getobs_column(x, i)
Expand Down Expand Up @@ -55,7 +54,6 @@ end
Base.getindex(dataset::TableDataset, i) = getobs_table(dataset.table, i)
Base.length(dataset::TableDataset) = numobs_table(dataset.table)


# fast access for DataFrame
# Base.getindex(dataset::TableDataset{<:DataFrame}, i) = dataset.table[i, :]
# Base.length(dataset::TableDataset{<:DataFrame}) = nrow(dataset.table)
Expand Down
Loading