Replies: 8 comments 18 replies
-
CC @junpenglao |
Beta Was this translation helpful? Give feedback.
-
From what I remember, Aside from a guess at the "historical" reasons why
Using the auto-updates mechanism as an example, I don't see why it wouldn't be possible. We don't want the We really need to do a full assessment of the The point is that we shouldn't need the caller to manage the Aside from whether or not we can technically remove updates at the user-level, we need to consider the options/flexibility that's offered by the current approach and address that. For instance, if a user omits the I don't think there needs to be much of a challenge involved in addressing that, though. For example, we could provide a keyword that specifies the shared variables to be updated, and our new interface would automatically handle the rest; otherwise, by default it would perform all updates. After considering all this, I'm starting to think that removing |
Beta Was this translation helpful? Give feedback.
-
The types of the outputs for the inner/loop function do not need to be specified separately. All one needs to do is construct dummy variables for the appropriate "inner-inputs" (i.e. the inputs to the inner function) and evaluate the inner function with those variables to get an "inner-graph" that represents the loop body. The output This is what the That's really all there is to |
Beta Was this translation helpful? Give feedback.
-
To clarify, the only substantial difference (in terms of the broad design) between the general "carry- That was likely done to "simplify" the implementation and interpretation of Regardless, to determine that something is a "TAP" term according to |
Beta Was this translation helpful? Give feedback.
-
Here is another possibility that fully embraces that fact that looping primitives are loop = aesara.while_loop(cond_fn, body_fn)
last_value = loop(init_value) And scan = aesara.scan(body_fn)
last_value, acc = scan(init_value, sequences) |
Beta Was this translation helpful? Give feedback.
This comment has been hidden.
This comment has been hidden.
-
|
Beta Was this translation helpful? Give feedback.
-
I've just added two issues that, when resolved, would almost entirely remove the need to use shared variables when combining With those in place (well, just #739, really), |
Beta Was this translation helpful? Give feedback.
-
Let's open a discussion on a new API for the looping primitives in
aesara
.Looping in
aesara
revolves around oneOp
, theScan
op. One can create aScan
Op
by calling the functionscan
:With the following arguments:
fn
is the function that is called at each iteration;sequences
is the list of sequences over which to iterate;outputs_info
are the variables to be carried around;non_sequences
is the list of arguments that are passed tofn
at each step;n_steps
(required) is the number of steps to take.It returns two values:
results
aVariable
or a list ofVariable
updates
with the update rules for all the shared variables used inscan
Issues with the user interface
The user interface makes very little sense when you do not know what it is actually doing internally. Nevertheless users should be able to use the main features of
aesara
without having to know what happens in the background, the same way one does not have to know what tracing and XLA are all about when usingJAX
.I have also identified the following difficulties with the user interface, most stem from the fact that the current implementation tries to do too many things with a single operator.
output_info
. It can be a list ofVariable
or adictionary that specifies theinitial
value andtaps
;outputs_info
to be able to make the distrinction between which returned values are carried over and which are just collected.Scan
represents is determined by which arguments are provided;Scan
behave as a while loop the body function needs to return anaesara.scan.utils.until
object;aesara.utils.until
instance to makeScan
behave as a while loop the body function is executed at least once;updates
returned value is (unnecessarily?) cumbersome;scan
(such asoutputs_info
, which is used very often) as a kwarg.scan
only acceptsTensorVariables
as arguments. While the design choice simplifies many things, it pushes complexity out to downstream applications. This can be a real issue, see aehmc’s codebase.I think it is possible to adress all these shortcomings with minimal change to the existing
Scan
operator. I suggest to provide user-facing APIs for condition- and collection-controlled loops.Motus operandi
The idea is to deprecate the current functions and Ops progressively:
Scan
;Scan
operator. Rename;Proposal
Python has two different kinds of loop:
while
for x in iterable
and to reduce cognitive load we should provide
aesara
primitives that map to these constructions. I also suggest we try to simplify the interface as much as we can.Automatic
updates
if possibleWe should aim at completely removing the manual update logic.
Deprecate taps
Theano
was used for deep learning, and implementing taps probably made sense if the group worked with RNNs a lot. I suggest we remove this feature: it is a nice to have, but a refactor should focus on what is essential. It should be easy for libraries that make heavy use of taps to implement a wrapper around the newscan
primitive.While loop
The first step is to implement an Op that maps to the standard
while
loop. We take inspiration from JAX and Tensorflow to propose the following:With the following rough type signatures:
We could then replace the following example (taken from the documentation)
with the more explicit:
With the following arguments to
while_loop
:cond_fn
function such that we keep iterating whilecond_fn(variables)=True
;body_fn
function called at each stepbody_fn(variables)
;init_vals
the initial value(s) of the variable(s) either as a variable, a tuple or list (*args) or a dictionary (**kwargs);n_steps
maximum number of steps (@brandonwillard can this be made optional, or does it have implications for e.g gradient computation?)Comments:
cond_fn
is checked before callingbody_fn
the first time;init_value
’sScan
We take inspiration from JAX and Tensorflow implementation:
with the following rough type signature:
Here I don’t have enough experience with
aesara
to know if we need to provide the type ofB
via the user interface of if it can be infered. In the following I assume I can infer it.Simple loop
We can implement a simple loop setting
sequences
toNone
and giving a value ton_step
. So that:becomes:
Iterate over the first dimension of a
TensorVariable
Becomes:
Beta Was this translation helpful? Give feedback.
All reactions