Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Contrib] Support NDArray cache taking generator #16693

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
"""Namespace to store utilities for building web runtime."""
import hashlib
import json
import math
import os
import shutil

# pylint: disable=unused-import
import sys
from typing import Mapping, Union
from types import GeneratorType
from typing import Iterator, Mapping, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -149,37 +151,48 @@ def pending_nbytes(self):


def dump_ndarray_cache(
params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
params: Union[
Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
],
cache_dir: str,
encode_format="f32-to-bf16",
meta_data=None,
shard_cap_mb=32,
show_progress: bool = True,
):
"""Dump parameters to NDArray cache.
Parameters
----------
params: Mapping[str, tvm.runtime.NDArray],
The parameter dictionary
params: Union[
Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
]
The parameter dictionary or generator
cache_dir: str
The path to the cache
encode_format: {"f32-to-bf16", "raw"}
Encoding format.
meta_data: json-compatible-struct
Extra meta_data to be stored in the cache json file.
meta_data: json-compatible-struct or Callable[[], Any]
MasterJH5574 marked this conversation as resolved.
Show resolved Hide resolved
Extra meta_data to be stored in the cache json file,
or a callable that returns the metadata.
shard_cap_mb: int
Maxinum number of MB to be kept per shard
show_progress: bool
A boolean indicating if to show the dump progress.
"""
if encode_format not in ("raw", "f32-to-bf16"):
raise ValueError(f"Invalie encode_format {encode_format}")

meta_data = {} if meta_data is None else meta_data
records = []
total = len(params)
from_generator = isinstance(params, GeneratorType)
total_bytes = 0
counter = 0
max_out_length = 0

Expand All @@ -193,14 +206,16 @@ def dump_ndarray_cache(

shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes)

for k, origin_v in params.items():
param_generator = params.items() if not from_generator else params
for k, origin_v in param_generator:
shape = list(origin_v.shape)
v = origin_v
if not isinstance(v, np.ndarray):
v = v.numpy()

# prefer to preserve original dtype, especially if the format was bfloat16
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype)
total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize

# convert fp32 to bf16
if encode_format == "f32-to-bf16" and dtype == "float32":
Expand All @@ -212,12 +227,14 @@ def dump_ndarray_cache(
shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format)

counter += 1
last_cmd = "[%04d/%04d] saving %s" % (counter, total, k)
flush = "\r" + (" " * max_out_length) + "\r"
max_out_length = max(len(last_cmd), max_out_length)
sys.stdout.write(flush + last_cmd)
if show_progress:
last_cmd = "[%04d] saving %s" % (counter, k)
flush = "\r" + (" " * max_out_length) + "\r"
max_out_length = max(len(last_cmd), max_out_length)
sys.stdout.write(flush + last_cmd)

records = shard_manager.finish()
meta_data = {} if meta_data is None else meta_data if not callable(meta_data) else meta_data()

nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")

Expand Down
Loading