Skip to content

Commit

Permalink
support random fill
Browse files Browse the repository at this point in the history
  • Loading branch information
FrozenGene committed Aug 10, 2020
1 parent 7926a5d commit de37081
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 2 deletions.
2 changes: 1 addition & 1 deletion cmake/config.cmake
Expand Up @@ -140,7 +140,7 @@ set(USE_MKLDNN OFF)
set(USE_OPENMP none)

# Whether use contrib.random in runtime
set(USE_RANDOM OFF)
set(USE_RANDOM ON)

# Whether use NNPack
set(USE_NNPACK OFF)
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/runtime/ndarray.h
Expand Up @@ -306,6 +306,14 @@ inline size_t GetDataSize(const DLTensor& arr) {
return size;
}

/*!
* \brief return the alignment data the DLTensor hold
*
* \param arr the input DLTensor
* \return alignment of data in the DLTensor.
*/
size_t GetDataAlignment(const DLTensor& arr);

/*!
* \brief check if a DLTensor is contiguous.
* \param arr The input DLTensor.
Expand Down
46 changes: 46 additions & 0 deletions src/runtime/contrib/random/mt_random_engine.cc
Expand Up @@ -22,11 +22,15 @@
* \brief mt19937 random engine
*/
#include <dmlc/logging.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/ndarray.h>

#include <algorithm>
#include <ctime>
#include <random>

#include "../3rdparty/compiler-rt/builtin_fp16.h"

namespace tvm {
namespace contrib {

Expand Down Expand Up @@ -111,6 +115,48 @@ class RandomEngine {
}
}

void RandomFill(DLTensor* data) {
DLTensor local;
local.shape = data->shape;
local.ndim = data->ndim;
local.dtype = data->dtype;
local.strides = data->strides;
local.byte_offset = data->byte_offset;
local.ctx = {kDLCPU, 0};
local.data = runtime::DeviceAPI::Get(local.ctx)->AllocDataSpace(
{kDLCPU, 0}, runtime::GetDataSize(local), runtime::GetDataAlignment(local), local.dtype);

int64_t size = 1;
for (int i = 0; i < data->ndim; ++i) {
size *= data->shape[i];
}

// Make the value be 1.0 - 10.0, not (0.0 - 1.0) so that we could satisfy
// quantized dtype (uint8 / int8) data non-empty requirement
std::uniform_real_distribution<> dist(1.0, 10.0);
// Use float representation could make us work well on float / int type too.
for (int i = 0; i < size; ++i) {
if (local.dtype.bits == 1) {
(reinterpret_cast<bool*>(local.data))[i] = dist(rnd_engine_);
} else if (local.dtype.bits == 8) {
(reinterpret_cast<uint8_t*>(local.data))[i] = dist(rnd_engine_);
} else if (local.dtype.bits == 16) {
(reinterpret_cast<uint16_t*>(local.data))[i] =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(
static_cast<float>(dist(rnd_engine_)));
} else if (local.dtype.bits == 32) {
(reinterpret_cast<float*>(local.data))[i] = dist(rnd_engine_);
} else if (local.dtype.bits == 64) {
(reinterpret_cast<double*>(local.data))[i] = dist(rnd_engine_);
} else {
LOG(FATAL) << "Doesn't support dtype code " << local.dtype.code << " dtype bits "
<< local.dtype.bits;
}
}

runtime::NDArray::CopyFromTo(&local, data);
}

private:
std::mt19937 rnd_engine_;
unsigned rseed_;
Expand Down
6 changes: 6 additions & 0 deletions src/runtime/contrib/random/random.cc
Expand Up @@ -117,5 +117,11 @@ TVM_REGISTER_GLOBAL("tvm.contrib.random.normal").set_body([](TVMArgs args, TVMRe
entry->random_engine.SampleNormal(out, loc, scale);
});

TVM_REGISTER_GLOBAL("tvm.contrib.random.random_fill").set_body([](TVMArgs args, TVMRetValue* ret) {
RandomThreadLocalEntry* entry = RandomThreadLocalEntry::ThreadLocal();
DLTensor* out = args[0];
entry->random_engine.RandomFill(out);
});

} // namespace contrib
} // namespace tvm
2 changes: 1 addition & 1 deletion src/runtime/ndarray.cc
Expand Up @@ -58,7 +58,7 @@ inline void VerifyDataType(DLDataType dtype) {
CHECK_EQ(dtype.bits & (dtype.bits - 1), 0);
}

inline size_t GetDataAlignment(const DLTensor& arr) {
size_t GetDataAlignment(const DLTensor& arr) {
size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
if (align < kAllocAlignment) return kAllocAlignment;
return align;
Expand Down
57 changes: 57 additions & 0 deletions tests/python/contrib/test_random.py
Expand Up @@ -18,6 +18,22 @@
from tvm import te
import numpy as np
from tvm.contrib import random
from tvm import rpc

def enabled_ctx_list():
ctx_list = [('cpu', tvm.cpu(0)),
('gpu', tvm.gpu(0)),
('cl', tvm.opencl(0)),
('metal', tvm.metal(0)),
('rocm', tvm.rocm(0)),
('vulkan', tvm.vulkan(0)),
('vpi', tvm.vpi(0))]
for k, v in ctx_list:
assert tvm.context(k, 0) == v
ctx_list = [x[1] for x in ctx_list if x[1].exist]
return ctx_list

ENABLED_CTX_LIST = enabled_ctx_list()

def test_randint():
m = 1024
Expand Down Expand Up @@ -89,8 +105,49 @@ def verify(target="llvm"):
assert abs(np.std(na) - 4) < 1e-2
verify()

def test_random_fill():
def test_local(ctx, dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
np_ones = np.ones((512, 512), dtype=dtype)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, ctx)
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.asnumpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.asnumpy()
assert np.isfinite(np_values * np_values + np_values).any()

def test_rpc(dtype):
if not tvm.get_global_func("tvm.contrib.random.random_fill", True):
print("skip because extern function is not available")
return
if not tvm.runtime.enabled("rpc") or not tvm.runtime.enabled("llvm"):
return
np_ones = np.ones((512, 512), dtype=dtype)
server = rpc.Server("localhost")
remote = rpc.connect(server.host, server.port)
value = tvm.nd.empty(np_ones.shape, np_ones.dtype, remote.cpu())
random_fill = tvm.get_global_func("tvm.contrib.random.random_fill")
random_fill(value)

assert np.count_nonzero(value.asnumpy()) == 512 * 512

# make sure arithmentic doesn't overflow too
np_values = value.asnumpy()
assert np.isfinite(np_values * np_values + np_values).any()

for dtype in ["bool", "int8", "uint8", "int16", "uint16", "int32", "int32",
"int64", "uint64", "float16", "float32", "float64"]:
for ctx in ENABLED_CTX_LIST:
test_local(ctx, dtype)
test_rpc(dtype)

if __name__ == "__main__":
test_randint()
test_uniform()
test_normal()
test_random_fill()

0 comments on commit de37081

Please sign in to comment.