diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index c146427a9763..1f1d4a2246d0 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -135,7 +135,7 @@ def pending_nbytes(self): def dump_ndarray_cache( - params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], + params: Mapping[str, Union[tuple, np.ndarray, tvm.runtime.NDArray]], cache_dir: str, encode_format="f32-to-bf16", meta_data=None, @@ -145,14 +145,18 @@ def dump_ndarray_cache( Parameters ---------- - params: Mapping[str, tvm.runtime.NDArray], - The parameter dictionary + params: Mapping[str, Union[tuple, np.ndarray, tvm.runtime.NDArray]], + The parameter dictionary. Array parameters should be passed + as np.ndarray or tvm.runtime.NDArray, and use the encoding + format specified by the "encode_format" parameter. Tuple + parameters, typically ShapeExpr, must contain only integer + values, and are encoded as raw int64 arrays. cache_dir: str The path to the cache encode_format: {"f32-to-bf16", "raw"} - Encoding format. + Encoding format for array parameters. meta_data: json-compatible-struct Extra meta_data to be stored in the cache json file. @@ -179,26 +183,36 @@ def dump_ndarray_cache( shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes) - for k, origin_v in params.items(): - shape = list(origin_v.shape) - v = origin_v - if not isinstance(v, np.ndarray): - v = v.numpy() + for name, param in params.items(): + if isinstance(param, tvm.runtime.ShapeTuple): + assert all( + isinstance(element, int) for element in param + ), "Encoded shape tuple must have integer elements" + param = np.array(param, dtype="int64") + item_encode_format = "raw-shape-tuple" + else: + item_encode_format = encode_format - # prefer to preserve original dtype, especially if the format was bfloat16 - dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype) + # Grab dtype before type conversion to preserve original + # dtype, especially if the format was bfloat16 + dtype = str(param.dtype) + if not isinstance(param, np.ndarray): + param = param.numpy() # convert fp32 to bf16 if encode_format == "f32-to-bf16" and dtype == "float32": - data = _convert_f32_to_bf16(v).tobytes() + data = _convert_f32_to_bf16(param).tobytes() f32_to_bf16_triggered = True else: - data = v.tobytes() + data = param.tobytes() - shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format) + shape = list(param.shape) + shard_manager.append( + data, name=name, shape=shape, dtype=dtype, encode_format=item_encode_format + ) counter += 1 - last_cmd = "[%04d/%04d] saving %s" % (counter, total, k) + last_cmd = "[%04d/%04d] saving %s" % (counter, total, name) flush = "\r" + (" " * max_out_length) + "\r" max_out_length = max(len(last_cmd), max_out_length) sys.stdout.write(flush + last_cmd) @@ -242,6 +256,39 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): if not cachepath.endswith(".json"): cachepath = os.path.join(cachepath, "ndarray-cache.json") + def _unpack_record(rec, raw_data): + shape = rec["shape"] + dtype = rec["dtype"] + encode_format = rec["format"] + offset = rec["byteOffset"] + nbytes = rec["nbytes"] + + assert offset + nbytes <= len(raw_data) + buffer_source = raw_data[offset : offset + nbytes] + + if encode_format == "raw-shape-tuple": + assert len(shape) == 1, "Shape tuple must be 1-d" + assert dtype == "int64", "Shape tuple must have int64 datatype" + data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape) + return tvm.runtime.ShapeTuple(data.tolist()) + + arr = tvm.nd.empty(shape, dtype, device=device) + if encode_format == "f32-to-bf16" and dtype == "float32": + data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) + arr.copyfrom(_convert_bf16_to_f32(data)) + elif encode_format == "raw" and dtype == "bfloat16": + data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) + arr.copyfrom(data) + elif encode_format == "raw": + data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape) + arr.copyfrom(data) + else: + raise ValueError( + f"Unknown combination of encode format and dtype: ({encode_format}, {dtype})" + ) + + return arr + cachedir = os.path.dirname(cachepath) json_info = json.loads(open(cachepath, "r").read()) result_dict = {} @@ -255,24 +302,7 @@ def load_ndarray_cache(cachepath: str, device: tvm.runtime.Device): for rec in shard_rec["records"]: name = rec["name"] - shape = rec["shape"] - dtype = rec["dtype"] - encode_format = rec["format"] - offset = rec["byteOffset"] - nbytes = rec["nbytes"] - - arr = tvm.nd.empty(shape, dtype, device=device) - assert offset + nbytes <= len(raw_data) - buffer_source = raw_data[offset : offset + nbytes] - if encode_format == "f32-to-bf16" and dtype == "float32": - data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) - arr.copyfrom(_convert_bf16_to_f32(data)) - elif dtype == "bfloat16": - data = np.frombuffer(buffer_source, dtype="uint16").reshape(shape) - arr.copyfrom(data) - else: - data = np.frombuffer(buffer_source, dtype=dtype).reshape(shape) - arr.copyfrom(data) + arr = _unpack_record(rec, raw_data) result_dict[name] = arr return result_dict, json_info["metadata"]