You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
TVM currently focuses on models which are control-flow free. In order to support control-flow operations from other frameworks, we can map other framework’s control flow operators, such as tf.while_loop, to Relay expressions. We can represent these operations using a combination Relay’s branches and recursive function calls.
This RFC proposes conversion from TensowFlow (1.x) control-flow constructs to Relay control-flow. The challenge is that TensorFlow uses low-level data-flow primitives, such as Merge, Exit, Switch, NextIteration, Enter, to implement control-flow operators (i.e. cond and while_loop). It is not trivial to revert these primitives to the original control-flow operators.
Proposal
We propose a decompilation strategy which reconstructs the original high-level control flow statements via pattern matching on the TensorFlow graph. The reconstruction translates low-level dataflow into corresponding Relay, i.e. loops to recursive functions and conditions to if expressions. Nested loops, nested conditions, mixing of them, and multiple level nested cases complicate the problem, fortunately, we haven’t seen any nested cases in real applications. Furthermore, we believe it is straightforward to support nested translation in many cases.
Our proposal is based on the observation/fact that:
A TF cond will only be composed by on merge and switch primitives, and there is only one merge.
importtensorflowastfx=tf.constant(10)
y=tf.constant(15)
z=tf.constant(20)
r=tf.cond(x<y, lambda: tf.add(x, z), lambda: tf.square(y))
# returned value r is tf.add(x, z) if pred "x < y" is true else tf.square(y)
A while_loop is constructed by the 5 aforementioned primitives. There could be multiple occurrences of Enter, Merge, Exit, Switch and NextIteration depending on the number of conditional variables, but there is only one occurrence of LoopCond.
i=tf.constant(0)
c=lambdai: tf.less(i, 10)
b=lambdai: tf.add(i, 1)
r=tf.while_loop(c, b, [i])
# Repeat loop body b while the loop condition c is true
The creation of a cond or while_loop is usually in an execution frame (i.e. similar to scoping). Therefore, by identifying a scope, we can correctly create condition and while_loop constructs.
We should instantiate a branch statement when we meet a Merge primitive in a cond execution frame and save the inputs of the Merge to capture the branches. The input of Switch contains the used variables, and the inputs of Merge indicate the true and false bodies. Using this information, we can build a Relay if expression.
We can generate a loop when we traverse into a while_loop execution frame and see the first occurrence of LoopCond. The input of NextIteration indicates loop body, the inputs of Switch indicates loop variables, Exit indicates the completion of an execution frame where we can extract the output from a certain execution frame, LoopCond gives us the condition. Based on the collected information, we can construct a specialized while loop using recursion.
The text was updated successfully, but these errors were encountered:
zhiics
changed the title
Decompile TensorFlow Control Flow Primitives to Relay
[RFC] Decompile TensorFlow Control Flow Primitives to Relay
Mar 14, 2019
@yidawang thanks for pointing it out. The rfc here is mainly for TF1.x, it works for graph of TF2.0 too. The changes of primitives in 2.0 is for tf execution, fortunately, the generated python function won't have control flow primitives in it anymore, it is relatively easy to handle with relay. We will add the support once 2.0 get released officially.
Motivation
TVM currently focuses on models which are control-flow free. In order to support control-flow operations from other frameworks, we can map other framework’s control flow operators, such as
tf.while_loop
, to Relay expressions. We can represent these operations using a combination Relay’s branches and recursive function calls.This RFC proposes conversion from TensowFlow (1.x) control-flow constructs to Relay control-flow. The challenge is that TensorFlow uses low-level data-flow primitives, such as
Merge
,Exit
,Switch
,NextIteration
,Enter
, to implement control-flow operators (i.e.cond
andwhile_loop
). It is not trivial to revert these primitives to the original control-flow operators.Proposal
We propose a decompilation strategy which reconstructs the original high-level control flow statements via pattern matching on the TensorFlow graph. The reconstruction translates low-level dataflow into corresponding Relay, i.e. loops to recursive functions and conditions to if expressions. Nested loops, nested conditions, mixing of them, and multiple level nested cases complicate the problem, fortunately, we haven’t seen any nested cases in real applications. Furthermore, we believe it is straightforward to support nested translation in many cases.
Our proposal is based on the observation/fact that:
cond
will only be composed by on merge and switch primitives, and there is only one merge.while_loop
is constructed by the 5 aforementioned primitives. There could be multiple occurrences ofEnter
,Merge
,Exit
,Switch
andNextIteration
depending on the number of conditional variables, but there is only one occurrence ofLoopCond
.The creation of a
cond
orwhile_loop
is usually in an execution frame (i.e. similar to scoping). Therefore, by identifying a scope, we can correctly create condition andwhile_loop
constructs.We should instantiate a branch statement when we meet a
Merge
primitive in acond
execution frame and save the inputs of theMerge
to capture the branches. The input ofSwitch
contains the used variables, and the inputs ofMerge
indicate the true and false bodies. Using this information, we can build a Relayif
expression.We can generate a loop when we traverse into a
while_loop
execution frame and see the first occurrence ofLoopCond
. The input ofNextIteration
indicates loop body, the inputs ofSwitch
indicates loop variables,Exit
indicates the completion of an execution frame where we can extract the output from a certain execution frame,LoopCond
gives us the condition. Based on the collected information, we can construct a specialized while loop using recursion.The text was updated successfully, but these errors were encountered: