Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface points for accessing the internal state of an exported learning network composite model #644

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ export pdf, sampler, mode, median, mean, shuffle!, categorical, shuffle,
# ===================================================================
## CONSTANTS

const PREDICT_OPERATIONS = (:predict,
:predict_mean,
:predict_mode,
:predict_median,
:predict_joint)

const OPERATIONS = (PREDICT_OPERATIONS..., :transform, :inverse_transform)

# the directory containing this file: (.../src/)
const MODULE_DIR = dirname(@__FILE__)

Expand Down
199 changes: 152 additions & 47 deletions src/composition/learning_networks/machines.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,70 @@
## LEARNING NETWORK MACHINES
# # SIGNATURES

surrogate(::Type{<:Deterministic}) = Deterministic()
surrogate(::Type{<:Probabilistic}) = Probabilistic()
surrogate(::Type{<:Unsupervised}) = Unsupervised()
surrogate(::Type{<:Static}) = Static()
const DOC_SIGNATURES =
"""
A learning network *signature* is an intermediate object defined when
a user constructs a learning network machine, `mach`. They are named
tuples whose values are the nodes consitituting interface points
between the network and the machine. Examples are

(predict=yhat, )
(transform=Xsmall,)
(predict=yhat, transform=W, report=(loss=loss_node,))

where `yhat`, `Xsmall`, `W` and `loss_node` are nodes in the network.

If a key `k` is the name of an operation (such as `:predict`,
`:predict_mode`, `:transform`, `inverse_transform`) then `k(mach, X)`
returns `n(X)` where `n` is the corresponding node value. Each such
node must have a unique origin (`length(origins(n)) === 1`).

The only other allowed key is `:report`, whose corresponding value
must be a named tuple

(k1=n1, k2=n2, ...)

whose keys are arbitrary, and whose values are nodes of the
network. For each such key-value pair `k=n`, the value returned by
`n()` is included in the named tuple `report(mach)`, with
corresponding key `k`. So, in the third example above,
`report(mach).loss` will return the value of `loss_node()`.

"""

function _operation_part(signature)
ops = filter(in(OPERATIONS), keys(signature))
return NamedTuple{ops}(map(op->getproperty(signature, op), ops))
end
function _report_part(signature)
:report in keys(signature) || return NamedTuple()
return signature.report
end

_operations(signature) = keys(_operation_part(signature))

function _nodes(signature)
return (values(_operation_part(signature))...,
values(_report_part(signature))...)
end

function _call(nt::NamedTuple)
_call(n) = deepcopy(n())
_keys = keys(nt)
_values = values(nt)
return NamedTuple{_keys}(_call.(_values))
end

"""
model_supertype(signature)

Return, if this can be deduced, which of `Deterministic`,
Return, if this can be inferred, which of `Deterministic`,
`Probabilistic` and `Unsupervised` is the appropriate supertype for a
composite model obtained by exporting a learning network with the
specified `signature`.

A learning network *signature* is a named tuple, such as
`(predict=yhat, transform=W)`, specifying what nodes of the network
are called to produce output of each operation represented by the
keys, in an exported version of the network.
$DOC_SIGNATURES

If a supertype cannot be deduced, `nothing` is returned.
If a supertype cannot be inferred, `nothing` is returned.

If the network with given `signature` is not exportable, this method
will not error but it will not a give meaningful return value either.
Expand All @@ -29,7 +74,7 @@ will not error but it will not a give meaningful return value either.
"""
function model_supertype(signature)

operations = keys(signature)
operations = _operations(signature)

length(intersect(operations, (:predict_mean, :predict_median))) == 1 &&
return Deterministic
Expand All @@ -50,6 +95,39 @@ function model_supertype(signature)

end


# # FITRESULTS FOR COMPOSITE MODELS

mutable struct CompositeFitresult
signature
glb
report_additions
function CompositeFitresult(signature)
glb = MLJBase.glb(_nodes(signature)...)
new(signature, glb)
end
end
signature(c::CompositeFitresult) = getfield(c, :signature)
glb(c::CompositeFitresult) = getfield(c, :glb)
report_additions(c::CompositeFitresult) = getfield(c, :report_additions)

update!(c::CompositeFitresult) =
setfield!(c, :report_additions, _call(_report_part(signature(c))))

# To accommodate pre-existing design (operations.jl) arrange
# that `fitresult.predict` returns the predict node, etc:
Base.propertynames(c::CompositeFitresult) = keys(signature(c))
Base.getproperty(c::CompositeFitresult, name::Symbol) =
getproperty(signature(c), name)


# # LEARNING NETWORK MACHINES

surrogate(::Type{<:Deterministic}) = Deterministic()
surrogate(::Type{<:Probabilistic}) = Probabilistic()
surrogate(::Type{<:Unsupervised}) = Unsupervised()
surrogate(::Type{<:Static}) = Static()

caches_data_by_default(::Type{<:Surrogate}) = false

const ERR_MUST_PREDICT = ArgumentError(
Expand All @@ -61,25 +139,31 @@ const ERR_MUST_OPERATE = ArgumentError(
const ERR_MUST_SPECIFY_SOURCES = ArgumentError(
"You must specify at least one source `Xs`, as in "*
"`machine(surrogate_model, Xs, ...; kwargs...)`. ")
const ERR_BAD_SIGNATURE = ArgumentError(
"Only the following keyword arguments are supported in learning network "*
"machine constructors: `report` or one of: `$OPERATIONS`. ")
const ERR_EXPECTED_NODE_IN_SIGNATURE = ArgumentError(
"Learning network machine constructor syntax error. "*
"Did not enounter `Node` in place one was expected. ")

function check_surrogate_machine(::Surrogate, signature, _sources)
isempty(signature) && throw(ERR_MUST_OPERATE)
isempty(_operations(signature)) && throw(ERR_MUST_OPERATE)
isempty(_sources) && throw(ERR_MUST_SPECIFY_SOURCES)
return nothing
end

function check_surrogate_machine(::Union{Supervised,SupervisedAnnotator},
signature,
_sources)
isempty(signature) && throw(ERR_MUST_PREDICT)
isempty(_operations(signature)) && throw(ERR_MUST_PREDICT)
length(_sources) > 1 || throw(err_supervised_nargs())
return nothing
end

function check_surrogate_machine(::Union{Unsupervised},
signature,
_sources)
isempty(signature) && throw(ERR_MUST_TRANSFORM)
isempty(_operations(signature)) && throw(ERR_MUST_TRANSFORM)
length(_sources) < 2 || throw(err_unsupervised_nargs())
return nothing
end
Expand All @@ -88,16 +172,26 @@ function machine(model::Surrogate, _sources::Source...; pair_itr...)

# named tuple, such as `(predict=yhat, transform=W)`:
signature = (; pair_itr...)
for op in keys(signature)
op in OPERATIONS || throw(ArgumentError(
"`$op` is not an admissible operation. "))

# signature checks:
isempty(_operations(signature)) && throw(ERR_MUST_OPERATE)
for k in keys(signature)
if k in OPERATIONS
getproperty(signature, k) isa AbstractNode ||
throw(ERR_EXPECTED_NODE_IN_SIGNATURE)
elseif k === :report
all(v->v isa AbstractNode, values(signature.report)) ||
throw(ERR_EXPECTED_NODE_IN_SIGNATURE)
else
throw(ERR_BAD_SIGNATURE)
end
end

check_surrogate_machine(model, signature, _sources)

mach = Machine(model, _sources...)

mach.fitresult = signature
mach.fitresult = CompositeFitresult(signature)

return mach

Expand All @@ -107,8 +201,6 @@ function machine(_sources::Source...; pair_itr...)

signature = (; pair_itr...)

isempty(signature) && throw(ERR_MUST_OPERATE)

T = model_supertype(signature)
if T == nothing
@warn "Unable to infer surrogate model type. \n"*
Expand All @@ -125,16 +217,17 @@ function machine(_sources::Source...; pair_itr...)
end

"""
N = glb(mach::Machine{<:Surrogate})
N = glb(mach::Machine{<:Union{Composite,Surrogate}})

A greatest lower bound for the nodes appearing in the signature of
`mach`.

$DOC_SIGNATURES

**Private method.**

"""
glb(mach::Machine{<:Union{Composite,Surrogate}}) =
glb(values(mach.fitresult)...)
glb(mach::Machine{<:Union{Composite,Surrogate}}) = glb(mach.fitresult)


"""
Expand All @@ -144,34 +237,31 @@ glb(mach::Machine{<:Union{Composite,Surrogate}}) =
verbosity=1,
force=false))

Train the complete learning network wrapped by the machine
`mach`.
Train the complete learning network wrapped by the machine `mach`.

More precisely, if `s` is the learning network signature used to
construct `mach`, then call `fit!(N)`, where `N = glb(values(s)...)`
is a greatest lower bound on the nodes appearing in the signature. For
example, if `s = (predict=yhat, transform=W)`, then call
`fit!(glb(yhat, W))`. Here `glb` is `tuple` overloaded for nodes.
construct `mach`, then call `fit!(N)`, where `N` is a greatest lower
bound of the nodes appearing in the signature (values in the signature
that are not `AbstractNode` are ignored). For example, if `s =
(predict=yhat, transform=W)`, then call `fit!(glb(yhat, W))`.

See also [`machine`](@ref)

"""
function fit!(mach::Machine{<:Surrogate}; kwargs...)

glb_node = glb(mach)
fit!(glb_node; kwargs...)

update!(mach.fitresult) # updates `report_additions`
mach.state += 1
mach.report = report(glb_node)
mach.report = merge(report(glb_node), report_additions(mach.fitresult))
return mach

end

MLJModelInterface.fitted_params(mach::Machine{<:Surrogate}) =
fitted_params(glb(mach))


## CONSTRUCTING THE RETURN VALUE FOR A COMPOSITE FIT METHOD
# # CONSTRUCTING THE RETURN VALUE FOR A COMPOSITE FIT METHOD

# Identify which properties of `model` have, as values, a model in the
# learning network wrapped by `mach`, and check that no two such
Expand All @@ -182,7 +272,6 @@ MLJModelInterface.fitted_params(mach::Machine{<:Surrogate}) =
function network_model_names(model::M,
mach::Machine{<:Surrogate}) where M<:Model

signature = mach.fitresult
network_model_ids = objectid.(MLJBase.models(glb(mach)))

names = propertynames(model)
Expand Down Expand Up @@ -247,8 +336,8 @@ appearing in the `MLJBase.fit` signature, while `mach` is a learning
network machine constructed using `model`. Not relevant when defining
composite models using `@pipeline` or `@from_network`.

For usage, see the example given below. Specificlly, the call does the
following:
For usage, see the example given below. Specifically, the call does
the following:

- Determines which hyper-parameters of `model` point to model
instances in the learning network wrapped by `mach`, for recording
Expand Down Expand Up @@ -342,10 +431,11 @@ specified are replaced with empty source nodes.
function Base.replace(mach::Machine{<:Surrogate},
pairs::Pair...; empty_unspecified_sources=false)

signature = mach.fitresult
interface_nodes = values(signature)
signature = MLJBase.signature(mach.fitresult)
operation_nodes = values(_operation_part(signature))
report_nodes = values(_report_part(signature))

W = glb(interface_nodes...)
W = glb(operation_nodes..., report_nodes...)

# Note: We construct nodes of the new network as values of a
# dictionary keyed on the nodes of the old network. Additionally,
Expand Down Expand Up @@ -394,7 +484,9 @@ function Base.replace(mach::Machine{<:Surrogate},
newnode_given_old =
IdDict{AbstractNode,AbstractNode}(all_source_pairs)
newsources = [newnode_given_old[s] for s in sources_]
newinterface_node_given_old =
newoperation_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newreport_node_given_old =
IdDict{AbstractNode,AbstractNode}()
newmach_given_old = IdDict{Machine,Machine}()

Expand All @@ -414,14 +506,27 @@ function Base.replace(mach::Machine{<:Surrogate},
end
newnode_given_old[N] = N.operation(m, args...)
end
if N in interface_nodes
newinterface_node_given_old[N] = newnode_given_old[N]
if N in operation_nodes
newoperation_node_given_old[N] = newnode_given_old[N]
elseif N in report_nodes
newreport_node_given_old[N] = newnode_given_old[N]
end
end

newinterface_nodes = Tuple(newinterface_node_given_old[N] for N in
interface_nodes)
newsignature = NamedTuple{keys(signature)}(newinterface_nodes)
newoperation_nodes = Tuple(newoperation_node_given_old[N] for N in
operation_nodes)
newreport_nodes = Tuple(newreport_node_given_old[N] for N in
report_nodes)
report_tuple =
NamedTuple{keys(_report_part(signature))}(newreport_nodes)
operation_tuple =
NamedTuple{keys(_operation_part(signature))}(newoperation_nodes)

newsignature = if isempty(report_tuple)
operation_tuple
else
merge(operation_tuple, (report=report_tuple,))
end

return machine(mach.model, newsources...; newsignature...)

Expand Down
5 changes: 3 additions & 2 deletions src/composition/learning_networks/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,9 @@ end
"""
glb(N1, N2, ...)

Construct a node `N` with the behaviour `N() = (N1(), N2(),
...)`. That is, `glb` is `tuple` overloaded for nodes.
Given nodes `N1`, `N2`, ... , construct a node `N` with the behaviour
`N() = (N1(), N2(), ...)`. That is, `glb` is `tuple` overloaded for
nodes.

Equivalent to `@tuple N1 N2 ...`

Expand Down
6 changes: 2 additions & 4 deletions src/composition/models/inspection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@ end
function report(mach::Machine{<:Composite})
machines = mach.report.machines
dict = mach.report.report_given_machine
return merge(tuple_keyed_on_model_names(dict, mach),
(machines=machines, report_given_machine=dict,))
return merge(tuple_keyed_on_model_names(dict, mach), mach.report)
end

function fitted_params(mach::Machine{<:Composite})
fp = fitted_params(mach.model, mach.fitresult)
_machines = fp.machines
dict = fp.fitted_params_given_machine
return merge(MLJBase.tuple_keyed_on_model_names(dict, mach),
(machines=_machines, fitted_params_given_machine=dict,))
return merge(MLJBase.tuple_keyed_on_model_names(dict, mach), fp)
end
Loading