/
inspection.jl
152 lines (125 loc) · 4.03 KB
/
inspection.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
## INSPECTING LEARNING NETWORKS
"""
tree(N)
Return a named-tuple respresentation of the ancestor tree `N`
(including training edges)
"""
function tree(W::Node)
mach = W.machine
if mach === nothing
value2 = nothing
endkeys = []
endvalues = []
else
value2 = mach.model
endkeys = (Symbol("train_arg", i) for i in eachindex(mach.args))
endvalues = (tree(arg) for arg in mach.args)
end
keys = tuple(:operation, :model,
(Symbol("arg", i) for i in eachindex(W.args))...,
endkeys...)
values = tuple(W.operation, value2,
(tree(arg) for arg in W.args)...,
endvalues...)
return NamedTuple{keys}(values)
end
tree(s::Source) = (source = s,)
# """
# args(tree; train=false)
# Return a vector of the top level args of the tree associated with a node.
# If `train=true`, return the `train_args`.
# """
# function args(tree; train=false)
# keys_ = filter(keys(tree) |> collect) do key
# match(Regex("^$("train_"^train)arg[0-9]*"), string(key)) !== nothing
# end
# return [getproperty(tree, key) for key in keys_]
# end
"""
MLJBase.models(N::AbstractNode)
A vector of all models referenced by a node `N`, each model appearing
exactly once.
"""
function models(W::AbstractNode)
models_ = filter(flat_values(tree(W)) |> collect) do model
model isa Union{Model,Symbol}
end
return unique(models_)
end
"""
sources(N::AbstractNode)
A vector of all sources referenced by calls `N()` and `fit!(N)`. These
are the sources of the ancestor graph of `N` when including training
edges.
Not to be confused with `origins(N)`, in which training edges are
excluded.
See also: [`origins`](@ref), [`source`](@ref).
"""
function sources(W::AbstractNode; kind=:any)
if kind == :any
sources_ = filter(flat_values(tree(W)) |> collect) do value
value isa Source
end
else
sources_ = filter(flat_values(tree(W)) |> collect) do value
value isa Source && value.kind == kind
end
end
return unique(sources_)
end
"""
machines(N::AbstractNode [, model::Model])
List all machines in the ancestor graph of node `N`, optionally
restricting to those machines whose corresponding model matches the
specifed `model`.
Here two models *match* if they have the same, possibly nested
hyperparameters, or, more precisely, if
`MLJModelInterface.is_same_except(m1, m2)` is `true`.
See also `MLJModelInterface.is_same_except`.
"""
function machines(W::Node, model=nothing)
if W.machine === nothing
machs = vcat((machines(arg) for arg in W.args)...) |> unique
else
machs = vcat(Machine[W.machine, ],
(machines(arg) for arg in W.args)...,
(machines(arg) for arg in W.machine.args)...) |> unique
end
model === nothing && return machs
return filter(machs) do mach
mach.model == model
end
end
args(::Source) = []
args(N::Node) = N.args
train_args(::Source) = []
train_args(N::Node{<:Machine}) = N.machine.args
train_args(N::Node{Nothing}) = []
"""
children(N::AbstractNode, y::AbstractNode)
List all (immediate) children of node `N` in the ancestor graph of `y`
(training edges included).
"""
children(N::AbstractNode, y::AbstractNode) = filter(nodes(y)) do W
N in args(W) || N in train_args(W)
end |> unique
"""
lower_bound(type_itr)
Return the minimum type in the collection `type_itr` if one exists
(mininum in the sense of `<:`). If `type_itr` is empty, return `Any`,
and in all other cases return the universal lower bound `Union{}`.
"""
function lower_bound(Ts)
isempty(Ts) && return Any
sorted = sort(collect(Ts), lt=<:)
candidate = first(sorted)
all(T -> candidate <: T, sorted[2:end]) && return candidate
return Union{}
end
function _lower_bound(Ts)
Unknown in Ts && return Unknown
return lower_bound(Ts)
end
MLJModelInterface.input_scitype(N::Node) = Unknown
MLJModelInterface.input_scitype(N::Node{<:Machine}) =
input_scitype(N.machine.model)