Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND] A Python hybrid frontend #1251

Merged
merged 31 commits into from
Jun 22, 2018
Merged

[FRONTEND] A Python hybrid frontend #1251

merged 31 commits into from
Jun 22, 2018

Conversation

were
Copy link
Contributor

@were were commented Jun 8, 2018

Hybrid Frontend Developer Guide

This hybrid frontend is aimed at:

  1. Building IR in a more intuitive way
  2. Writing preliminary versions of some idioms that yet have not been supported by

Features

Software emulation

This feature supports both software emulation and compilation of the code.

To define a function, you need to use tvm.hybrid.script decorator to indicate this is a hybrid function:

@tvm.hybrid.script
def outer_product(a, b, c):
    for i in range(a.shape[0]):
        for j in range(b.shape[0]):
            c[i, j] = a[i] * b[j]
a = numpy.random.rand(100)
b = numpy.random.rand(99)
c = numpy.zeros((100, 99))
outer_product(a, b)

This decorator will help you to import key words required spontaneously when software emulation.
Every element in the argument list is either a python variable or numpy tensor.

Backend Compilation

The current parse interface looks like:

a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
tvm.hybrid.parse(outer_product, [a, b, c]) # return an ir root of this function

TODO: If we pass these tvm tensors to this function, it returns a op node:

a = tvm.placeholder((100, ), name='a')
b = tvm.placeholder((99, ), name='b')
c = tvm.placeholder((100, 99), name='c')
op = outer_product(a, b, c) # return the corresponding op node

Scheduling

Under construction, not truly supported yet.

Follow up the example above, you can use some tvm like interfaces to manipulate the structure of IR:

sch = tvm.create_schedule(op)
jo, ji = sch.split(j, 4)
sch.vectorize(ji)

split, reorder, and loop_annotation will be supported!

Attributes

So far, ONLY tensors' shape attribute is supported!

Loops

In HalideIR, loops have in total 4 types: serail, unrolled, parallel, and vectorized.

Here we use range, serial, unroll, parallel, and vectorize, these 5 keywords to annotate the types of for loops.

NOTE: In HalideIR those are enums, they are in passive form. Here we use active form to annotate loops, because they are ready to run.

NOTE: Unlike what that is in HalideIR, in loop_type(a, b), a is the starting point and b is the trip count of iterations. Here loop_type(a, b) indicates [a, b).

Variables

Because there is no variables in HalideIR, all the mutatable variables will be lowered to an array with size 1.
It takes the first store of a variable as its declaration.
NOTE: Unlike conventional Python, the declared array can only be used in the scope level it is declared.

for i in range(5):
    sum = 0
    for j in range(5):
    	sum += a[i, j] #do something with sum
    b[i] = sum #you can still use sum in this level
#you can NEVER use some here, even though it is allowed in conventional Python
a[0] = sum

Conditional Statement and Expression

if condition:
    # do something
a = b if condition else c

However, NO True and False keyword supported yet.

Math intrinsics

So far, these math intrinsics, log, exp, sigmoid, tanh, power, and popcount, are supported. No import is required, just use it!

Array allocation

TODO: Use a function call allocation(shape, type, share/local) to declare an array buffer. The basic usage is roughly the same as variables

Thread bind

You can also do loop-thread bind by writing code like this:

for tx in bind("threadIdx.x", 100):
    a[tx] = b[tx]

Appendix

Keywords

  • Statement keywords: for, in, if, else
  • For keywords: serial, range, unroll, parallel, vectorize, bind
  • Math keywords: log, exp, sigmoid, tanh, power, popcount

@were were mentioned this pull request Jun 8, 2018
17 tasks
@were were changed the title cleanup a branch I messed up [FRONTEND] A Python hybrid frontend Jun 8, 2018
@were
Copy link
Contributor Author

were commented Jun 8, 2018

@tqchen @merrymercy @kevinthesun @xqdan @ZihengJiang
I think these are reviews? I hope I did not miss any one.

a = numpy.random.rand(100)
b = numpy.random.rand(99)
c = numpy.zeros((100, 99))
outer_product(a, b)
Copy link
Member

@merrymercy merrymercy Jun 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be outer_product(a, b, c)

Copy link
Member

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other parts look good to me

@tqchen
Copy link
Member

tqchen commented Jun 12, 2018

@were please comment when the PR is ready to be reviewed. We shouldn't put TODO in the docs, but you can put a note saying that the feature is yet to be supported.

We will need a tutorial in tutorials/language/hybrid_script.py

Copy link

@liao02x liao02x left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a few typos


### Loops

In HalideIR, loops have in total 4 types: `serail`, `unrolled`, `parallel`, and `vectorized`.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serail should be serial

for j in range(5):
sum += a[i, j] #do something with sum
b[i] = sum #you can still use sum in this level
#you can NEVER use some here, even though it is allowed in conventional Python
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that some, or should it be sum?

@tqchen
Copy link
Member

tqchen commented Jun 13, 2018

general progress tracked in #1213

stmt = schedule.ScheduleOps(sch, bounds)
stmt = ir_pass.InjectPrefetch(stmt)
else:
#So far there is no op for hybrid script, so a plain ir body is given
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if sch is stmt, if not raise error

from ..tensor import Tensor

# Useful constants
NOP = _make.Evaluate(_api.const(0, dtype='int32'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try not to construct IR object in global, since we sometimes will have runtime only object and we need to support non-dependency when we only import but did not use hybrid.

Construct a Constant class single and overload property to do it lazily when any of them is requested

NUMPY_ARG_TYPES = (float, int, long, numpy.float32, numpy.int32, numpy.ndarray)

def _is_tvm_arg_types(args):
"""Determine a list of element is either a list of tvm arguments of a list of numpy arguments.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when possible, document the arg types


TVM_ARG_TYPES = (_expr.Var, Tensor)
if sys.version_info[0] == 3:
NUMPY_ARG_TYPES = (float, int, numpy.float32, numpy.int32, numpy.ndarray)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go to _ffi.base

@@ -0,0 +1,69 @@
"""Intrinsics of Python-Halide DSL for Python runtime"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intrinsics of TVM Hybrid script

Returns
-------
func_name: str
The name of the function to be lowered; if not provided,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indent the description


def parse_python(src, args):
""" The helper function of calling the AST visitor"""
root = ast.parse(src)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems most of this can be merged into parser construction


#pylint: disable=missing-docstring, invalid-name
#pylint: disable=consider-merging-isinstance, no-else-return
#pylint: disable=inconsistent-return-statements
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check inconsistent-return-statements

from ._intrin import HYBRID_GLOBALS

class PyVariableUsage(ast.NodeVisitor):
"""The vistor class to determine the declaration, r/w status, and last use of each variable"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document arguments

@@ -0,0 +1,68 @@
"""Determines the declaration, r/w status, and last use of each variable"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import absolute import

@tqchen tqchen added the status: need update need update based on feedbacks label Jun 14, 2018
@@ -0,0 +1,66 @@
Hybrid Frontend Developer Guide
-------------------------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us start from first level title =======

@@ -0,0 +1,169 @@
Hybrid Frontend Language Reference
----------------------------------
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first level title is =====, second level title is -------, third level title is ~~~~~

@tqchen
Copy link
Member

tqchen commented Jun 21, 2018

@Laurawly please also take a look

If you are a developer:

1. who is trying writing some preliminary patterns that have not been supported by TVM yet,
maybe ``lang_ref/hybrid_script.rst`` is a better place for you.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use ref feature of rst


In HalideIR, loops have in total 4 types: ``serial``, ``unrolled``, ``parallel``, and ``vectorized``.

Here we use ``range``, ``serial``, ``unroll``, ``parallel``, and ``vectorize``,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove serial


**NOTE**: In HalideIR those are enums, they are in passive form.
Here we use active form to annotate loops, because they are ready to run.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is useer facing doc, no need to refer back to HalideIR, instead, directly say the loop_type corresponds to range. Use note block construction of rst .

All the mutatable variables will be lowered to an array with size 1.
It regards the first store of a variable as its declaration.

**NOTE**: Unlike conventional Python, in hybrid script, the declared variable
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use note block rst

If you are a developer:

1. who is trying writing some preliminary patterns that have not been supported by TVM yet,
maybe `language ref <../langref/hybrid_script.rst>`_ is a better place for you.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ref. see http://www.sphinx-doc.org/en/stable/markup/inline.html cross reference arbitrary locations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a ref to a file instead of an anchor. Can I still use label to do it? I suppose label is only for intra file instead of inter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label works for inter file

@@ -0,0 +1,170 @@
Hybrid Frontend Language Reference
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hybrid Script Language Reference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this is the first level, change to ======



def allocate(shape, dtype=None):
"""Allocate a buffer with given shape"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



def popcount(x):
"""Software emulated popcount function which counts 1's in a number's binary representation."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return cnt


def sigmoid(x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,73 @@
"""Internal utilities for parsing Python subset to HalideIR"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider just make it util.py



# Useful constants. In avoid of runtime dependences, we use function calls to return them.
def make_nop():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

document functions

-------
(halide_ir, parser) : (Stmt, PyAST2HalideIR)
The result Halide IR and the parser class instance.
TODO: Later we deprecate this return value, use a dedicated OP node type instead
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parse should always return Stmt only

@@ -0,0 +1,342 @@
"""Compiling a TVM Hybrid Script Python to HalideIR"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hybrid Script Parser

return res


def parse_python(src, args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this function necessary? Can we just do and inline it into parse?

HybridParser parser(src, args)
return parser.parse(root)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to make the code style uniform, each IR pass has its own helper function. Just like that is in Halide style.

@@ -0,0 +1,10 @@
"""Hybrid Programming APIs of TVM Python Package.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Laurawly
Copy link
Contributor

@were Could there be more gpu testings regarding vthread injection and usage of shared / warp memory?

@tqchen tqchen merged commit 290226e into apache:master Jun 22, 2018
@tqchen
Copy link
Member

tqchen commented Jun 22, 2018

Thanks! This is merged. Please followup with PRs on gpu support

@were were deleted the pyfrontend branch June 26, 2018 01:06
tqchen pushed a commit to tqchen/tvm that referenced this pull request Jul 6, 2018
mnuyens pushed a commit to mnuyens/tvm that referenced this pull request Jul 10, 2018
sergei-mironov pushed a commit to sergei-mironov/tvm that referenced this pull request Aug 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants