From de37081d206b0732dddb5fd69c64399ac4bc3390 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Wed, 24 Jun 2020 17:49:00 +0800 Subject: [PATCH] support random fill --- cmake/config.cmake | 2 +- include/tvm/runtime/ndarray.h | 8 +++ .../contrib/random/mt_random_engine.cc | 46 +++++++++++++++ src/runtime/contrib/random/random.cc | 6 ++ src/runtime/ndarray.cc | 2 +- tests/python/contrib/test_random.py | 57 +++++++++++++++++++ 6 files changed, 119 insertions(+), 2 deletions(-) diff --git a/cmake/config.cmake b/cmake/config.cmake index 6963ece47f8a1..c46f9c4eae2ab 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -140,7 +140,7 @@ set(USE_MKLDNN OFF) set(USE_OPENMP none) # Whether use contrib.random in runtime -set(USE_RANDOM OFF) +set(USE_RANDOM ON) # Whether use NNPack set(USE_NNPACK OFF) diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index e69d802652fda..67747d05fe0d2 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -306,6 +306,14 @@ inline size_t GetDataSize(const DLTensor& arr) { return size; } +/*! + * \brief return the alignment data the DLTensor hold + * + * \param arr the input DLTensor + * \return alignment of data in the DLTensor. + */ +size_t GetDataAlignment(const DLTensor& arr); + /*! * \brief check if a DLTensor is contiguous. * \param arr The input DLTensor. diff --git a/src/runtime/contrib/random/mt_random_engine.cc b/src/runtime/contrib/random/mt_random_engine.cc index c628e327643e4..f300eece26822 100644 --- a/src/runtime/contrib/random/mt_random_engine.cc +++ b/src/runtime/contrib/random/mt_random_engine.cc @@ -22,11 +22,15 @@ * \brief mt19937 random engine */ #include +#include +#include #include #include #include +#include "../3rdparty/compiler-rt/builtin_fp16.h" + namespace tvm { namespace contrib { @@ -111,6 +115,48 @@ class RandomEngine { } } + void RandomFill(DLTensor* data) { + DLTensor local; + local.shape = data->shape; + local.ndim = data->ndim; + local.dtype = data->dtype; + local.strides = data->strides; + local.byte_offset = data->byte_offset; + local.ctx = {kDLCPU, 0}; + local.data = runtime::DeviceAPI::Get(local.ctx)->AllocDataSpace( + {kDLCPU, 0}, runtime::GetDataSize(local), runtime::GetDataAlignment(local), local.dtype); + + int64_t size = 1; + for (int i = 0; i < data->ndim; ++i) { + size *= data->shape[i]; + } + + // Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy + // quantized dtype (uint8 / int8) data non-empty requirement + std::uniform_real_distribution<> dist(1.0, 10.0); + // Use float representation could make us work well on float / int type too. + for (int i = 0; i < size; ++i) { + if (local.dtype.bits == 1) { + (reinterpret_cast(local.data))[i] = dist(rnd_engine_); + } else if (local.dtype.bits == 8) { + (reinterpret_cast(local.data))[i] = dist(rnd_engine_); + } else if (local.dtype.bits == 16) { + (reinterpret_cast(local.data))[i] = + __truncXfYf2__( + static_cast(dist(rnd_engine_))); + } else if (local.dtype.bits == 32) { + (reinterpret_cast(local.data))[i] = dist(rnd_engine_); + } else if (local.dtype.bits == 64) { + (reinterpret_cast(local.data))[i] = dist(rnd_engine_); + } else { + LOG(FATAL) << "Doesn't support dtype code " << local.dtype.code << " dtype bits " + << local.dtype.bits; + } + } + + runtime::NDArray::CopyFromTo(&local, data); + } + private: std::mt19937 rnd_engine_; unsigned rseed_; diff --git a/src/runtime/contrib/random/random.cc b/src/runtime/contrib/random/random.cc index acba193c12305..14bdd267d38c4 100644 --- a/src/runtime/contrib/random/random.cc +++ b/src/runtime/contrib/random/random.cc @@ -117,5 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRe entry->random_engine.SampleNormal(out, loc, scale); }); +TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, TVMRetValue* ret) { + RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal(); + DLTensor* out = args[0]; + entry->random_engine.RandomFill(out); +}); + } // namespace contrib } // namespace tvm diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 800a9167dadc5..e460cf3ea7302 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -58,7 +58,7 @@ inline void VerifyDataType(DLDataType dtype) { CHECK_EQ(dtype.bits & (dtype.bits - 1), 0); } -inline size_t GetDataAlignment(const DLTensor& arr) { +size_t GetDataAlignment(const DLTensor& arr) { size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; if (align < kAllocAlignment) return kAllocAlignment; return align; diff --git a/tests/python/contrib/test_random.py b/tests/python/contrib/test_random.py index 9efdc3e5a7631..81a47e1d43304 100644 --- a/tests/python/contrib/test_random.py +++ b/tests/python/contrib/test_random.py @@ -18,6 +18,22 @@ from tvm import te import numpy as np from tvm.contrib import random +from tvm import rpc + +def enabled_ctx_list(): + ctx_list = [('cpu', tvm.cpu(0)), + ('gpu', tvm.gpu(0)), + ('cl', tvm.opencl(0)), + ('metal', tvm.metal(0)), + ('rocm', tvm.rocm(0)), + ('vulkan', tvm.vulkan(0)), + ('vpi', tvm.vpi(0))] + for k, v in ctx_list: + assert tvm.context(k, 0) == v + ctx_list = [x[1] for x in ctx_list if x[1].exist] + return ctx_list + +ENABLED_CTX_LIST = enabled_ctx_list() def test_randint(): m = 1024 @@ -89,8 +105,49 @@ def verify(target="llvm"): assert abs(np.std(na) - 4) < 1e-2 verify() +def test_random_fill(): + def test_local(ctx, dtype): + if not tvm.get_global_func("tvm.contrib.random.random_fill", True): + print("skip because extern function is not available") + return + np_ones = np.ones((512, 512), dtype=dtype) + value = tvm.nd.empty(np_ones.shape, np_ones.dtype, ctx) + random_fill = tvm.get_global_func("tvm.contrib.random.random_fill") + random_fill(value) + + assert np.count_nonzero(value.asnumpy()) == 512 * 512 + + # make sure arithmentic doesn't overflow too + np_values = value.asnumpy() + assert np.isfinite(np_values * np_values + np_values).any() + + def test_rpc(dtype): + if not tvm.get_global_func("tvm.contrib.random.random_fill", True): + print("skip because extern function is not available") + return + if not tvm.runtime.enabled("rpc") or not tvm.runtime.enabled("llvm"): + return + np_ones = np.ones((512, 512), dtype=dtype) + server = rpc.Server("localhost") + remote = rpc.connect(server.host, server.port) + value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu()) + random_fill = tvm.get_global_func("tvm.contrib.random.random_fill") + random_fill(value) + + assert np.count_nonzero(value.asnumpy()) == 512 * 512 + + # make sure arithmentic doesn't overflow too + np_values = value.asnumpy() + assert np.isfinite(np_values * np_values + np_values).any() + + for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32", + "int64", "uint64", "float16", "float32", "float64"]: + for ctx in ENABLED_CTX_LIST: + test_local(ctx, dtype) + test_rpc(dtype) if __name__ == "__main__": test_randint() test_uniform() test_normal() + test_random_fill()