Skip to content

Commit

Permalink
Merge bb7b762 into b55a06f
Browse files Browse the repository at this point in the history
  • Loading branch information
tlienart committed Jul 24, 2019
2 parents b55a06f + bb7b762 commit d7f46df
Show file tree
Hide file tree
Showing 12 changed files with 483 additions and 362 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/internals.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function fit!(machine::Machine; rows=nothing, force=false, verbosity=1)
warning = clean!(mach.model)
isempty(warning) || verbosity < 0 || @warn warning

if rows == nothing
if rows === nothing
rows = (:)
end

Expand All @@ -61,7 +61,7 @@ function fit!(machine::Machine; rows=nothing, force=false, verbosity=1)
mach.rows = deepcopy(rows)
end

if report != nothing
if report !== nothing
merge!(mach.report, report)
end

Expand Down
92 changes: 48 additions & 44 deletions src/MLJ.jl
Original file line number Diff line number Diff line change
@@ -1,61 +1,60 @@
module MLJ

# defined in include files:
export @curve, @pcurve # utilities.jl
export mav, rms, rmsl, rmslp1, rmsp # metrics.jl
export misclassification_rate, cross_entropy # metrics.jl
export default_measure # metrics.jl
export coerce, supervised, unsupervised # tasks.jl
export report # machines.jl
export Holdout, CV, evaluate!, Resampler # resampling.jl
export Params, params, set_params! # parameters.jl
export strange, iterator # parameters.jl
export Grid, TunedModel, learning_curve! # tuning.jl
export EnsembleModel # ensembles.jl
export ConstantRegressor, ConstantClassifier # builtins/Constant.jl
export models, localmodels, @load # loading.jl
export KNNRegressor # builtins/KNN.jl
export @from_network, machines, sources # composites.jl
export @curve, @pcurve, # utilities.jl
mav, mae, rms, rmsl, rmslp1, rmsp, # metrics.jl
misclassification_rate, cross_entropy, # metrics.jl
default_measure, # metrics.jl
coerce, supervised, unsupervised, # tasks.jl
report, # machines.jl
Holdout, CV, evaluate!, Resampler, # resampling.jl
Params, params, set_params!, # parameters.jl
strange, iterator, # parameters.jl
Grid, TunedModel, learning_curve!, # tuning.jl
EnsembleModel, # ensembles.jl
ConstantRegressor, ConstantClassifier, # builtins/Constant.jl
models, localmodels, @load, # loading.jl
KNNRegressor, # builtins/KNN.jl
@from_network, machines, sources # composites.jl

# defined in include files "machines.jl and "networks.jl":
export Machine, NodalMachine, machine, AbstractNode
export source, node, fit!, freeze!, thaw!, Node, sources, origins
export Machine, NodalMachine, machine, AbstractNode,
source, node, fit!, freeze!, thaw!, Node, sources, origins

# defined in include file "builtins/Transformers.jl":
export FeatureSelector
export UnivariateStandardizer, Standardizer
export UnivariateBoxCoxTransformer
export OneHotEncoder
# export IntegerToInt64Transformer
# export UnivariateDiscretizer, Discretizer
export FeatureSelector,
UnivariateStandardizer, Standardizer,
UnivariateBoxCoxTransformer,
OneHotEncoder
# IntegerToInt64Transformer,
# UnivariateDiscretizer, Discretizer

# rexport from Random, Statistics, Distributions, CategoricalArrays:
export pdf, mode, median, mean, shuffle!, categorical, shuffle

# reexport from MLJBase:
export nrows, nfeatures, info
export SupervisedTask, UnsupervisedTask, MLJTask
export Deterministic, Probabilistic, Unsupervised, Supervised
export DeterministicNetwork, ProbabilisticNetwork
export Found, Continuous, Finite, Infinite
export OrderedFactor, Unknown
export Count, Multiclass, Binary
export scitype, scitype_union, scitypes
export predict, predict_mean, predict_median, predict_mode
export transform, inverse_transform, se, evaluate, fitted_params
export @constant, @more, HANDLE_GIVEN_ID, UnivariateFinite
export partition, X_and_y
export load_boston, load_ames, load_iris, load_reduced_ames
export load_crabs, datanow
export features, X_and_y
export nrows, nfeatures, info,
SupervisedTask, UnsupervisedTask, MLJTask,
Deterministic, Probabilistic, Unsupervised, Supervised,
DeterministicNetwork, ProbabilisticNetwork,
Found, Continuous, Finite, Infinite,
OrderedFactor, Unknown,
Count, Multiclass, Binary,
scitype, scitype_union, scitypes,
predict, predict_mean, predict_median, predict_mode,
transform, inverse_transform, se, evaluate, fitted_params,
@constant, @more, HANDLE_GIVEN_ID, UnivariateFinite,
partition, X_and_y,
load_boston, load_ames, load_iris, load_reduced_ames,
load_crabs, datanow,
features, X_and_y

using MLJBase

# to be extended:
import MLJBase: fit, update, clean!
import MLJBase: predict, predict_mean, predict_median, predict_mode
import MLJBase: transform, inverse_transform, se, evaluate, fitted_params
import MLJBase: show_as_constructed, params
import MLJBase: fit, update, clean!,
predict, predict_mean, predict_median, predict_mode,
transform, inverse_transform, se, evaluate, fitted_params,
show_as_constructed, params

using RemoteFiles
import Pkg.TOML
Expand All @@ -65,6 +64,11 @@ import Distributions
import StatsBase
using ProgressMeter
import Tables
import Random

# convenience packages
using DocStringExtensions: SIGNATURES, TYPEDEF
using Parameters

# to be extended:
import Base.==
Expand All @@ -78,7 +82,7 @@ import Distributed: @distributed, nworkers, pmap
using RecipesBase # for plotting

# submodules of this module:
include("registry/src/Registry.jl")
include("registry/src/Registry.jl")
import .Registry

const srcdir = dirname(@__FILE__) # the directory containing this file:
Expand Down
73 changes: 32 additions & 41 deletions src/composites.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,49 @@ MLJBase.predict(composite::SupervisedNetwork, fitresult, Xnew) =
fitresult(Xnew)

"""
MLJ.tree(N::Node)
$SIGNATURES
Return a description of the tree defined by the learning network
terminating at node `N`.
terminating at a given node.
"""
tree(s::MLJ.Source) = (source = s,)
function tree(W::MLJ.Node)
mach = W.machine
if mach == nothing
if mach === nothing
value2 = nothing
endkeys=[]
endvalues=[]
endkeys = []
endvalues = []
else
value2 = mach.model
endkeys = [Symbol("train_arg", i) for i in eachindex(mach.args)]
endvalues = [tree(arg) for arg in mach.args]
endkeys = (Symbol("train_arg", i) for i in eachindex(mach.args))
endvalues = (tree(arg) for arg in mach.args)
end
keys = tuple(:operation, :model,
[Symbol("arg", i) for i in eachindex(W.args)]...,
(Symbol("arg", i) for i in eachindex(W.args))...,
endkeys...)
values = tuple(W.operation, value2,
[tree(arg) for arg in W.args]...,
(tree(arg) for arg in W.args)...,
endvalues...)
return NamedTuple{keys}(values)
end
tree(s::MLJ.Source) = (source = s,)

# get the top level args of the tree of some node:
function args(tree)
keys_ = filter(keys(tree) |> collect) do key
match(r"^arg[0-9]*", string(key)) != nothing
end
return [getproperty(tree, key) for key in keys_]
end
"""
$SIGNATURES
# get the top level train_args of the tree of some node:
function train_args(tree)
Return a vector of the top level args of the tree associated with a node.
If `train=true`, return the `train_args`.
"""
function args(tree; train=false)
keys_ = filter(keys(tree) |> collect) do key
match(r"^train_arg[0-9]*", string(key)) != nothing
match(Regex("^$("train_"^train)arg[0-9]*"), string(key)) !== nothing
end
return [getproperty(tree, key) for key in keys_]
end

"""
$SIGNATURES
models(N::AbstractNode)
A vector of all models referenced by node `N`, each model
appearing exactly once.
A vector of all models referenced by a node, each model appearing exactly once.
"""
function models(W::MLJ.AbstractNode)
models_ = filter(flat_values(tree(W)) |> collect) do model
Expand All @@ -75,16 +68,16 @@ function models(W::MLJ.AbstractNode)
end

"""
sources(N::AbstractNode)
$SIGNATURES
A vector of all sources referenced by calls `N()` and `fit!(N)`. These
are the sources of the directed acyclic graph associated with the
learning network terminating at `N`.
Not to be confused with `origins(N)` which refers to the same graph with edges corresponding to training arguments deleted.
Not to be confused with `origins(N)` which refers to the same graph with edges
corresponding to training arguments deleted.
See also: [`origins`](@ref), [`source`](@ref).
"""
function sources(W::MLJ.AbstractNode)
sources_ = filter(MLJ.flat_values(tree(W)) |> collect) do model
Expand All @@ -94,30 +87,28 @@ function sources(W::MLJ.AbstractNode)
end

"""
machines(N)
List all machines in the learning network terminating at node `N`.
$SIGNATURES
List all machines in the learning network terminating at a given node.
"""
machines(W::MLJ.Source) = Any[]
function machines(W::MLJ.Node)
if W.machine == nothing
return vcat([machines(arg) for arg in W.args]...) |> unique
if W.machine === nothing
return vcat((machines(arg) for arg in W.args) |> collect) |> unique
else
return vcat(Any[W.machine, ],
[machines(arg) for arg in W.args]...,
[machines(arg) for arg in W.machine.args]...) |> unique
(machines(arg) for arg in W.args)...,
(machines(arg) for arg in W.machine.args)...) |> unique
end
end
machines(W::MLJ.Source) = Any[]

"""
replace(W::MLJ.Node, a1=>b1, a2=>b2, ....)
replace(W::MLJ.Node, a1=>b1, a2=>b2, ...)
Create a deep copy of a node `W`, and thereby replicate the learning
network terminating at `W`, but replacing any specified sources and
models `a1, a2, ...` of the original network with the specified targets
`b1, b2, ...`.
"""
function Base.replace(W::Node, pairs::Pair...)

Expand Down Expand Up @@ -161,7 +152,7 @@ function Base.replace(W::Node, pairs::Pair...)
# build the new network:
for N in nodes_
args = [newnode_given_old[arg] for arg in N.args]
if N.machine == nothing
if N.machine === nothing
newnode_given_old[N] = node(N.operation, args...)
else
if N.machine in keys(newmach_given_old)
Expand Down Expand Up @@ -199,7 +190,7 @@ end
function supervised_fit_method(network_Xs, network_ys, network_N,
network_models...)

function fit(model::M, verbosity, X, y) where M <:Supervised
function fit(model::M, verbosity, X, y) where M <: Supervised
Xs = source(X)
ys = source(y)
replacement_models = [getproperty(model, fld)
Expand Down
2 changes: 1 addition & 1 deletion src/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ to be specified.
"""
function MLJBase.info(model::String; pkg=nothing)
if pkg == nothing
if pkg === nothing
if model in string.(MLJBase.finaltypes(Model))
pkg = "MLJ"
else
Expand Down
Loading

0 comments on commit d7f46df

Please sign in to comment.