Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AUTOTVM] Core part of auto-tuning module #1312

Merged
merged 46 commits into from
Jul 12, 2018
Merged

Conversation

merrymercy
Copy link
Member

@merrymercy merrymercy commented Jun 21, 2018

This PR is the step 1 in #1311

It includes

  • Tuning space definition API
  • Basic tuners: RandomTuner, GridSearchTuner, XGBTuner
  • Measurement executor (local mode and distributed mode through rpc)
  • Tuning results log file
  • Tutorial on how to write tunable schedule

Some code is contributed by @eqy and @tqchen

@tqchen
Copy link
Member

tqchen commented Jun 23, 2018

@eqy @Laurawly please review

@@ -0,0 +1,83 @@
Auto-tuning API
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just use tvm.autotvm

~~~~~~~~~~~~~~~~~~~~
.. automodule:: tvm.autotvm.measure.measure

.. class:: tvm.autotvm.MeasureInput(target, task, config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let us keep named tuple document in place in source. DO

class X(namedtuple("X", fields)):
    """docstring
    """
     __slots__ = ()

:members:

tvm.autotvm.task
~~~~~~~~~~~~~~~~~~~~~
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rst requires the underline to have the same length of the title

docs/conf.py Outdated
@@ -189,6 +189,7 @@ def run_doxygen(folder):
subsection_order = ExplicitOrder(
['../tutorials/language',
'../tutorials/optimize',
'../tutorials/autotuning',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

autotuning->autotvm

"""
partial_results = [None] * len(measure_inputs)
unsaved = list()
for i in range(len(measure_inputs)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for i, inp in enumerate(measure_inputs)

return _do_reg


def create(func_name, args):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create and create_task can be super confusing, I would prefer to rename create_task-> create and remove create.

Allow pass in None in target, and call init_space later which returns a new task

Its argument should be hashable values.
Its return value should be a Tuple(Schedule, Array of Tensor)

Returns
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need an example block here.

simple_template deserves its own file.
Need to explain why it is not in template

@@ -0,0 +1,830 @@
# pylint: disable=too-few-public-methods,invalid-name,unused-argument,arguments-differ
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

be careful about arguments-differ

pass

"""
We can regard our schedule code as a transformation graph of axes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put this as docstring of Transform space

def has_next(self):
return len(self.visited) < len(self.space)

def save_state(self, filename):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

save/load state are not pythonic

Instead, directly make them pickable. https://docs.python.org/2/library/pickle.html#object.__getstate__

@tqchen
Copy link
Member

tqchen commented Jun 23, 2018

please always make docker change first as a separate PR, this is used to avoid out of space error we might encounter

# In autotvm, we can define a tunable parameter, or a "knob" for such kind of value.

# Matmul V1: List candidate values
@autotvm.simple_template # 1. use a decorator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need a simple template testcase

# ---------------------------------
# There are plenty of useful schedule primitives in tvm. You can also find
# some tutorials that describe them in more details, such as
# (1). `Optimizing Conv2d on CUDA GPU <https://docs.tvm.ai/tutorials/optimize/opt_conv_cuda.html#sphx-glr-tutorials-optimize-opt-conv-cuda-py>`_
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do internal doc reference using ref

@tqchen tqchen added status: review in progress status: need update need update based on feedbacks labels Jun 24, 2018
@tqchen tqchen requested a review from Laurawly June 24, 2018 03:28
@merrymercy merrymercy force-pushed the autotvm branch 6 times, most recently from e960356 to 260d67c Compare June 26, 2018 01:28
@merrymercy
Copy link
Member Author

return names


_get_buffer_curve_sample_flatten = get_global_func(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid get global function eagerly at root namespae, won't work for the runtime only env

keep += 1
ret[i] = keep + 1
return ret / len(trial_ranks)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two lines between global functions

@merrymercy
Copy link
Member Author

@tqchen
Copy link
Member

tqchen commented Jul 5, 2018

@were please review

i.e. Do not eliminate loop with extent of 1, do not vectorize, unroll or inject virtual threads
"""
binds, _ = build_module.get_binds(args, binds)
# cfg = current_build_config()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you still keep these comments? The unuseful code should be eliminated. If you really want it, you can find it in the version control.

# back to use low level API.

@autotvm.template
def matmul(N, L, M, dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every time you need to rewrite the whole algorithm part, which is redundant. Why don't you get it from topi.dense (or wrap it up in some your own function in this .py) and apply some default schedule first?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the intent of the simple template is to provide a standalone example e.g., to tune an operator that has not been upstreamed into topi yet. The example uses a common operator, but that is just for illustration purposes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point is that rewriting the op weaken the point you want to make. We can just take the advantage of decoupling algorithm description and scheduling.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The core part of autotvm is independent of topi. We don't want to introduce topi for this basic tutorial.


new_scores = model.predict(new_points)

ac_prob = np.exp((new_scores - scores) / t)
Copy link
Contributor

@were were Jul 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little bit curious how you normalize this score and temperature. I wrote a sa before: it seems when the temperature is high, it just acts like a random search. The acc ratio is too high nearly 100% every time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I normalize the score to [0,1]

def update(self, inputs, results):
for inp, res in zip(inputs, results):
if res.error_no == 0:
y = inp.task.flop / np.mean(res.costs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finally, relying on the parameter of the hardware mode is resolved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot understand this comment, can you elaborate?

@merrymercy
Copy link
Member Author

ready for merge

@tqchen
Copy link
Member

tqchen commented Jul 11, 2018

if there is no further comments, I am going to merge this in an hour

msg = msg[:msg.index("Stack trace returned")]
res_pack.append(MeasureResult((RuntimeError(msg),),
MeasureErrorNo.COMPILE_HOST,
tstamp - tic, tstamp))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When gpu_verify_pass() raises InstantiationError, we reach here but it looks a bit confusing.
How about doing something like as follows to create a more appropriate result?

            try:
                # extract error information
                _e, _msg = exc.message.split('\n')[-2].split(': ', 1)

                if _e == "InstantiationError":
                    res_pack.append(MeasureResult(_msg,
                                            MeasureErrorNo.INSTANTIATION_ERROR,
                                            tstamp - tic, tstamp))
                    continue
            except Error as _:
                pass

def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
if not valid:
raise InstantiationError("invalid gpu kernel")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the message should be more informative. "Skip execution because of Invalid gpu kernel config." or something like that?

@tqchen
Copy link
Member

tqchen commented Jul 11, 2018

OK, @merrymercy please act on @kazum 's comments

**Author**: `Lianmin Zheng <https://https://github.com/merrymercy>`_

This is an advanced tutorial for writing high performance tunable template for
CUDA GPU. By runing auto-tuner on this template, we can outperform the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing CUDA GPU with Nvidia GPUs is more accurate.

# the techniques used in these tutorials. Then we rely on the efficient auto-tuner
# to search through this space and pick some good configurations.
#
# If you are familiar with wring cude schedule, you can find the following
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling typos.

@merrymercy
Copy link
Member Author

ready for merge

@tqchen
Copy link
Member

tqchen commented Jul 12, 2018

Copy link
Contributor

@kazum kazum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@tqchen
Copy link
Member

tqchen commented Jul 12, 2018

@merrymercy can you rebase against master to resolve the conflict?

@tqchen tqchen merged commit 5980b5d into apache:master Jul 12, 2018
@tqchen
Copy link
Member

tqchen commented Jul 12, 2018

Thanks, @Laurawly @kazum for the reviews!

tqchen pushed a commit to tqchen/tvm that referenced this pull request Aug 4, 2018
sergei-mironov pushed a commit to sergei-mironov/tvm that referenced this pull request Aug 8, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants