Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache-upstream/main' into unity
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 18, 2023
2 parents 959b7e5 + 2eca9f0 commit d8f1ac4
Show file tree
Hide file tree
Showing 66 changed files with 2,563 additions and 267 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ We do encourage everyone to work anything they are interested in.
- [Janet Schneider](https://github.com/janetsc): @janetsc
- [Junru Shao](https://github.com/junrushao): @junrushao
- [Haichen Shen](https://github.com/icemelon): @icemelon
- [Qingchao Shen](https://github.com/jikechao): @jikechao
- [Xingjian Shi](https://github.com/sxjscience): @sxjscience
- [Yuanjing Shi](https://github.com/shingjan): @shingjan
- [Mark Shields](https://github.com/mbs-octoml): @mbs-octoml
Expand Down
9 changes: 5 additions & 4 deletions apps/microtvm/arduino/template_project/microtvm_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,11 @@ def _get_arduino_port(
def _get_board_from_makefile(self, makefile_path: pathlib.Path) -> str:
"""Get Board from generated Makefile."""
with open(makefile_path) as makefile_f:
line = makefile_f.readline()
if "BOARD" in line:
board = re.sub(r"\s", "", line).split(":=")[1]
return board
lines = makefile_f.readlines()
for line in lines:
if "BOARD" in line:
board = re.sub(r"\s", "", line).split(":=")[1]
return board
raise RuntimeError("Board was not found in Makefile: {}".format(makefile_path))

FLASH_TIMEOUT_SEC = 60
Expand Down
2 changes: 1 addition & 1 deletion ci/jenkins/docker-images.ini
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# This data file is read during when Jenkins runs job to determine docker images.
[jenkins]
ci_arm: tlcpack/ci-arm:20230615-060132-62a5e7acf
ci_cortexm: tlcpack/ci-cortexm:20230613-060122-21361a63a
ci_cortexm: tlcpack/ci-cortexm:20230710-060128-a60cd0fec
ci_cpu: tlcpack/ci-cpu:20230604-060130-0af9ff90e
ci_gpu: tlcpack/ci-gpu:20230504-142417-4d37a0a0
ci_hexagon: tlcpack/ci-hexagon:20230504-142417-4d37a0a0
Expand Down
2 changes: 1 addition & 1 deletion docker/install/ubuntu_install_tensorflow_aarch64.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ apt-install-and-clear -y --no-install-recommends libhdf5-dev
pip3 install \
numpy==1.23.5 \
keras==2.9 \
tensorflow-aarch64==2.9.1
tensorflow-aarch64~=2.9.3
4 changes: 2 additions & 2 deletions docs/arch/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ The main goal of TVM's runtime is to provide a minimal API for loading and execu
mod: tvm.runtime.Module = tvm.runtime.load_module("compiled_artifact.so")
arr: tvm.runtime.NDArray = tvm.nd.array([1, 2, 3], device=tvm.cuda(0))
fun: tvm.runtime.PackedFunc = mod["addone"]
fun(a)
print(a.numpy())
fun(arr)
print(arr.numpy())
:py:class:`tvm.runtime.Module` encapsulates the result of compilation. A runtime.Module contains a GetFunction method to obtain PackedFuncs by name.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ enum DeviceAttrKind : int {
kMaxRegistersPerBlock = 9,
kGcnArch = 10,
kApiVersion = 11,
kDriverVersion = 12
kDriverVersion = 12,
kL2CacheSizeBytes = 13,
};

#ifdef TVM_KALLOC_ALIGNMENT
Expand Down
4 changes: 3 additions & 1 deletion include/tvm/runtime/profiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,13 +579,15 @@ PackedFunc ProfileFunction(Module mod, std::string func_name, int device_type, i
* defined by `repeats_to_cooldown`.
* \param repeats_to_cooldown The number of repeats before the
* cooldown is activated.
* \param cache_flush_bytes The number of bytes to flush from cache before
* \param f_preproc The function to be executed before we execute time
* evaluator.
* \return f_timer A timer function.
*/
PackedFunc WrapTimeEvaluator(PackedFunc f, Device dev, int number, int repeat, int min_repeat_ms,
int limit_zero_time_iterations, int cooldown_interval_ms,
int repeats_to_cooldown, PackedFunc f_preproc = nullptr);
int repeats_to_cooldown, int cache_flush_bytes = 0,
PackedFunc f_preproc = nullptr);

} // namespace profiling
} // namespace runtime
Expand Down
96 changes: 96 additions & 0 deletions include/tvm/topi/nn/rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief root mean square normalization op constructions
* \file nn/rms_norm.h
*/
#ifndef TVM_TOPI_NN_RMS_NORM_H_
#define TVM_TOPI_NN_RMS_NORM_H_

#include <tvm/te/operation.h>
#include <tvm/topi/reduction.h>
#include <tvm/topi/tags.h>

#include <string>

namespace tvm {
namespace topi {
namespace nn {

using namespace tvm::te;

/*!
* \brief Root mean square normalization.
* \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
* \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == len(axis) and
* d_{axis_k} == r_k
* \param bias Optional, K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where
* d_{axis_k} == r_k
* \param axis The axis to normalize over.
* \param epsilon The epsilon value to avoid division by zero.
* \param name The name of the operation.
* \param tag The tag to mark the operation.
* \return The normalized tensor, with the same shape as data.
*/
inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const Tensor& bias,
const Array<Integer>& axis, double epsilon, std::string name = "T_rms_norm",
std::string tag = kInjective) {
const auto& data_type = data->dtype;
const auto& weight_type = weight.defined() ? weight->dtype : data_type;
ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the same type";
const auto& bias_type = bias.defined() ? bias->dtype : data_type;
ICHECK(data_type == bias_type) << "rms_norm: data and bias must have the same type";

auto square = multiply(data, data);
auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);

auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_extent = make_const(data->dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
auto rms_norm_func = [&](const Array<Var>& indices) {
Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) {
reduce_indices.push_back(indices[i]);
} else {
non_reduce_indices.push_back(indices[i]);
}
}
auto output =
data(indices) * weight(reduce_indices) *
tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + make_const(data_type, epsilon));
if (bias.defined()) {
output += bias(reduce_indices);
}
return output;
};
auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
return rms_norm;
}

} // namespace nn
} // namespace topi
} // namespace tvm

#endif // TVM_TOPI_NN_RMS_NORM_H_
2 changes: 1 addition & 1 deletion jvm/native/osx-x86_64/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ under the License.
<compilerEndOptions>
<compilerEndOption>-I../../../include</compilerEndOption>
<compilerEndOption>-I${JAVA_HOME}/include</compilerEndOption>
<compilerEndOption>-I${JAVA_HOME}/include/linux</compilerEndOption>
<compilerEndOption>-I${JAVA_HOME}/include/darwin</compilerEndOption>
<compilerEndOption>${cflags}</compilerEndOption>
</compilerEndOptions>
<linkerStartOptions>
Expand Down
2 changes: 1 addition & 1 deletion jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
<version>3.5.0</version>
<executions>
<execution>
<id>attach-javadocs</id>
Expand Down
8 changes: 8 additions & 0 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def get_lib_path():
# remove large files
_remove_path(os.path.join(candidate_path, "cutlass", "docs"))
_remove_path(os.path.join(candidate_path, "cutlass", "media"))
_remove_path(
os.path.join(candidate_path, "cutlass_fpA_intB_gemm", "cutlass", "docs")
)
_remove_path(
os.path.join(
candidate_path, "cutlass_fpA_intB_gemm", "cutlass", "media"
)
)
break
else:
libs = None
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,24 @@ def driver_version(self):
"""
return self._GetDeviceAttr(self.device_type, self.device_id, 12)

@property
def l2_cache_size_bytes(self):
"""Return the size of the device L2 cache in bytes
Supported devices include CUDA/ROCM/OpenCL.
Returns
-------
l2_cache_size_bytes : int or None
The size of the device L2 cache in bytes returned by device runtime API.
Return None if the device does not support this feature.
Note
----
The value returned by opencl's API is smaller than actual device L2 cache size.
"""
return self._GetDeviceAttr(self.device_type, self.device_id, 13)

def texture_spatial_limit(self):
"""Returns limits for textures by spatial dimensions
Expand Down
105 changes: 102 additions & 3 deletions python/tvm/contrib/hexagon/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,16 @@

import tvm
from tvm import relay
from tvm.relay.dataflow_pattern import DFPatternCallback, rewrite, wildcard
from tvm.relay.dataflow_pattern import is_constant, is_op, is_tuple
from tvm.relay.dataflow_pattern import (
DFPatternCallback,
is_constant,
is_op,
is_tuple,
rewrite,
wildcard,
)
from tvm.relay.expr import Call

from ..._ffi.registry import register_func

### VTCM
Expand All @@ -43,7 +51,6 @@ def mem_info_vtcm():


def lower_vtcm_(get_alloc, get_free, def_align, func, mod, ctx): # pylint: disable=unused-argument

"""Generic VTCM allocation
Parameters
Expand Down Expand Up @@ -311,3 +318,95 @@ def remove_empty_pad(mod):
"""Remove the empty pad operator."""
mod["main"] = rewrite(remove_empty_pad_callback(), mod["main"])
return mod


class simplify_qnn_concat_in_func(DFPatternCallback):

"""
Propagate qnn.concat's quantization params to its inputs,
and try to avoid redundant requantization while doing so.
Replace
def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
%q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) {
%0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC");
%1 = qnn.requantize(%q2, 0.000109401f, 0, 0.00345f, 0, axis=1, out_dtype="uint8");
%2 = (%0, %1, %q3);
%3 = (0.0425042f, 0.00345f, 0.0486874f);
%4 = (0, 0, 0);
qnn.concatenate(%2, %3, %4, 0.0486874f, 0, axis=1)
}
with
def @main(%q1: Tensor[(1, 64, 35, 35), uint8],
%q2: Tensor[(1, 64, 35, 35), uint8], %q3: Tensor[(1, 32, 35, 35), uint8]) {
%0 = nn.max_pool2d(%q1, pool_size=[3, 3], padding=[1, 1, 1, 1], layout="NHWC");
%1 = qnn.requantize(%0, 0.0425042f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8");
%2 = qnn.requantize(%q2, 0.000109401f, 0, 0.0486874f, 0, axis=1, out_dtype="uint8");
%3 = (%1, %2, %q3);
concatenate(%3, axis=1)
}
"""

def __init__(self):
super(simplify_qnn_concat_in_func, self).__init__()
self.qvals = wildcard()
self.scales = wildcard()
self.zps = wildcard()
self.out_scale = wildcard()
self.out_zp = wildcard()
self.pattern = is_op("qnn.concatenate")(
self.qvals, self.scales, self.zps, self.out_scale, self.out_zp
)

def callback(self, pre, post, node_map):
in_qvals = node_map[self.qvals][0]
in_scales = node_map[self.scales][0]
in_zps = node_map[self.zps][0]
new_qvals = []
for i in range(len(in_qvals)):
new_requant_args = []
# TODO Generalize for all qnn ops
if isinstance(in_qvals[i], Call) and (in_qvals[i].op.name == "qnn.requantize"):
# propagate scale/zp of qnn.concat to this requantize op
for j in range(3):
new_requant_args.append(in_qvals[i].args[j])
new_requant_args += [node_map[self.out_scale][0], node_map[self.out_zp][0]]
new_qvals.append(relay.qnn.op.requantize(*new_requant_args, **(in_qvals[i].attrs)))
else:
# simply create a new requantize op if there is a change in quantization params
# if not, just retain the old qval
if (in_scales[i] == node_map[self.out_scale][0]) and (
in_zps[i] == node_map[self.out_zp][0]
):
new_qvals.append(in_qvals[i])
else:
new_requant_args += [
in_qvals[i],
in_scales[i],
in_zps[i],
node_map[self.out_scale][0],
node_map[self.out_zp][0],
]
new_qvals.append(
relay.qnn.op.requantize(
*new_requant_args,
axis=post.attrs["axis"],
out_dtype=post.checked_type.dtype,
)
)

new_op = relay.op.concatenate(
new_qvals,
node_map[self.pattern][0].attrs["axis"],
)
return new_op


# Right now context is ignored
@tvm.transform.module_pass(opt_level=1)
def simplify_qnn_concat(mod, _=None):
for global_var in mod.functions.keys():
mod[global_var] = rewrite(simplify_qnn_concat_in_func(), mod[global_var])
return mod

0 comments on commit d8f1ac4

Please sign in to comment.