-
-
Notifications
You must be signed in to change notification settings - Fork 214
/
Copy pathcodegen_utils.jl
294 lines (266 loc) · 11.8 KB
/
codegen_utils.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
"""
$(TYPEDSIGNATURES)
Return the name for the `i`th argument in a function generated by `build_function_wrapper`.
"""
function generated_argument_name(i::Int)
return Symbol(:__mtk_arg_, i)
end
"""
$(TYPEDSIGNATURES)
Given the arguments to `build_function_wrapper`, return a list of assignments which
reconstruct array variables if they are present scalarized in `args`.
# Keyword Arguments
- `argument_name` a function of the form `(::Int) -> Symbol` which takes the index of
an argument to the generated function and returns the name of the argument in the
generated function.
"""
function array_variable_assignments(args...; argument_name = generated_argument_name)
# map array symbolic to an identically sized array where each element is (buffer_idx, idx_in_buffer)
var_to_arridxs = Dict{BasicSymbolic, Array{Tuple{Int, Int}}}()
for (i, arg) in enumerate(args)
# filter out non-arrays
# any element of args which is not an array is assumed to not contain a
# scalarized array symbolic. This works because the only non-array element
# is the independent variable
symbolic_type(arg) == NotSymbolic() || continue
arg isa AbstractArray || continue
# go through symbolics
for (j, var) in enumerate(arg)
var = unwrap(var)
# filter out non-array-symbolics
iscall(var) || continue
operation(var) == getindex || continue
arrvar = arguments(var)[1]
# get and/or construct the buffer storing indexes
idxbuffer = get!(
() -> map(Returns((0, 0)), eachindex(arrvar)), var_to_arridxs, arrvar)
Origin(first.(axes(arrvar))...)(idxbuffer)[arguments(var)[2:end]...] = (i, j)
end
end
assignments = Assignment[]
for (arrvar, idxs) in var_to_arridxs
# all elements of the array need to be present in `args` to form the
# reconstructing assignment
any(iszero ∘ first, idxs) && continue
# if they are all in the same buffer, we can take a shortcut and `view` into it
if allequal(Iterators.map(first, idxs))
buffer_idx = first(first(idxs))
idxs = map(last, idxs)
# if all the elements are contiguous and ordered, turn the array of indexes into a range
# to help reduce allocations
if first(idxs) < last(idxs) && vec(idxs) == first(idxs):last(idxs)
idxs = first(idxs):last(idxs)
elseif vec(idxs) == last(idxs):-1:first(idxs)
idxs = last(idxs):-1:first(idxs)
else
# Otherwise, turn the indexes into an `SArray` so they're stack-allocated
idxs = SArray{Tuple{size(idxs)...}}(idxs)
end
# view and reshape
expr = term(reshape, term(view, argument_name(buffer_idx), idxs),
size(arrvar))
else
elems = map(idxs) do idx
i, j = idx
term(getindex, argument_name(i), j)
end
# use `MakeArray` syntax and generate a stack-allocated array
expr = term(SymbolicUtils.Code.create_array, SArray, nothing,
Val(ndims(arrvar)), Val(length(arrvar)), elems...)
end
if any(x -> !isone(first(x)), axes(arrvar))
expr = term(Origin(first.(axes(arrvar))...), expr)
end
push!(assignments, arrvar ← expr)
end
return assignments
end
"""
$(TYPEDSIGNATURES)
A wrapper around `build_function` which performs the necessary transformations for
code generation of all types of systems. `expr` is the expression returned from the
generated functions, and `args` are the arguments.
# Keyword Arguments
- `p_start`, `p_end`: Denotes the indexes in `args` where the buffers of the splatted
`MTKParameters` object are present. These are collapsed into a single argument and
destructured inside the function. `p_start` must also be provided for non-split systems
since it is used by `wrap_delays`.
- `wrap_delays`: Whether to transform delayed unknowns of `sys` present in `expr` into
calls to a history function. The history function is added to the list of arguments
right before parameters, at the index `p_start`.
- `wrap_code`: Forwarded to `build_function`.
- `add_observed`: Whether to add assignment statements for observed equations in the
generated code.
- `filter_observed`: A predicate function to filter out observed equations which should
not be added to the generated code.
- `create_bindings`: Whether to explicitly destructure arrays of symbolics present in
`args` in the generated code. If `false`, all usages of the individual symbolics will
instead call `getindex` on the relevant argument. This is useful if the generated
function writes to one of its arguments and expects subsequent code to use the new
values. Note that the collapsed `MTKParameters` argument will always be explicitly
destructured regardless of this keyword argument.
- `output_type`: The type of the output buffer. If `mkarray` (see below) is `nothing`,
this will be passed to the `similarto` argument of `build_function`. If `output_type`
is `Tuple`, `expr` will be wrapped in `SymbolicUtils.Code.MakeTuple` (regardless of
whether it is scalar or an array).
- `mkarray`: A function which accepts `expr` and `output_type` and returns a code
generation object similar to `MakeArray` or `MakeTuple` to be used to generate
code for `expr`.
- `wrap_mtkparameters`: Whether to collapse parameter buffers for a split system into a
argument.
- `extra_assignments`: Extra `Assignment` statements to prefix to `expr`, after all other
assignments.
All other keyword arguments are forwarded to `build_function`.
"""
function build_function_wrapper(sys::AbstractSystem, expr, args...; p_start = 2,
p_end = is_time_dependent(sys) ? length(args) - 1 : length(args),
wrap_delays = is_dde(sys), wrap_code = identity,
add_observed = true, filter_observed = Returns(true),
create_bindings = false, output_type = nothing, mkarray = nothing,
wrap_mtkparameters = true, extra_assignments = Assignment[], cse = true, kwargs...)
isscalar = !(expr isa AbstractArray || symbolic_type(expr) == ArraySymbolic())
# filter observed equations
obs = filter(filter_observed, observed(sys))
# turn delayed unknowns into calls to the history function
if wrap_delays
history_arg = is_split(sys) ? MTKPARAMETERS_ARG : generated_argument_name(p_start)
obs = map(obs) do eq
delay_to_function(sys, eq; history_arg)
end
expr = delay_to_function(sys, expr; history_arg)
# add extra argument
args = (args[1:(p_start - 1)]..., DDE_HISTORY_FUN, args[p_start:end]...)
p_start += 1
p_end += 1
end
pdeps = parameter_dependencies(sys)
# get the constants to add to the code
cmap, _ = get_cmap(sys)
extra_constants = collect_constants(expr)
filter!(extra_constants) do c
!any(x -> isequal(c, x.lhs), cmap)
end
for c in extra_constants
push!(cmap, c ~ getdefault(c))
end
# only get the necessary observed equations, avoiding extra computation
if add_observed && !isempty(obs)
obsidxs = observed_equations_used_by(sys, expr; obs)
else
obsidxs = Int[]
end
# similarly for parameter dependency equations
pdepidxs = observed_equations_used_by(sys, expr; obs = pdeps)
for i in obsidxs
union!(pdepidxs, observed_equations_used_by(sys, obs[i].rhs; obs = pdeps))
end
# assignments for reconstructing scalarized array symbolics
assignments = array_variable_assignments(args...)
for eq in Iterators.flatten((cmap, pdeps[pdepidxs], obs[obsidxs]))
push!(assignments, eq.lhs ← eq.rhs)
end
append!(assignments, extra_assignments)
args = ntuple(Val(length(args))) do i
arg = args[i]
# for time-dependent systems, all arguments are passed through `time_varying_as_func`
# TODO: This is legacy behavior and a candidate for removal in v10 since we have callable
# parameters now.
if is_time_dependent(sys)
arg = if symbolic_type(arg) == NotSymbolic()
arg isa AbstractArray ?
map(x -> time_varying_as_func(unwrap(x), sys), arg) : arg
else
time_varying_as_func(unwrap(arg), sys)
end
end
# Make sure to use the proper names for arguments
if symbolic_type(arg) == NotSymbolic() && arg isa AbstractArray
DestructuredArgs(arg, generated_argument_name(i); create_bindings)
else
arg
end
end
# wrap into a single MTKParameters argument
if is_split(sys) && wrap_mtkparameters
if p_start > p_end
# In case there are no parameter buffers, still insert an argument
args = (args[1:(p_start - 1)]..., MTKPARAMETERS_ARG, args[(p_end + 1):end]...)
else
# cannot apply `create_bindings` here since it doesn't nest
args = (args[1:(p_start - 1)]...,
DestructuredArgs(collect(args[p_start:p_end]), MTKPARAMETERS_ARG),
args[(p_end + 1):end]...)
end
end
# add preface assignments
if has_preface(sys) && (pref = preface(sys)) !== nothing
append!(assignments, pref)
end
wrap_code = wrap_code .∘ wrap_assignments(isscalar, assignments)
# handling of `output_type` and `mkarray`
similarto = nothing
if output_type === Tuple
expr = MakeTuple(Tuple(expr))
wrap_code = wrap_code[1]
elseif mkarray === nothing
similarto = output_type
else
expr = mkarray(expr, output_type)
wrap_code = wrap_code[2]
end
# scalar `build_function` only accepts a single function for `wrap_code`.
if wrap_code isa Tuple && symbolic_type(expr) == ScalarSymbolic()
wrap_code = wrap_code[1]
end
return build_function(expr, args...; wrap_code, similarto, cse, kwargs...)
end
"""
$(TYPEDEF)
A wrapper around a generated in-place and out-of-place function. The type-parameter `P`
must be a 3-tuple where the first element is the index of the parameter object in the
arguments, the second is the expected number of arguments in the out-of-place variant
of the function, and the third is a boolean indicating whether the generated functions
are for a split system. For scalar functions, the inplace variant can be `nothing`.
"""
struct GeneratedFunctionWrapper{P, O, I} <: Function
f_oop::O
f_iip::I
end
function GeneratedFunctionWrapper{P}(foop::O, fiip::I) where {P, O, I}
GeneratedFunctionWrapper{P, O, I}(foop, fiip)
end
function (gfw::GeneratedFunctionWrapper)(args...)
_generated_call(gfw, args...)
end
@generated function _generated_call(gfw::GeneratedFunctionWrapper{P}, args...) where {P}
paramidx, nargs, issplit = P
iip = false
# IIP case has one more argument
if length(args) == nargs + 1
nargs += 1
paramidx += 1
iip = true
end
if length(args) != nargs
throw(ArgumentError("Expected $nargs arguments, got $(length(args))."))
end
# the function to use
f = iip ? :(gfw.f_iip) : :(gfw.f_oop)
# non-split systems just call it as-is
if !issplit
return :($f(args...))
end
if args[paramidx] <: Union{Tuple, MTKParameters} &&
!(args[paramidx] <: Tuple{Vararg{Number}})
# for split systems, call it as-is if the parameter object is a tuple or MTKParameters
# but not if it is a tuple of numbers
return :($f(args...))
else
# The user provided a single buffer/tuple for the parameter object, so wrap that
# one in a tuple
fargs = ntuple(Val(length(args))) do i
i == paramidx ? :((args[$i], nothing)) : :(args[$i])
end
return :($f($(fargs...)))
end
end