This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
/
executor.jl
254 lines (212 loc) · 9.1 KB
/
executor.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import Base: bind
"""
Executor
An executor is a realization of a symbolic architecture defined by a `SymbolicNode`.
The actual forward and backward computation specified by the network architecture can
be carried out with an executor.
"""
mutable struct Executor
handle :: MX_ExecutorHandle
symbol :: SymbolicNode
arg_arrays :: VecOfNDArray
grad_arrays :: Vector{Union{Cvoid,<:NDArray}}
aux_arrays :: VecOfNDArray
outputs :: VecOfNDArray
arg_dict :: Dict{Symbol}
aux_dict :: Dict{Symbol}
end
function Executor(hdl::MX_ExecutorHandle, sym::SymbolicNode,
arg_arrays::VecOfNDArray, grad_arrays::AbstractVector,
aux_arrays::VecOfNDArray)
# get output arrays
ref_size = Ref{MX_uint}(0)
ref_hdls = Ref{Ptr{MX_handle}}(C_NULL)
@mxcall(:MXExecutorOutputs, (MX_handle, Ref{MX_uint}, Ref{Ptr{MX_handle}}),
hdl, ref_size, ref_hdls)
out_hdrs = unsafe_wrap(Array, ref_hdls[], ref_size[])
out_arrays = [NDArray(MX_NDArrayHandle(x)) for x in out_hdrs]
arg_names = list_arguments(sym)
@assert(length(arg_names) == length(unique(arg_names)), "Duplicated names in arguments: $arg_names")
arg_dict = Dict(zip(arg_names, arg_arrays))
aux_names = list_auxiliary_states(sym)
@assert(length(aux_names) == length(unique(aux_names)), "Duplicated names in auxiliary states: $aux_names")
aux_dict = Dict(zip(aux_names, aux_arrays))
Executor(hdl, sym, arg_arrays, grad_arrays, aux_arrays, out_arrays, arg_dict, aux_dict)
end
Base.unsafe_convert(::Type{MX_handle}, obj::Executor) =
Base.unsafe_convert(MX_handle, obj.handle)
Base.convert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj)
Base.cconvert(t::Type{MX_handle}, obj::Executor) = Base.unsafe_convert(t, obj)
function _get_ndarray_inputs(arg_key::AbstractString, args::VecOfNDArray,
arg_names::Vector{Symbol}, allow_missing::Bool)
@assert(length(args) == length(arg_names), "Length of $arg_key does not match number of arguments")
return (MX_handle[args...], args)
end
function _get_ndarray_inputs(arg_key::AbstractString, args::Dict{Symbol},
arg_names::Vector{Symbol}, allow_missing::Bool)
args_vec = map(arg_names) do name
arr = get(args, name, nothing)
if !allow_missing
@assert(!isa(arr, Cvoid), "Must specify all arguments in $arg_key ($name is missing)")
end
arr
end
# help the type inference
if allow_missing
args_vec = Union{NDArray,Cvoid}[args_vec...]
else
args_vec = NDArray[args_vec...]
end
args_hdr = MX_handle[(isa(x,Cvoid) ? MX_handle(0) : x) for x in args_vec]
return (args_hdr, args_vec)
end
"""
bind(sym, ctx, args; args_grad=Dict(), aux_states=Dict(), grad_req=GRAD_WRITE)
Create an `Executor` by binding a `SymbolicNode` to concrete `NDArray`.
# Arguments
* `sym::SymbolicNode`: the network architecture describing the computation graph.
* `ctx::Context`: the context on which the computation should run.
* `args`: either a list of `NDArray` or a dictionary of name-array pairs. Concrete
arrays for all the inputs in the network architecture. The inputs typically include
network parameters (weights, bias, filters, etc.), data and labels.
See [`list_arguments`](@ref) and [`infer_shape`](@ref).
* `args_grad`: a `Vector` of `NDArray` or a `Dict` contains `NDArray`
* `aux_states`: a `Vector` of `NDArray` or a `Dict` contains `NDArray`
* `grad_req`: single value, a `Vector` of `GRAD_REQ` or a `Dict{Symbol,GRAD_REQ}`
"""
function bind(self::SymbolicNode, ctx::Context, args;
args_grad = Dict{Symbol,NDArray}(),
aux_states = Dict{Symbol,NDArray}(),
grad_req = GRAD_WRITE)
arg_names = list_arguments(self)
args_hdr, args = _get_ndarray_inputs("args", args, arg_names, false)
args_grad_hdr, args_grad = _get_ndarray_inputs("args_grad", args_grad, arg_names, true)
aux_args_hdr, aux_states = _get_ndarray_inputs("aux_states", aux_states, list_auxiliary_states(self), false)
if isa(grad_req, GRAD_REQ)
reqs = MX_uint[MX_uint(grad_req) for i=1:length(args)]
elseif isa(grad_req, Vector{GRAD_REQ})
@assert(length(grad_req) == length(args))
reqs = MX_uint[MX_uint.(grad_req)...]
elseif isa(grad_req, Dict{Symbol, GRAD_REQ})
reqs = MX_uint[MX_uint(get(grad_req, name, GRAD_NOP)) for name in arg_names]
end
ref_hdr = Ref{MX_handle}(0)
@mxcall(:MXExecutorBind,
(MX_handle, Cint, Cint, MX_uint, Ptr{MX_handle}, Ptr{MX_handle}, Ptr{MX_uint},
MX_uint, Ptr{MX_handle}, Ref{MX_handle}),
self, ctx.device_type, ctx.device_id, length(args), args_hdr,
args_grad_hdr, reqs, length(aux_states), aux_args_hdr, ref_hdr)
args_grad = convert(Vector{Union{Cvoid,NDArray}}, args_grad)
executor = Executor(MX_ExecutorHandle(ref_hdr[]), self,
args, args_grad, aux_states)
end
function bind(x::SymbolicNode; context::Context = cpu(), kwargs...)
kwargs = Dict(kwargs)
@assert(haskey(kwargs, :args), "Must specify args")
args = pop!(kwargs, :args)
bind(x, context, args; kwargs...)
end
function simple_bind(self::SymbolicNode, ctx::Context;
grad_req::Union{GRAD_REQ,Dict{Symbol,GRAD_REQ}} = GRAD_WRITE,
kwargs...)
arg_shapes, out_shapes, aux_shapes = infer_shape(self; kwargs...)
@assert(!isa(arg_shapes, Cvoid), "Information not enough to perform complete shape inference")
arg_arrays = NDArray[zeros(shape, ctx) for shape in arg_shapes]
arg_names = list_arguments(self)
grad_arrays = Dict{Symbol,NDArray}()
if grad_req != GRAD_NOP
shapes = zip(arg_names, arg_shapes)
# if not in provided data, should be parameters
provided_data_names = [x[1] for x in kwargs]
shapes = filter(x -> !in(x[1], provided_data_names), shapes)
# Remove all gradients for nop params
# if isa(grad_req, Dict{Symbol, GRAD_REQ})
# shapes = filter(x -> grad_req[x[1]] != GRAD_NOP,shapes)
# end
for (name, shape) in shapes
grad_arrays[name] = zeros(shape, ctx)
end
end
aux_arrays = NDArray[zeros(shape, ctx) for shape in aux_shapes]
return bind(self, ctx, arg_arrays, args_grad=grad_arrays, grad_req=grad_req, aux_states=aux_arrays)
end
function forward(self::Executor; is_train::Bool = false, kwargs...)
for (k,v) in kwargs
@assert(k ∈ self.arg_dict, "Unknown argument $k")
@assert(isa(v, NDArray), "Keyword argument $k must be an NDArray")
copy!(self.arg_dict[k], v)
end
@mxcall(:MXExecutorForward, (MX_handle, Cint), self, is_train)
self.outputs
end
backward(x::Executor) = backward(x, NDArray[])
backward(x::Executor, out_grad::NDArray) = backward(x, [out_grad])
backward(x::Executor, out_grads::VecOfNDArray) =
@mxcall(:MXExecutorBackward, (MX_handle, MX_uint, Ptr{MX_handle}),
x, length(out_grads), MX_handle[out_grads...])
function copy_params_from(self::Executor, arg_params::Dict{Symbol},
aux_params::Dict{Symbol} = Dict{Symbol,Any}();
allow_extra_params::Bool = false)
for (name, array) in arg_params
if haskey(self.arg_dict, name)
copy!(self.arg_dict[name], array)
else
@assert(allow_extra_params, "Extra params $name not in the arguments")
end
end
for (name, array) in aux_params
if haskey(self.aux_dict, name)
copy!(self.aux_dict[name], array)
else
@assert(allow_extra_params, "Extra auxiliary state $name not recognized")
end
end
end
Base.show(io::IO, x::Executor) =
print(io, "mx.", split(string(typeof(x)), '.')[end], " ", x.handle.value)
"""
print([io::IO], x::Executor)
Get a debug string about internal execution plan.
Can be used to get an estimated about the memory cost.
```julia
julia> x = mx.Variable(:x)
MXNet.mx.SymbolicNode x
julia> exec = mx.bind(x + 1, mx.cpu(), Dict(:x => mx.ones(2,3)))
mx.Executor Ptr{Nothing} @0x000055c3dee9eb30
julia> print(exec)
Symbol Outputs:
output[0]=_plus_scalar0(0)
Variable:x
--------------------
Op:_plus_scalar, Name=_plus_scalar0
Inputs:
arg[0]=x(0) version=0
Attrs:
scalar=1.00000000e+00
Total 0 MB allocated
Total 11 TempSpace resource requested
```
"""
Base.print(io::IO, x::Executor) = print(io, debug_str(x))
Base.print(x::Executor) = print(STDOUT, x)
function debug_str(x::Executor)
s_ref = Ref{Cstring}(C_NULL)
@mxcall(:MXExecutorPrint, (MX_handle, Ptr{Cstring}), x.handle, s_ref)
unsafe_string(s_ref[])
end