Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Decompile TensorFlow Control Flow Primitives to Relay #2812

Closed
zhiics opened this issue Mar 14, 2019 · 4 comments
Closed

[RFC] Decompile TensorFlow Control Flow Primitives to Relay #2812

zhiics opened this issue Mar 14, 2019 · 4 comments

Comments

@zhiics
Copy link
Member

zhiics commented Mar 14, 2019

Motivation

TVM currently focuses on models which are control-flow free. In order to support control-flow operations from other frameworks, we can map other framework’s control flow operators, such as tf.while_loop, to Relay expressions. We can represent these operations using a combination Relay’s branches and recursive function calls.

This RFC proposes conversion from TensowFlow (1.x) control-flow constructs to Relay control-flow. The challenge is that TensorFlow uses low-level data-flow primitives, such as Merge, Exit, Switch, NextIteration, Enter, to implement control-flow operators (i.e. cond and while_loop). It is not trivial to revert these primitives to the original control-flow operators.

Proposal

We propose a decompilation strategy which reconstructs the original high-level control flow statements via pattern matching on the TensorFlow graph. The reconstruction translates low-level dataflow into corresponding Relay, i.e. loops to recursive functions and conditions to if expressions. Nested loops, nested conditions, mixing of them, and multiple level nested cases complicate the problem, fortunately, we haven’t seen any nested cases in real applications. Furthermore, we believe it is straightforward to support nested translation in many cases.

Our proposal is based on the observation/fact that:

  1. A TF cond will only be composed by on merge and switch primitives, and there is only one merge.
import tensorflow as tf
x = tf.constant(10)
y = tf.constant(15)
z = tf.constant(20)
r = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))
# returned value r is tf.add(x, z) if pred "x < y" is true else tf.square(y)

Screen Shot 2019-02-25 at 8 31 55 PM

  1. A while_loop is constructed by the 5 aforementioned primitives. There could be multiple occurrences of Enter, Merge, Exit, Switch and NextIteration depending on the number of conditional variables, but there is only one occurrence of LoopCond.
i = tf.constant(0)
c = lambda i: tf.less(i, 10)
b = lambda i: tf.add(i, 1)
r = tf.while_loop(c, b, [i])
# Repeat loop body b while the loop condition c is true

Screen Shot 2019-03-13 at 11 13 54 PM

The creation of a cond or while_loop is usually in an execution frame (i.e. similar to scoping). Therefore, by identifying a scope, we can correctly create condition and while_loop constructs.

  1. We should instantiate a branch statement when we meet a Merge primitive in a cond execution frame and save the inputs of the Merge to capture the branches. The input of Switch contains the used variables, and the inputs of Merge indicate the true and false bodies. Using this information, we can build a Relay if expression.

  2. We can generate a loop when we traverse into a while_loop execution frame and see the first occurrence of LoopCond. The input of NextIteration indicates loop body, the inputs of Switch indicates loop variables, Exit indicates the completion of an execution frame where we can extract the output from a certain execution frame, LoopCond gives us the condition. Based on the collected information, we can construct a specialized while loop using recursion.

@zhiics
Copy link
Member Author

zhiics commented Mar 14, 2019

@zhiics zhiics changed the title Decompile TensorFlow Control Flow Primitives to Relay [RFC] Decompile TensorFlow Control Flow Primitives to Relay Mar 14, 2019
@yidawang
Copy link
Contributor

How should we embrace the control flow primitive changes introduced by TF 2.0? Should we worry about them later in a separated PR?

@yongwww
Copy link
Member

yongwww commented Mar 19, 2019

@yidawang thanks for pointing it out. The rfc here is mainly for TF1.x, it works for graph of TF2.0 too. The changes of primitives in 2.0 is for tf execution, fortunately, the generated python function won't have control flow primitives in it anymore, it is relatively easy to handle with relay. We will add the support once 2.0 get released officially.

@zhiics
Copy link
Member Author

zhiics commented Mar 24, 2019

closed by #2830

@zhiics zhiics closed this as completed Mar 24, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants