diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index 4cef868cfd72..8d8bd1b0510b 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -17,12 +17,14 @@ """Namespace to store utilities for building web runtime.""" import hashlib import json +import math import os import shutil # pylint: disable=unused-import import sys -from typing import Mapping, Union +from types import GeneratorType +from typing import Iterator, Mapping, Tuple, Union import numpy as np @@ -149,18 +151,25 @@ def pending_nbytes(self): def dump_ndarray_cache( - params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], + params: Union[ + Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], + Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]], + ], cache_dir: str, encode_format="f32-to-bf16", meta_data=None, shard_cap_mb=32, + show_progress: bool = True, ): """Dump parameters to NDArray cache. Parameters ---------- - params: Mapping[str, tvm.runtime.NDArray], - The parameter dictionary + params: Union[ + Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]], + Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]], + ] + The parameter dictionary or generator cache_dir: str The path to the cache @@ -168,18 +177,22 @@ def dump_ndarray_cache( encode_format: {"f32-to-bf16", "raw"} Encoding format. - meta_data: json-compatible-struct - Extra meta_data to be stored in the cache json file. + meta_data: json-compatible-struct or Callable[[], Any] + Extra meta_data to be stored in the cache json file, + or a callable that returns the metadata. shard_cap_mb: int Maxinum number of MB to be kept per shard + + show_progress: bool + A boolean indicating if to show the dump progress. """ if encode_format not in ("raw", "f32-to-bf16"): raise ValueError(f"Invalie encode_format {encode_format}") - meta_data = {} if meta_data is None else meta_data records = [] - total = len(params) + from_generator = isinstance(params, GeneratorType) + total_bytes = 0 counter = 0 max_out_length = 0 @@ -193,7 +206,8 @@ def dump_ndarray_cache( shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes) - for k, origin_v in params.items(): + param_generator = params.items() if not from_generator else params + for k, origin_v in param_generator: shape = list(origin_v.shape) v = origin_v if not isinstance(v, np.ndarray): @@ -201,6 +215,7 @@ def dump_ndarray_cache( # prefer to preserve original dtype, especially if the format was bfloat16 dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype) + total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize # convert fp32 to bf16 if encode_format == "f32-to-bf16" and dtype == "float32": @@ -212,12 +227,14 @@ def dump_ndarray_cache( shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format) counter += 1 - last_cmd = "[%04d/%04d] saving %s" % (counter, total, k) - flush = "\r" + (" " * max_out_length) + "\r" - max_out_length = max(len(last_cmd), max_out_length) - sys.stdout.write(flush + last_cmd) + if show_progress: + last_cmd = "[%04d] saving %s" % (counter, k) + flush = "\r" + (" " * max_out_length) + "\r" + max_out_length = max(len(last_cmd), max_out_length) + sys.stdout.write(flush + last_cmd) records = shard_manager.finish() + meta_data = {} if meta_data is None else meta_data if not callable(meta_data) else meta_data() nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")