Skip to content

Commit

Permalink
more cosmetic fixes and replacing array comprehensions by generators …
Browse files Browse the repository at this point in the history
…where appropriate
  • Loading branch information
tlienart committed Jul 16, 2019
1 parent 0e2de10 commit 58d1b72
Showing 1 changed file with 29 additions and 38 deletions.
67 changes: 29 additions & 38 deletions src/composites.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,56 +16,49 @@ MLJBase.predict(composite::SupervisedNetwork, fitresult, Xnew) =
fitresult(Xnew)

"""
MLJ.tree(N::Node)
$SIGNATURES
Return a description of the tree defined by the learning network
terminating at node `N`.
terminating at a given node.
"""
tree(s::MLJ.Source) = (source = s,)
function tree(W::MLJ.Node)
mach = W.machine
if mach === nothing
value2 = nothing
endkeys=[]
endvalues=[]
endkeys = []
endvalues = []
else
value2 = mach.model
endkeys = [Symbol("train_arg", i) for i in eachindex(mach.args)]
endvalues = [tree(arg) for arg in mach.args]
endkeys = (Symbol("train_arg", i) for i in eachindex(mach.args))
endvalues = (tree(arg) for arg in mach.args)
end
keys = tuple(:operation, :model,
[Symbol("arg", i) for i in eachindex(W.args)]...,
(Symbol("arg", i) for i in eachindex(W.args))...,
endkeys...)
values = tuple(W.operation, value2,
[tree(arg) for arg in W.args]...,
(tree(arg) for arg in W.args)...,
endvalues...)
return NamedTuple{keys}(values)
end
tree(s::MLJ.Source) = (source = s,)

# get the top level args of the tree of some node:
function args(tree)
keys_ = filter(keys(tree) |> collect) do key
match(r"^arg[0-9]*", string(key)) !== nothing
end
return [getproperty(tree, key) for key in keys_]
end
"""
$SIGNATURES
# get the top level train_args of the tree of some node:
function train_args(tree)
Return a vector of the top level args of the tree associated with a node.
If `train=true`, return the `train_args`.
"""
function args(tree; train=false)
keys_ = filter(keys(tree) |> collect) do key
match(r"^train_arg[0-9]*", string(key)) !== nothing
match(Regex("^$("train_"^train)arg[0-9]*"), string(key)) !== nothing
end
return [getproperty(tree, key) for key in keys_]
end

"""
$SIGNATURES
models(N::AbstractNode)
A vector of all models referenced by node `N`, each model
appearing exactly once.
A vector of all models referenced by a node, each model appearing exactly once.
"""
function models(W::MLJ.AbstractNode)
models_ = filter(flat_values(tree(W)) |> collect) do model
Expand All @@ -75,16 +68,16 @@ function models(W::MLJ.AbstractNode)
end

"""
sources(N::AbstractNode)
$SIGNATURES
A vector of all sources referenced by calls `N()` and `fit!(N)`. These
are the sources of the directed acyclic graph associated with the
learning network terminating at `N`.
Not to be confused with `origins(N)` which refers to the same graph with edges corresponding to training arguments deleted.
Not to be confused with `origins(N)` which refers to the same graph with edges
corresponding to training arguments deleted.
See also: [`origins`](@ref), [`source`](@ref).
"""
function sources(W::MLJ.AbstractNode)
sources_ = filter(MLJ.flat_values(tree(W)) |> collect) do model
Expand All @@ -94,30 +87,28 @@ function sources(W::MLJ.AbstractNode)
end

"""
machines(N)
List all machines in the learning network terminating at node `N`.
$SIGNATURES
List all machines in the learning network terminating at a given node.
"""
machines(W::MLJ.Source) = Any[]
function machines(W::MLJ.Node)
if W.machine === nothing
return vcat([machines(arg) for arg in W.args]...) |> unique
return vcat((machines(arg) for arg in W.args) |> collect) |> unique
else
return vcat(Any[W.machine, ],
[machines(arg) for arg in W.args]...,
[machines(arg) for arg in W.machine.args]...) |> unique
(machines(arg) for arg in W.args)...,
(machines(arg) for arg in W.machine.args)...) |> unique
end
end
machines(W::MLJ.Source) = Any[]

"""
replace(W::MLJ.Node, a1=>b1, a2=>b2, ....)
replace(W::MLJ.Node, a1=>b1, a2=>b2, ...)
Create a deep copy of a node `W`, and thereby replicate the learning
network terminating at `W`, but replacing any specified sources and
models `a1, a2, ...` of the original network with the specified targets
`b1, b2, ...`.
"""
function Base.replace(W::Node, pairs::Pair...)

Expand Down Expand Up @@ -199,7 +190,7 @@ end
function supervised_fit_method(network_Xs, network_ys, network_N,
network_models...)

function fit(model::M, verbosity, X, y) where M <:Supervised
function fit(model::M, verbosity, X, y) where M <: Supervised
Xs = source(X)
ys = source(y)
replacement_models = [getproperty(model, fld)
Expand Down

0 comments on commit 58d1b72

Please sign in to comment.