/
build_function.jl
756 lines (649 loc) · 29.4 KB
/
build_function.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
using SymbolicUtils.Code
using Base.Threads
abstract type BuildTargets end
struct JuliaTarget <: BuildTargets end
struct StanTarget <: BuildTargets end
struct CTarget <: BuildTargets end
struct MATLABTarget <: BuildTargets end
abstract type ParallelForm end
struct SerialForm <: ParallelForm end
struct MultithreadedForm <: ParallelForm
ntasks::Int
end
MultithreadedForm() = MultithreadedForm(2*nthreads())
"""
`build_function`
Generates a numerically-usable function from a Symbolics `Num`.
```julia
build_function(ex, args...;
expression = Val{true},
target = JuliaTarget(),
kwargs...)
```
Arguments:
- `ex`: The `Num` to compile
- `args`: The arguments of the function
- `expression`: Whether to generate code or whether to generate the compiled form.
By default, `expression = Val{true}`, which means that the code for the
function is returned. If `Val{false}`, then the returned value is compiled.
Keyword Arguments:
- `target`: The output target of the compilation process. Possible options are:
- `JuliaTarget`: Generates a Julia function
- `CTarget`: Generates a C function
- `StanTarget`: Generates a function for compiling with the Stan probabilistic
programming language
- `MATLABTarget`: Generates an anonymous function for use in MATLAB and Octave
environments
- `fname`: Used by some targets for the name of the function in the target space.
Note that not all build targets support the full compilation interface. Check the
individual target documentation for details.
"""
function build_function(args...;target = JuliaTarget(),kwargs...)
_build_function(target,args...;kwargs...)
end
function unflatten_args(f, args, N=4)
length(args) < N && return Term{Real}(f, args)
unflatten_args(f, [Term{Real}(f, group)
for group in Iterators.partition(args, N)], N)
end
# Speeds up by avoiding repeated sorting when you call `arguments`
# after editing children in Postwalk in unflatten_long_ops
function termify(op)
!istree(op) && return op
Term{symtype(op)}(operation(op), arguments(op); metadata=op.metadata)
end
function unflatten_long_ops(op, N=4)
op = value(op)
op = termify(op)
!istree(op) && return Num(op)
rule1 = @rule((+)(~~x) => length(~~x) > N ? unflatten_args(+, ~~x, 4) : nothing)
rule2 = @rule((*)(~~x) => length(~~x) > N ? unflatten_args(*, ~~x, 4) : nothing)
Num(Rewriters.Postwalk(Rewriters.Chain([rule1, rule2]))(op))
end
# Scalar output
destructure_arg(arg::Union{AbstractArray, Tuple}, inbounds) = DestructuredArgs(map(value, arg), inbounds=inbounds)
destructure_arg(arg, _) = arg
function _build_function(target::JuliaTarget, op, args...;
conv = toexpr,
expression = Val{true},
expression_module = @__MODULE__(),
checkbounds = false,
linenumbers = true)
dargs = map(arg -> destructure_arg(arg, !checkbounds), [args...])
expr = toexpr(Func(dargs, [], unflatten_long_ops(op)))
if expression == Val{true}
expr
else
_build_and_inject_function(expression_module, expr)
end
end
function _build_and_inject_function(mod::Module, ex)
if ex.head == :function && ex.args[1].head == :tuple
ex.args[1] = Expr(:call, :($mod.$(gensym())), ex.args[1].args...)
elseif ex.head == :(->)
return _build_and_inject_function(mod, Expr(:function, ex.args...))
end
# XXX: Workaround to specify the module as both the cache module AND context module.
# Currently, the @RuntimeGeneratedFunction macro only sets the context module.
module_tag = getproperty(mod, RuntimeGeneratedFunctions._tagname)
RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex)
end
toexpr(n::Num, st) = toexpr(value(n), st)
function fill_array_with_zero!(x::AbstractArray)
if eltype(x) <: AbstractArray
foreach(fill_array_with_zero!, x)
else
fill!(x, false)
end
return x
end
"""
Build function target: `JuliaTarget`
```julia
function _build_function(target::JuliaTarget, rhss, args...;
conv = toexpr, expression = Val{true},
checkbounds = false,
linenumbers = false,
headerfun = addheader, outputidxs=nothing,
convert_oop = true, force_SA = false,
skipzeros = outputidxs===nothing,
fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC),
parallel=SerialForm(), kwargs...)
```
Generates a Julia function which can then be utilized for further evaluations.
If expression=Val{false}, the return is a Julia function which utilizes
RuntimeGeneratedFunctions.jl in order to be free of world-age issues.
If the `rhss` is a scalar, the generated function is a function
with a scalar output, otherwise if it's an `AbstractArray`, the output
is two functions, one for out-of-place AbstractArray output and a second which
is a mutating function. The outputted functions match the given argument order,
i.e., f(u,p,args...) for the out-of-place and scalar functions and
`f!(du,u,p,args..)` for the in-place version.
Special Keyword Argumnets:
- `parallel`: The kind of parallelism to use in the generated function. Defaults
to `SerialForm()`, i.e. no parallelism. Note that the parallel forms are not
exported and thus need to be chosen like `Symbolics.SerialForm()`.
The choices are:
- `SerialForm()`: Serial execution.
- `MultithreadedForm()`: Multithreaded execution with a static split, evenly
splitting the number of expressions per thread.
- `conv`: The conversion function of symbolic types to Expr. By default this uses
the `toexpr` function.
- `checkbounds`: For whether to enable bounds checking inside of the generated
function. Defaults to false, meaning that `@inbounds` is applied.
- `linenumbers`: Determines whether the generated function expression retains
the line numbers. Defaults to true.
- `convert_oop`: Determines whether the OOP version should try to convert
the output to match the type of the first input. This is useful for
cases like LabelledArrays or other array types that carry extra
information. Defaults to true.
- `force_SA`: Forces the output of the OOP version to be a StaticArray.
Defaults to `false`, and outputs a static array when the first argument
is a static array.
- `skipzeros`: Whether to skip filling zeros in the in-place version if the
filling function is 0.
- `fillzeros`: Whether to perform `fill(out,0)` before the calculations to ensure
safety with `skipzeros`.
"""
function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
expression = Val{true},
expression_module = @__MODULE__(),
checkbounds = false,
linenumbers = false,
outputidxs=nothing,
skipzeros = false,
wrap_code = (nothing, nothing),
fillzeros = skipzeros && !(typeof(rhss)<:SparseMatrixCSC),
parallel=SerialForm(), kwargs...)
dargs = map(arg -> destructure_arg(arg, !checkbounds), [args...])
i = findfirst(x->x isa DestructuredArgs, dargs)
similarto = i === nothing ? Array : dargs[i].name
oop_expr = Func(dargs, [], make_array(parallel, dargs, rhss, similarto))
if !isnothing(wrap_code[1])
oop_expr = wrap_code[1](oop_expr)
end
out = Sym{Any}(gensym("out"))
ip_expr = Func([out, dargs...], [], set_array(parallel, dargs, out, outputidxs, rhss, checkbounds, skipzeros))
if !isnothing(wrap_code[2])
ip_expr = wrap_code[2](ip_expr)
end
if expression == Val{true}
return toexpr(oop_expr), toexpr(ip_expr)
else
return _build_and_inject_function(expression_module, toexpr(oop_expr)),
_build_and_inject_function(expression_module, toexpr(ip_expr))
end
end
function make_array(s, dargs, arr, similarto)
Base.@warn("Parallel form of $(typeof(s)) not implemented")
_make_array(arr, similarto)
end
function make_array(s::SerialForm, dargs, arr, similarto)
_make_array(arr, similarto)
end
function make_array(s::MultithreadedForm, closed_args, arr, similarto)
per_task = ceil(Int, length(arr) / s.ntasks)
slices = collect(Iterators.partition(arr, per_task))
arrays = map(slices) do slice
Func(closed_args, [], _make_array(slice, similarto)), closed_args
end
SpawnFetch{MultithreadedForm}(first.(arrays), last.(arrays), vcat)
end
struct Funcall{F, T}
f::F
args::T
end
(f::Funcall)() = f.f(f.args...)
function toexpr(p::SpawnFetch{MultithreadedForm}, st)
args = isnothing(p.args) ?
Iterators.repeated((), length(p.exprs)) : p.args
spawns = map(p.exprs, args) do thunk, a
ex = :($Funcall($(@RuntimeGeneratedFunction(toexpr(thunk, st))),
($(toexpr.(a, (st,))...),)))
quote
let
task = Base.Threads.Task($ex)
Base.Threads.schedule(task)
task
end
end
end
quote
$(toexpr(p.combine, st))(map(fetch, ($(spawns...),))...)
end
end
function _make_array(rhss::AbstractSparseArray, similarto)
arr = map(x->_make_array(x, similarto), rhss)
if !(arr isa AbstractSparseArray)
_make_array(arr, similarto)
else
MakeSparseArray(arr)
end
end
function _make_array(rhss::AbstractArray, similarto)
arr = map(x->_make_array(x, similarto), rhss)
# Ugh reshaped array of a sparse array when mapped gives a sparse array
if arr isa AbstractSparseArray
_make_array(arr, similarto)
else
MakeArray(arr, similarto)
end
end
_make_array(x, similarto) = unflatten_long_ops(x)
## In-place version
function set_array(p, closed_vars, args...)
Base.@warn("Parallel form of $(typeof(p)) not implemented")
_set_array(args...)
end
function set_array(s::SerialForm, closed_vars, args...)
_set_array(args...)
end
function set_array(s::MultithreadedForm, closed_args, out, outputidxs, rhss, checkbounds, skipzeros)
if rhss isa AbstractSparseArray
return set_array(LiteralExpr(:($out.nzval)),
nothing,
rhss.nzval,
checkbounds,
skipzeros)
end
if outputidxs === nothing
outputidxs = collect(eachindex(rhss))
end
per_task = ceil(Int, length(rhss) / s.ntasks)
# TODO: do better partitioning when skipzeros is present
slices = collect(Iterators.partition(zip(outputidxs, rhss), per_task))
arrays = map(slices) do slice
idxs, vals = first.(slice), last.(slice)
Func([out, closed_args...], [],
_set_array(out, idxs, vals, checkbounds, skipzeros)), [out, closed_args...]
end
SpawnFetch{MultithreadedForm}(first.(arrays), last.(arrays), @inline noop(args...) = nothing)
end
function _set_array(out, outputidxs, rhss::AbstractSparseArray, checkbounds, skipzeros)
_set_array(LiteralExpr(:($out.nzval)), nothing, rhss.nzval, checkbounds, skipzeros)
end
function _set_array(out, outputidxs, rhss::AbstractArray, checkbounds, skipzeros)
if outputidxs === nothing
outputidxs = collect(eachindex(rhss))
end
# sometimes outputidxs is a Tuple
ii = findall(i->!(rhss[i] isa AbstractArray) && !(skipzeros && _iszero(rhss[i])), eachindex(outputidxs))
jj = findall(i->rhss[i] isa AbstractArray, eachindex(outputidxs))
exprs = []
rhss_scalar = unflatten_long_ops.(vec(rhss[ii]))
setterexpr = SetArray(!checkbounds,
out,
AtIndex.(vec(collect(outputidxs[ii])),
rhss_scalar))
push!(exprs, setterexpr)
for j in jj
push!(exprs, _set_array(LiteralExpr(:($out[$j])), nothing, rhss[j], checkbounds, skipzeros))
end
LiteralExpr(quote
$(exprs...)
end)
end
_set_array(out, outputidxs, rhs, checkbounds, skipzeros) = unflatten_long_ops(rhs)
function vars_to_pairs(name,vs::Union{Tuple, AbstractArray}, symsdict=Dict())
vs_names = tosymbol.(vs)
for (v,k) in zip(vs_names, vs)
symsdict[k] = Sym{symtype(k)}(v)
end
exs = [:($name[$i]) for (i, u) ∈ enumerate(vs)]
vs_names,exs
end
function vars_to_pairs(name,vs, symsdict)
symsdict[vs] = Sym{symtype(vs)}(tosymbol(vs))
[tosymbol(vs)], [name]
end
get_varnumber(varop, vars::Vector) = findfirst(x->isequal(x,varop),vars)
get_varnumber(varop, var) = isequal(var,varop) ? 0 : nothing
buildvarnumbercache(args...) = Dict([isa(arg,AbstractArray) ? el=>(argi,eli) : arg=>(argi,0)
for (argi,arg) in enumerate(args) for (eli,el) in enumerate(arg)])
function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],offset = 0,
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)])
O = value(O)
if O isa Sym || isa(operation(O), Sym)
(j,i) = get(varnumbercache, O, (nothing, nothing))
if !isnothing(j)
return i==0 ? :($(rhsnames[j])) : :($(rhsnames[j])[$(i+offset)])
end
end
if istree(O)
Expr(:call, Symbol(operation(O)), (numbered_expr(x,varnumbercache,args...;offset=offset,lhsname=lhsname,
rhsnames=rhsnames,varordering=varordering) for x in arguments(O))...)
elseif O isa Sym
tosymbol(O, escape=false)
else
O
end
end
function numbered_expr(de::Equation,varnumbercache,args...;varordering = args[1],
lhsname=gensym("du"),rhsnames=[gensym("MTK") for i in 1:length(args)],offset=0)
varordering = value.(args[1])
var = var_from_nested_derivative(de.lhs)[1]
i = findfirst(x->isequal(tosymbol(x isa Sym ? x : operation(x), escape=false), tosymbol(var, escape=false)),varordering)
:($lhsname[$(i+offset)] = $(numbered_expr(de.rhs,varnumbercache,args...;offset=offset,
varordering = varordering,
lhsname = lhsname,
rhsnames = rhsnames)))
end
numbered_expr(c,args...;kwargs...) = c
numbered_expr(c::Num,args...;kwargs...) = error("Num found")
# Replace certain multiplication and power expressions so they form valid C code
# Extra factors of 1 are hopefully eliminated by the C compiler
function coperators(expr)
expr isa Expr || return expr
for e in expr.args
if e isa Expr
coperators(e)
end
end
# Introduce another factor 1 to prevent contraction of terms like "5 * t" to "5t" (not valid C code)
if expr.head==:call && expr.args[1]==:* && length(expr.args)==3 && isa(expr.args[2], Real) && isa(expr.args[3], Symbol)
push!(expr.args, 1)
# Power operator does not exist in C, replace by multiplication or "pow"
elseif expr.head==:call && expr.args[1]==:^
@assert length(expr.args)==3 "Don't know how to handle ^ operation with <> 2 arguments"
x = expr.args[2]
n = expr.args[3]
empty!(expr.args)
# Replace by multiplication/division if
# x is a symbol and n is a small integer
# x is a more complex expression and n is ±1
# n is exactly 0
if (isa(n,Integer) && ((isa(x, Symbol) && abs(n) <= 3) || abs(n) <= 1)) || n==0
if n >= 0
append!(expr.args, [:*, fill(x, n)...])
# fill up with factor 1 so this expr can still be a multiplication
while length(expr.args) < 3
push!(expr.args, 1)
end
else # inverse of the above
if n==-1
term = x
else
term = :( ($(x)) ^ ($(-n)))
coperators(term)
end
append!(expr.args, [:/, 1., term])
end
#... otherwise use "pow" function
else
append!(expr.args, [:pow, x, n])
end
end
expr
end
"""
Build function target: `CTarget`
```julia
function _build_function(target::CTarget, eqs::Array{<:Equation}, args...;
conv = toexpr, expression = Val{true},
fname = :diffeqf,
lhsname=:du,rhsnames=[Symbol("RHS\$i") for i in 1:length(args)],
libpath=tempname(),compiler=:gcc)
```
This builds an in-place C function. Only works on arrays of equations. If
`expression == Val{false}`, then this builds a function in C, compiles it,
and returns a lambda to that compiled function. These special keyword arguments
control the compilation:
- libpath: the path to store the binary. Defaults to a temporary path.
- compiler: which C compiler to use. Defaults to :gcc, which is currently the
only available option.
"""
function _build_function(target::CTarget, eqs::Array{<:Equation}, args...;
conv = toexpr, expression = Val{true},
fname = :diffeqf,
lhsname=:du,rhsnames=[Symbol("RHS$i") for i in 1:length(args)],
libpath=tempname(),compiler=:gcc)
@warn "build_function(::Array{<:Equation}...) is deprecated. Use build_function(::AbstractArray...) instead."
varnumbercache = buildvarnumbercache(args...)
differential_equation = string(join([numbered_expr(eq,varnumbercache,args...,lhsname=lhsname,
rhsnames=rhsnames,offset=-1) for
(i, eq) ∈ enumerate(eqs)],";\n "),";")
argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:Array ? "double* $(rhsnames[i])" : "double $(rhsnames[i])" for i in 1:length(args)]),", ")
ex = """
void $fname($(argstrs...)) {
$differential_equation
}
"""
if expression == Val{true}
return ex
else
@assert compiler == :gcc
ex = build_function(eqs,args...;target=Symbolics.CTarget())
open(`gcc -fPIC -O3 -msse3 -xc -shared -o $(libpath * "." * Libdl.dlext) -`, "w") do f
print(f, ex)
end
@RuntimeGeneratedFunction(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
end
end
"""
Build function target: `CTarget`
```julia
function _build_function(target::CTarget, ex::AbstractArray, args...;
columnmajor = true,
conv = toexpr,
expression = Val{true},
fname = :diffeqf,
lhsname = :du,
rhsnames = [Symbol("RHS\$i") for i in 1:length(args)],
libpath = tempname(),
compiler = :gcc)
```
This builds an in-place C function. Only works on expressions. If
`expression == Val{false}`, then this builds a function in C, compiles it,
and returns a lambda to that compiled function. These special keyword arguments
control the compilation:
- libpath: the path to store the binary. Defaults to a temporary path.
- compiler: which C compiler to use. Defaults to :gcc, which is currently the
only available option.
"""
function _build_function(target::CTarget, ex::AbstractArray, args...;
columnmajor = true,
conv = toexpr,
expression = Val{true},
fname = :diffeqf,
lhsname = :du,
rhsnames = [Symbol("RHS$i") for i in 1:length(args)],
libpath = tempname(),
compiler = :gcc)
if !columnmajor
return _build_function(target, hcat([row for row ∈ eachrow(ex)]...), args...;
columnmajor = true,
conv = conv,
fname = fname,
lhsname = lhsname,
rhsnames = rhsnames,
libpath = libpath,
compiler = compiler)
end
varnumbercache = buildvarnumbercache(args...)
equations = Vector{String}()
for col ∈ 1:size(ex,2)
for row ∈ 1:size(ex,1)
lhs = string(lhsname, "[", (col-1) * size(ex,1) + row-1, "]")
rhs = numbered_expr(value(ex[row, col]), varnumbercache, args...;
lhsname = lhsname,
rhsnames = rhsnames,
offset = -1) |> coperators |> string # Filter through coperators to produce valid C code in more cases
push!(equations, string(lhs, " = ", rhs, ";"))
end
end
argstrs = join(vcat("double* $(lhsname)",[typeof(args[i])<:Array ? "const double* $(rhsnames[i])" : "const double $(rhsnames[i])" for i in 1:length(args)]),", ")
ccode = """
#include <math.h>
void $fname($(argstrs...)) {$([string("\n ", eqn) for eqn ∈ equations]...)\n}
"""
if expression == Val{true}
return ccode
else
@assert compiler == :gcc
open(`gcc -fPIC -O3 -msse3 -xc -shared -o $(libpath * "." * Libdl.dlext) -`, "w") do f
print(f, ccode)
end
@RuntimeGeneratedFunction(:((du::Array{Float64},u::Array{Float64},p::Array{Float64},t::Float64) -> ccall(("diffeqf", $libpath), Cvoid, (Ptr{Float64}, Ptr{Float64}, Ptr{Float64}, Float64), du, u, p, t)))
end
end
_build_function(target::CTarget, ex::Num, args...; kwargs...) = _build_function(target, [ex], args...; kwargs...)
"""
Build function target: `StanTarget`
```julia
function _build_function(target::StanTarget, eqs::Array{<:Equation}, vs, ps, iv;
conv = toexpr, expression = Val{true},
fname = :diffeqf, lhsname=:internal_var___du,
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
```
This builds an in-place Stan function compatible with the Stan differential equation solvers.
Unlike other build targets, this one requestions (vs, ps, iv) as the function arguments.
Only allowed on arrays of equations.
"""
function _build_function(target::StanTarget, eqs::Array{<:Equation}, vs, ps, iv;
conv = toexpr, expression = Val{true},
fname = :diffeqf, lhsname=:internal_var___du,
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
@warn "build_function(::Array{<:Equation}...) is deprecated. Use build_function(::AbstractArray...) instead."
@assert expression == Val{true}
varnumbercache = buildvarnumbercache(vs,ps)
differential_equation = string(join([numbered_expr(eq,varnumbercache,vs,ps,lhsname=lhsname,
rhsnames=rhsnames) for
(i, eq) ∈ enumerate(eqs)],";\n "),";")
"""
real[] $fname(real $(conv(iv)),real[] $(rhsnames[1]),real[] $(rhsnames[2]),real[] x_r,int[] x_i) {
real $lhsname[$(length(eqs))];
$differential_equation
return $lhsname;
}
"""
end
"""
Build function target: `StanTarget`
```julia
function _build_function(target::StanTarget, ex::AbstractArray, vs, ps, iv;
columnmajor = true,
conv = toexpr,
expression = Val{true},
fname = :diffeqf, lhsname=:internal_var___du,
rhsnames = [:internal_var___u,:internal_var___p,:internal_var___t])
```
This builds an in-place Stan function compatible with the Stan differential equation solvers.
Unlike other build targets, this one requestions (vs, ps, iv) as the function arguments.
Only allowed on expressions, and arrays of expressions.
"""
function _build_function(target::StanTarget, ex::AbstractArray, vs, ps, iv;
columnmajor = true,
conv = toexpr,
expression = Val{true},
fname = :diffeqf, lhsname=:internal_var___du,
rhsnames = [:internal_var___u,:internal_var___p,:internal_var___t])
@assert expression == Val{true}
if !columnmajor
return _build_function(target, hcat([row for row ∈ eachrow(ex)]...), vs, ps, iv;
columnmajor = true,
conv = conv,
expression = expression,
fname = fname,
lhsname = lhsname,
rhsnames = rhsnames)
end
varnumbercache = buildvarnumbercache(vs,ps,iv)
equations = Vector{String}()
for col ∈ 1:size(ex,2)
for row ∈ 1:size(ex,1)
lhs = string(lhsname, "[", (col-1) * size(ex,1) + row, "]")
rhs = numbered_expr(value(ex[row, col]), varnumbercache, vs, ps, iv;
lhsname = lhsname,
rhsnames = rhsnames,
offset = 0) |> string
push!(equations, string(lhs, " = ", rhs, ";"))
end
end
"""
real[] $fname(real $(conv(iv)),real[] $(rhsnames[1]),real[] $(rhsnames[2]),real[] x_r,int[] x_i) {
real $lhsname[$(length(equations))];
$([eqn == equations[end] ? string(" ", eqn) : string(" ", eqn, "\n") for eqn ∈ equations]...)
return $lhsname;
}
"""
end
_build_function(target::StanTarget, ex::Num, vs, ps, iv; kwargs...) = _build_function(target, [ex], vs, ps, iv; kwargs...)
"""
Build function target: `MATLABTarget`
```julia
function _build_function(target::MATLABTarget, eqs::Array{<:Equation}, args...;
conv = toexpr, expression = Val{true},
lhsname=:internal_var___du,
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
```
This builds an out of place anonymous function @(t,rhsnames[1]) to be used in MATLAB.
Compatible with the MATLAB differential equation solvers. Only allowed on expressions,
and arrays of expressions.
"""
function _build_function(target::MATLABTarget, eqs::Array{<:Equation}, args...;
conv = toexpr, expression = Val{true},
fname = :diffeqf, lhsname=:internal_var___du,
rhsnames=[:internal_var___u,:internal_var___p,:internal_var___t])
@warn "build_function(::Array{<:Equation}...) is deprecated. Use build_function(::AbstractArray...) instead."
@assert expression == Val{true}
varnumbercache = buildvarnumbercache(args...)
matstr = join([numbered_expr(eq.rhs,varnumbercache,args...,lhsname=lhsname,
rhsnames=rhsnames) for
(i, eq) ∈ enumerate(eqs)],"; ")
matstr = replace(matstr,"["=>"(")
matstr = replace(matstr,"]"=>")")
matstr = "$fname = @(t,$(rhsnames[1])) ["*matstr*"];"
matstr
end
"""
Build function target: `MATLABTarget`
```julia
function _build_function(target::MATLABTarget, ex::AbstractArray, args...;
columnmajor = true,
conv = toexpr,
expression = Val{true},
fname = :diffeqf,
lhsname = :internal_var___du,
rhsnames = [:internal_var___u,:internal_var___p,:internal_var___t])
```
This builds an out of place anonymous function @(t,rhsnames[1]) to be used in MATLAB.
Compatible with the MATLAB differential equation solvers. Only allowed on expressions,
and arrays of expressions.
"""
function _build_function(target::MATLABTarget, ex::AbstractArray, args...;
columnmajor = true,
conv = toexpr,
expression = Val{true},
fname = :diffeqf,
lhsname = :internal_var___du,
rhsnames = [:internal_var___u,:internal_var___p,:internal_var___t])
@assert expression == Val{true}
if !columnmajor
return _build_function(target, hcat([row for row ∈ eachrow(ex)]...), args...;
columnmajor = true,
conv = conv,
expression = expression,
fname = fname,
lhsname = lhsname,
rhsnames = rhsnames)
end
varnumbercache = buildvarnumbercache(args...)
matstr = ""
for row ∈ 1:size(ex,1)
row_strings = Vector{String}()
for col ∈ 1:size(ex,2)
lhs = string(lhsname, "[", (col-1) * size(ex,1) + row-1, "]")
rhs = numbered_expr(value(ex[row, col]), varnumbercache, args...;
lhsname = lhsname,
rhsnames = rhsnames,
offset = 0) |> string
push!(row_strings, rhs)
end
matstr = matstr * " " * join(row_strings, ", ") * ";\n"
end
matstr = replace(matstr,"["=>"(")
matstr = replace(matstr,"]"=>")")
matstr = "$fname = @(t,$(rhsnames[1])) [\n"*matstr*"];\n"
return matstr
end
_build_function(target::MATLABTarget, ex::Num, args...; kwargs...) = _build_function(target, [ex], args...; kwargs...)