Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay, TOPI] Make Softmax op fusible with elemwise ops #8909

Merged
merged 15 commits into from
Sep 6, 2021
6 changes: 3 additions & 3 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@

# softmax
reg.register_strategy("nn.softmax", strategy.softmax_strategy)
reg.register_pattern("nn.softmax", OpPattern.OPAQUE)
reg.register_pattern("nn.softmax", OpPattern.OUT_ELEMWISE_FUSABLE)


# fast softmax
reg.register_strategy("nn.fast_softmax", strategy.fast_softmax_strategy)
reg.register_pattern("nn.fast_softmax", OpPattern.OPAQUE)
reg.register_pattern("nn.fast_softmax", OpPattern.OUT_ELEMWISE_FUSABLE)


# log_softmax
reg.register_strategy("nn.log_softmax", strategy.log_softmax_strategy)
reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
reg.register_pattern("nn.log_softmax", OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_legalize("nn.matmul")
Expand Down
108 changes: 66 additions & 42 deletions python/tvm/topi/cuda/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,43 +21,26 @@
from tvm.contrib import cudnn
from .. import generic
from .injective import schedule_injective_from_existing
from ..utils import traverse_inline


def schedule_softmax(outs):
"""Schedule for softmax op.

Parameters
----------
outs: Array of Tensor
The computation graph description of softmax in the format
of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
softmax = outs[0]
tgt = Target.current(allow_none=False)

op_tag = softmax.op.tag
def _schedule_softmax(softmax_op, s, outs, tgt):
op_tag = softmax_op.tag
if op_tag == "softmax_output":
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
expsum = softmax_op.input_tensors[1]
exp = softmax_op.input_tensors[0]
max_elem = s[exp].op.input_tensors[1]
delta = None
elif op_tag == "fast_softmax_output":
expsum = softmax.op.input_tensors[1]
exp = softmax.op.input_tensors[0]
expsum = softmax_op.input_tensors[1]
exp = softmax_op.input_tensors[0]
delta = s[exp].op.input_tensors[0]
max_elem = s[delta].op.input_tensors[1]
elif op_tag == "log_softmax_output":
exp = None
delta = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
max_elem = softmax_op.input_tensors[1]
expsum = softmax_op.input_tensors[2]
else:
raise ValueError(
"Tag is expected to be softmax_output or log_softmax_output. \
Expand All @@ -71,41 +54,53 @@ def schedule_softmax(outs):
#
# TODO(tvm-team) Fix nvptx codegen or deprecate nvptx backend.
def sched_warp_softmax():
if tgt.kind.name == "nvptx" or tgt.kind.name == "rocm":
return softmax.dtype == "float32" or softmax.dtype == "int32"
if tgt.kind.name in ["nvptx", "rocm"]:
dtype = softmax_op.output(0).dtype
return dtype in ["float32", "int32"]
if tgt.kind.name != "cuda":
# this is used as the gpu schedule for other arches which may not have warp reductions
# this is used as the gpu schedule for other arches which
# may not have warp reductions
return False
return True

if len(softmax.shape) > 2:
ops = [max_elem.op, expsum.op, softmax.op]
if len(outs[0].shape) > 2:
ops = [max_elem.op, expsum.op, softmax_op]
if delta is not None:
ops.append(delta.op)
if exp is not None:
ops.append(exp.op)
if softmax_op != outs[0].op:
ops.append(outs[0].op)

for op in ops:
s = schedule_injective_from_existing(s, op.output(0))

elif sched_warp_softmax():
elif sched_warp_softmax() and softmax_op == outs[0].op:
# TODO(masahi): Fix LowerThreadAllreduce pass to remove
# softmax_op == outs[0].op condition
masahi marked this conversation as resolved.
Show resolved Hide resolved

# A warp of 32 threads performs a row reduction.
num_thread = tgt.thread_warp_size
block_x = te.thread_axis("blockIdx.x")
thread_x = te.thread_axis((0, num_thread), "threadIdx.x")

# (4) softmax
xo, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
_, xii = s[softmax].split(xi, factor=4)
s[softmax].vectorize(xii)
s[softmax].bind(xo, thread_x)
s[softmax].bind(softmax.op.axis[0], block_x)
output = outs[0]
xo, xi = s[output].split(output.op.axis[1], nparts=num_thread)
xio, xii = s[output].split(xi, factor=4)
s[output].vectorize(xii)
s[output].bind(xo, thread_x)
s[output].bind(output.op.axis[0], block_x)

if softmax_op != outs[0].op:
s[softmax_op].compute_at(s[output], xio)
s[softmax_op].vectorize(softmax_op.axis[1]) # vec_len == 4

# (3) expsum
k = expsum.op.reduce_axis[0]
ko, _ = s[expsum].split(k, nparts=num_thread)
s[expsum].bind(ko, thread_x)
s[expsum].compute_at(s[softmax], xo)
s[expsum].compute_at(s[output], xo)

# (2) exp
if delta is not None:
Expand All @@ -117,7 +112,7 @@ def sched_warp_softmax():
s[exp].vectorize(xii)
s[exp].bind(xo, thread_x)
s[exp].compute_at(s[expsum], expsum.op.axis[0])
s[exp].compute_at(s[softmax], softmax.op.axis[0])
s[exp].compute_at(s[output], output.op.axis[0])
s[exp].set_scope("warp")

# (1) max_elem
Expand Down Expand Up @@ -149,10 +144,39 @@ def sched_warp_softmax():
s[expsum].bind(s[expsum].op.reduce_axis[0], thread_x)
s[EF].compute_at(s[expsum], s[expsum].op.reduce_axis[0])
s[expsum].set_store_predicate(thread_x.var.equal(0))
tx, xi = s[softmax].split(softmax.op.axis[1], nparts=num_thread)
s[softmax].bind(softmax.op.axis[0], block_x)
s[softmax].bind(tx, thread_x)

output = outs[0]
tx, xi = s[output].split(output.op.axis[1], nparts=num_thread)
s[output].bind(output.op.axis[0], block_x)
s[output].bind(tx, thread_x)

if softmax_op != outs[0].op:
s[softmax_op].compute_at(s[output], tx)


def schedule_softmax(outs):
"""Schedule for softmax op.

Parameters
----------
outs: Array of Tensor
The computation graph description of softmax in the format
of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])
tgt = Target.current(allow_none=False)

def _callback(op):
if "softmax" in op.tag:
_schedule_softmax(op, s, outs, tgt)

traverse_inline(s, outs[0].op, _callback)
return s


Expand Down
81 changes: 48 additions & 33 deletions python/tvm/topi/x86/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,44 +17,28 @@
# pylint: disable=invalid-name,too-many-locals,unused-variable
"""x86 nn operators"""
from tvm import te
from ..utils import traverse_inline


def schedule_softmax(outs):
"""Schedule for softmax

Parameters
----------
outs: Array of Tensor
The computation graph description of softmax
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
softmax = outs[0]
s = te.create_schedule([x.op for x in outs])

op_tag = softmax.op.tag
def _schedule_softmax(softmax_op, s, outs):
op_tag = softmax_op.tag
if op_tag == "softmax_output":
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
exp = softmax_op.input_tensors[0]
expsum = softmax_op.input_tensors[1]
max_elem = s[exp].op.input_tensors[1]
delta = None
axis = int(softmax.op.attrs["axis"])
axis = int(softmax_op.attrs["axis"])
elif op_tag == "fast_softmax_output":
exp = softmax.op.input_tensors[0]
expsum = softmax.op.input_tensors[1]
exp = softmax_op.input_tensors[0]
expsum = softmax_op.input_tensors[1]
delta = s[exp].op.input_tensors[0]
max_elem = s[delta].op.input_tensors[1]
axis = int(softmax.op.attrs["axis"])
axis = int(softmax_op.attrs["axis"])
elif op_tag == "log_softmax_output":
exp = None
delta = None
max_elem = softmax.op.input_tensors[1]
expsum = softmax.op.input_tensors[2]
max_elem = softmax_op.input_tensors[1]
expsum = softmax_op.input_tensors[2]
axis = 1
else:
raise ValueError(
Expand All @@ -65,18 +49,49 @@ def schedule_softmax(outs):
)

# only parallelize outer dimensions up to axis
outer_axes = [s[softmax].op.axis[i] for i in range(0, axis)]
fused_outer_axes = s[softmax].fuse(*outer_axes)
s[softmax].parallel(fused_outer_axes)
outer_axes = [s[softmax_op].op.axis[i] for i in range(0, axis)]
fused_outer_axes = s[softmax_op].fuse(*outer_axes)
s[softmax_op].parallel(fused_outer_axes)

# move computations with the same outer dimensions under the same root
s[max_elem].compute_at(s[softmax], fused_outer_axes)
s[expsum].compute_at(s[softmax], fused_outer_axes)
s[max_elem].compute_at(s[softmax_op], fused_outer_axes)
s[expsum].compute_at(s[softmax_op], fused_outer_axes)

if delta is not None:
s[exp].compute_inline()
s[delta].compute_inline()
if exp is not None:
s[exp].compute_at(s[softmax], fused_outer_axes)
s[exp].compute_at(s[softmax_op], fused_outer_axes)

if softmax_op != outs[0].op:
# fuse softmax output with following elemwise ops.
output = outs[0]
outer_axes = [s[output].op.axis[i] for i in range(0, axis)]
fused_outer_axes = s[output].fuse(*outer_axes)
s[output].parallel(fused_outer_axes)
s[softmax_op].compute_at(s[output], fused_outer_axes)


def schedule_softmax(outs):
"""Schedule for softmax

Parameters
----------
outs: Array of Tensor
The computation graph description of softmax
in the format of an array of tensors.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "softmax" in op.tag:
_schedule_softmax(op, s, outs)

traverse_inline(s, outs[0].op, _callback)
return s
12 changes: 11 additions & 1 deletion tests/micro/arduino/test_arduino_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import datetime
import pathlib
import re
import shutil
import sys

Expand Down Expand Up @@ -80,7 +81,16 @@ def test_model_header_templating(project_dir, project):
# Ensure model.h was templated with correct WORKSPACE_SIZE
with (project_dir / "src" / "model.h").open() as f:
model_h = f.read()
assert "#define WORKSPACE_SIZE 21312" in model_h
workspace_size_defs = re.findall(r"\#define WORKSPACE_SIZE ([0-9]*)", model_h)
assert workspace_size_defs
assert len(workspace_size_defs) == 1

# Make sure the WORKSPACE_SIZE we define is a reasonable size. We don't want
# to set an exact value, as this test shouldn't break if an improvement to
# TVM causes the amount of memory needed to decrease.
workspace_size = int(workspace_size_defs[0])
assert workspace_size < 30000
assert workspace_size > 10000


def test_import_rerouting(project_dir, project):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def is_conv_add(func):
def test_extract_resnet():
mod, _params = get_workload()
items = relay.analysis.extract_fused_functions(mod)
assert len(items) == 6
assert len(items) == 7


if __name__ == "__main__":
Expand Down
Loading