Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def dump_ndarray_cache(
v = v.numpy()

# prefer to preserve original dtype, especially if the format was bfloat16
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else v.dtype
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype)

# convert fp32 to bf16
if encode_format == "f32-to-bf16" and dtype == "float32":
Expand Down
204 changes: 204 additions & 0 deletions src/runtime/relax_vm/ndarray_cache_support.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/runtime/relax_vm/ndarray_cache_support.cc
* \brief Runtime to support ndarray cache file loading.
*
* This file provides a minimum support for ndarray cache file loading.
*
* The main focus of this implementation is to enable loading
* with minimum set of intermediate files while also being
* compatible to some of the multi-shard files that are more
* friendly in some of the environments.
*
* NDArray cache also provides a way to do system-wide
* parameter sharing across multiple VMs.
*
* There are likely other ways to load the parameters ndarray-ache.
* We will keep the impact minimum by puting it as a private
* runtime builtin provide as in this file.
*/
#define PICOJSON_USE_INT64

#include <picojson.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/registry.h>

#include <sstream>
#include <string>
#include <vector>

#include "../../support/utils.h"
#include "../file_utils.h"

namespace tvm {
namespace runtime {
namespace relax_vm {

/*!
* A NDArray cache to store pre-loaded arrays in the system.
*/
class NDArrayCache {
public:
static NDArrayCache* Global() {
static NDArrayCache* inst = new NDArrayCache();
return inst;
}

static void Update(String name, NDArray arr, bool override) {
NDArrayCache* pool = Global();
if (!override) {
ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache";
}
pool->pool_.Set(name, arr);
}

static Optional<NDArray> Get(String name) {
NDArrayCache* pool = Global();
auto it = pool->pool_.find(name);
if (it != pool->pool_.end()) {
return (*it).second;
} else {
return NullOpt;
}
}

static void Remove(String name) {
NDArrayCache* pool = Global();
pool->pool_.erase(name);
}

static void Clear() { Global()->pool_.clear(); }

/*!
* \brief Load parameters from path and append them.
*
* \param cache_path The cache to path.
* \param device_type The type of device to be loaded.
* \param device_id The device id.
*/
static void Load(const std::string& cache_path, int device_type, int device_id) {
DLDevice device{static_cast<DLDeviceType>(device_type), device_id};
std::string json_str;
LoadBinaryFromFile(cache_path + "/ndarray-cache.json", &json_str);
picojson::value json_info;
picojson::parse(json_info, json_str);
auto shard_records = json_info.get<picojson::object>()["records"].get<picojson::array>();

Map<String, NDArray> result;

for (auto shard_item : shard_records) {
auto shard_rec = shard_item.get<picojson::object>();
ICHECK(shard_rec["dataPath"].is<std::string>());
std::string data_path = shard_rec["dataPath"].get<std::string>();

std::string raw_data;
LoadBinaryFromFile(cache_path + "/" + data_path, &raw_data);
CHECK_EQ(shard_rec["format"].get<std::string>(), "raw-shard");
int64_t raw_nbytes = shard_rec["nbytes"].get<int64_t>();
CHECK_EQ(raw_nbytes, raw_data.length());

for (auto nd_item : shard_rec["records"].get<picojson::array>()) {
auto nd_rec = nd_item.get<picojson::object>();
CHECK(nd_rec["name"].is<std::string>());
String name = nd_rec["name"].get<std::string>();

std::vector<int64_t> shape;
for (auto value : nd_rec["shape"].get<picojson::array>()) {
shape.push_back(value.get<int64_t>());
}

DataType dtype(String2DLDataType(nd_rec["dtype"].get<std::string>()));
std::string encode_format = nd_rec["format"].get<std::string>();
int64_t offset = nd_rec["byteOffset"].get<int64_t>();
int64_t nbytes = nd_rec["nbytes"].get<int64_t>();
NDArray arr = NDArray::Empty(ShapeTuple(shape.begin(), shape.end()), dtype, device);

if (dtype == DataType::Float(32) && encode_format == "f32-to-bf16") {
// decode bf16 to f32
std::vector<uint16_t> buffer(nbytes / 2);
std::vector<uint32_t> decoded(nbytes / 2);
std::memcpy(buffer.data(), raw_data.data() + offset, nbytes);
for (size_t i = 0; i < buffer.size(); ++i) {
decoded[i] = static_cast<uint32_t>(buffer[i]) << 16;
}
arr.CopyFromBytes(decoded.data(), decoded.size() * sizeof(uint32_t));
} else {
arr.CopyFromBytes(raw_data.data() + offset, nbytes);
}
Update(name, arr, true);
}
}
}

private:
Map<String, NDArray> pool_;
};

TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.get").set_body_typed(NDArrayCache::Get);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.update").set_body_typed(NDArrayCache::Update);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear);
TVM_REGISTER_GLOBAL("vm.builtin.ndarray_cache.load").set_body_typed(NDArrayCache::Load);

// This param module node can be useful to get param dict in RPC mode
// when the remote already have loaded parameters from file.
class ParamModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const final { return "param_module"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_params") {
auto params = params_;
return PackedFunc([params](TVMArgs args, TVMRetValue* rv) { *rv = params; });
} else {
return PackedFunc();
}
}

static Array<NDArray> GetParams(const std::string& prefix, int num_params) {
Array<NDArray> params;
for (int i = 0; i < num_params; ++i) {
std::string name = prefix + "_" + std::to_string(i);
auto opt = NDArrayCache::Get(name);
if (opt) {
params.push_back(opt.value());
} else {
LOG(FATAL) << "Cannot find " << name << " in cache";
}
}
return params;
}

static Module Create(const std::string& prefix, int num_params) {
auto n = make_object<ParamModuleNode>();
n->params_ = GetParams(prefix, num_params);
return Module(n);
}

private:
Array<NDArray> params_;
};

TVM_REGISTER_GLOBAL("vm.builtin.param_module_from_cache").set_body_typed(ParamModuleNode::Create);
TVM_REGISTER_GLOBAL("vm.builtin.param_array_from_cache").set_body_typed(ParamModuleNode::GetParams);

} // namespace relax_vm
} // namespace runtime
} // namespace tvm
22 changes: 22 additions & 0 deletions tests/python/relax/test_runtime_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
import tvm
import tvm.testing
from tvm.contrib import tvmjs, utils

import pytest
import numpy as np

Expand Down Expand Up @@ -166,5 +168,25 @@ def test_attention_kv_cache():
assert res[i][1] == i


def test_ndarray_cache():
fload = tvm.get_global_func("vm.builtin.ndarray_cache.load")
fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache")

param_dict = {
"x_0": np.array([1, 2, 3], dtype="int32"),
"x_1": np.random.uniform(size=[10, 20]).astype("float32"),
}

temp = utils.tempdir()
tvmjs.dump_ndarray_cache(param_dict, temp.path, encode_format="f32-to-bf16")
fload(str(temp.path), tvm.cpu().device_type, 0)
res = fget_params("x", 2)
for i, v in enumerate(res):
v_np = param_dict[f"x_{i}"]
if v_np.dtype == "float32":
v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np))
np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6)


if __name__ == "__main__":
tvm.testing.main()
3 changes: 2 additions & 1 deletion web/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
TVM_ROOT=$(shell cd ..; pwd)

INCLUDE_FLAGS = -I$(TVM_ROOT) -I$(TVM_ROOT)/include\
-I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include -I$(TVM_ROOT)/3rdparty/compiler-rt
-I$(TVM_ROOT)/3rdparty/dlpack/include -I$(TVM_ROOT)/3rdparty/dmlc-core/include\
-I$(TVM_ROOT)/3rdparty/compiler-rt -I$(TVM_ROOT)/3rdparty/picojson

.PHONY: clean all rmtypedep preparetest

Expand Down
80 changes: 1 addition & 79 deletions web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
#include "src/runtime/relax_vm/executable.cc"
#include "src/runtime/relax_vm/lm_support.cc"
#include "src/runtime/relax_vm/memory_manager.cc"
#include "src/runtime/relax_vm/ndarray_cache_support.cc"
#include "src/runtime/relax_vm/vm.cc"

// --- Implementations of backend and wasm runtime API. ---
Expand Down Expand Up @@ -120,50 +121,6 @@ TVM_REGISTER_GLOBAL("testing.object_use_count").set_body([](TVMArgs args, TVMRet
*ret = (obj.use_count() - 1);
});

/*!
* A NDArray cache to store pre-loaded arrays in the system.
*/
class NDArrayCache {
public:
static NDArrayCache* Global() {
static NDArrayCache* inst = new NDArrayCache();
return inst;
}

static void Update(String name, NDArray arr, bool override) {
NDArrayCache* pool = Global();
if (!override) {
ICHECK_EQ(pool->pool_.count(name), 0) << "Name " << name << " already exists in the cache";
}
pool->pool_.Set(name, arr);
}

static Optional<NDArray> Get(String name) {
NDArrayCache* pool = Global();
auto it = pool->pool_.find(name);
if (it != pool->pool_.end()) {
return (*it).second;
} else {
return NullOpt;
}
}

static void Remove(String name) {
NDArrayCache* pool = Global();
pool->pool_.erase(name);
}

static void Clear() { Global()->pool_.clear(); }

private:
Map<String, NDArray> pool_;
};

TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.get").set_body_typed(NDArrayCache::Get);
TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.update").set_body_typed(NDArrayCache::Update);
TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.remove").set_body_typed(NDArrayCache::Remove);
TVM_REGISTER_GLOBAL("tvmjs.ndarray_cache.clear").set_body_typed(NDArrayCache::Clear);

void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format) {
if (format == "f32-to-bf16") {
std::vector<uint16_t> buffer(bytes.length() / 2);
Expand All @@ -186,40 +143,5 @@ void ArrayDecodeStorage(NDArray cpu_arr, std::string bytes, std::string format)
}

TVM_REGISTER_GLOBAL("tvmjs.array.decode_storage").set_body_typed(ArrayDecodeStorage);

class ParamModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const final { return "param_module"; }

PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final {
if (name == "get_params") {
auto params = params_;
return PackedFunc([params](TVMArgs args, TVMRetValue* rv) { *rv = params; });
} else {
return PackedFunc();
}
}

static Module Create(std::string prefix, int num_params) {
Array<NDArray> params;
for (int i = 0; i < num_params; ++i) {
std::string name = prefix + "_" + std::to_string(i);
auto opt = NDArrayCache::Get(name);
if (opt) {
params.push_back(opt.value());
} else {
LOG(FATAL) << "Cannot find " << name << " in cache";
}
}
auto n = make_object<ParamModuleNode>();
n->params_ = params;
return Module(n);
}

private:
Array<NDArray> params_;
};

TVM_REGISTER_GLOBAL("tvmjs.param_module_from_cache").set_body_typed(ParamModuleNode::Create);
} // namespace runtime
} // namespace tvm
10 changes: 5 additions & 5 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,12 +161,12 @@ class RuntimeContext implements Disposable {
this.arrayGetSize = getGlobalFunc("runtime.ArraySize");
this.arrayMake = getGlobalFunc("runtime.Array");
this.getSysLib = getGlobalFunc("runtime.SystemLib");
this.arrayCacheGet = getGlobalFunc("tvmjs.ndarray_cache.get");
this.arrayCacheRemove = getGlobalFunc("tvmjs.ndarray_cache.remove");
this.arrayCacheUpdate = getGlobalFunc("tvmjs.ndarray_cache.update");
this.arrayCacheClear = getGlobalFunc("tvmjs.ndarray_cache.clear");
this.arrayCacheGet = getGlobalFunc("vm.builtin.ndarray_cache.get");
this.arrayCacheRemove = getGlobalFunc("vm.builtin.ndarray_cache.remove");
this.arrayCacheUpdate = getGlobalFunc("vm.builtin.ndarray_cache.update");
this.arrayCacheClear = getGlobalFunc("vm.builtin.ndarray_cache.clear");
this.arrayDecodeStorage = getGlobalFunc("tvmjs.array.decode_storage");
this.paramModuleFromCache = getGlobalFunc("tvmjs.param_module_from_cache");
this.paramModuleFromCache = getGlobalFunc("vm.builtin.param_module_from_cache");
this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
Expand Down