Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 63 additions & 33 deletions python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def pending_nbytes(self):


def dump_ndarray_cache(
params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
params: Mapping[str, Union[tuple, np.ndarray, tvm.runtime.NDArray]],
cache_dir: str,
encode_format="f32-to-bf16",
meta_data=None,
Expand All @@ -145,14 +145,18 @@ def dump_ndarray_cache(

Parameters
----------
params: Mapping[str, tvm.runtime.NDArray],
The parameter dictionary
params: Mapping[str, Union[tuple, np.ndarray, tvm.runtime.NDArray]],
The parameter dictionary. Array parameters should be passed
as np.ndarray or tvm.runtime.NDArray, and use the encoding
format specified by the "encode_format" parameter. Tuple
parameters, typically ShapeExpr, must contain only integer
values, and are encoded as raw int64 arrays.

cache_dir: str
The path to the cache

encode_format: {"f32-to-bf16", "raw"}
Encoding format.
Encoding format for array parameters.

meta_data: json-compatible-struct
Extra meta_data to be stored in the cache json file.
Expand All @@ -179,26 +183,36 @@ def dump_ndarray_cache(

shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes)

for k, origin_v in params.items():
shape = list(origin_v.shape)
v = origin_v
if not isinstance(v, np.ndarray):
v = v.numpy()
for name, param in params.items():
if isinstance(param, tvm.runtime.ShapeTuple):
assert all(
isinstance(element, int) for element in param
), "Encoded shape tuple must have integer elements"
param = np.array(param, dtype="int64")
item_encode_format = "raw-shape-tuple"
else:
item_encode_format = encode_format

# prefer to preserve original dtype, especially if the format was bfloat16
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype)
# Grab dtype before type conversion to preserve original
# dtype, especially if the format was bfloat16
dtype = str(param.dtype)
if not isinstance(param, np.ndarray):
param = param.numpy()

# convert fp32 to bf16
if encode_format == "f32-to-bf16" and dtype == "float32":
data = _convert_f32_to_bf16(v).tobytes()
data = _convert_f32_to_bf16(param).tobytes()
f32_to_bf16_triggered = True
else:
data = v.tobytes()
data = param.tobytes()

shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format)
shape = list(param.shape)
shard_manager.append(
data, name=name, shape=shape, dtype=dtype, encode_format=item_encode_format
)

counter += 1
last_cmd = "[%04d/%04d] saving %s" % (counter, total, k)
last_cmd = "[%04d/%04d] saving %s" % (counter, total, name)
flush = "\r" + (" " * max_out_length) + "\r"
max_out_length = max(len(last_cmd), max_out_length)
sys.stdout.write(flush + last_cmd)
Expand Down Expand Up @@ -242,6 +256,39 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device):
if not cachepath.endswith(".json"):
cachepath = os.path.join(cachepath, "ndarray-cache.json")

def _unpack_record(rec, raw_data):
shape = rec["shape"]
dtype = rec["dtype"]
encode_format = rec["format"]
offset = rec["byteOffset"]
nbytes = rec["nbytes"]

assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]

if encode_format == "raw-shape-tuple":
assert len(shape) == 1, "Shape tuple must be 1-d"
assert dtype == "int64", "Shape tuple must have int64 datatype"
data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape)
return tvm.runtime.ShapeTuple(data.tolist())

arr = tvm.nd.empty(shape, dtype, device=device)
if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
arr.copyfrom(_convert_bf16_to_f32(data))
elif encode_format == "raw" and dtype == "bfloat16":
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
arr.copyfrom(data)
elif encode_format == "raw":
data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape)
arr.copyfrom(data)
else:
raise ValueError(
f"Unknown combination of encode format and dtype: ({encode_format}, {dtype})"
)

return arr

cachedir = os.path.dirname(cachepath)
json_info = json.loads(open(cachepath, "r").read())
result_dict = {}
Expand All @@ -255,24 +302,7 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device):

for rec in shard_rec["records"]:
name = rec["name"]
shape = rec["shape"]
dtype = rec["dtype"]
encode_format = rec["format"]
offset = rec["byteOffset"]
nbytes = rec["nbytes"]

arr = tvm.nd.empty(shape, dtype, device=device)
assert offset + nbytes <= len(raw_data)
buffer_source = raw_data[offset : offset + nbytes]
if encode_format == "f32-to-bf16" and dtype == "float32":
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
arr.copyfrom(_convert_bf16_to_f32(data))
elif dtype == "bfloat16":
data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape)
arr.copyfrom(data)
else:
data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape)
arr.copyfrom(data)
arr = _unpack_record(rec, raw_data)
result_dict[name] = arr
return result_dict, json_info["metadata"]

Expand Down