Skip to content

Commit

Permalink
[Contrib] Support NDArray cache taking generator (apache#16693)
Browse files Browse the repository at this point in the history
This PR enhances the `dump_ndarray_cache` function to take
generator as input. Previously it can only take a dictionary.

Sometimes, it is possible that the total ndarray size cannot
fit the main CPU memory, in which case we may turn to using
generators so we can free some NDArray memory on the fly.
And this PR supports the NDArray cache dumping with generators.
  • Loading branch information
MasterJH5574 authored and Lunderberg committed Mar 12, 2024
1 parent 9ab4059 commit 426b639
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions python/tvm/contrib/tvmjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
"""Namespace to store utilities for building web runtime."""
import hashlib
import json
import math
import os
import shutil

# pylint: disable=unused-import
import sys
from typing import Mapping, Union
from types import GeneratorType
from typing import Iterator, Mapping, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -149,37 +151,48 @@ def pending_nbytes(self):


def dump_ndarray_cache(
params: Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
params: Union[
Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
],
cache_dir: str,
encode_format="f32-to-bf16",
meta_data=None,
shard_cap_mb=32,
show_progress: bool = True,
):
"""Dump parameters to NDArray cache.
Parameters
----------
params: Mapping[str, tvm.runtime.NDArray],
The parameter dictionary
params: Union[
Mapping[str, Union[np.ndarray, tvm.runtime.NDArray]],
Iterator[Tuple[str, Union[np.ndarray, tvm.runtime.NDArray]]],
]
The parameter dictionary or generator
cache_dir: str
The path to the cache
encode_format: {"f32-to-bf16", "raw"}
Encoding format.
meta_data: json-compatible-struct
Extra meta_data to be stored in the cache json file.
meta_data: json-compatible-struct or Callable[[], Any]
Extra meta_data to be stored in the cache json file,
or a callable that returns the metadata.
shard_cap_mb: int
Maxinum number of MB to be kept per shard
show_progress: bool
A boolean indicating if to show the dump progress.
"""
if encode_format not in ("raw", "f32-to-bf16"):
raise ValueError(f"Invalie encode_format {encode_format}")

meta_data = {} if meta_data is None else meta_data
records = []
total = len(params)
from_generator = isinstance(params, GeneratorType)
total_bytes = 0
counter = 0
max_out_length = 0

Expand All @@ -193,14 +206,16 @@ def dump_ndarray_cache(

shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes)

for k, origin_v in params.items():
param_generator = params.items() if not from_generator else params
for k, origin_v in param_generator:
shape = list(origin_v.shape)
v = origin_v
if not isinstance(v, np.ndarray):
v = v.numpy()

# prefer to preserve original dtype, especially if the format was bfloat16
dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype)
total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize

# convert fp32 to bf16
if encode_format == "f32-to-bf16" and dtype == "float32":
Expand All @@ -212,12 +227,14 @@ def dump_ndarray_cache(
shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format)

counter += 1
last_cmd = "[%04d/%04d] saving %s" % (counter, total, k)
flush = "\r" + (" " * max_out_length) + "\r"
max_out_length = max(len(last_cmd), max_out_length)
sys.stdout.write(flush + last_cmd)
if show_progress:
last_cmd = "[%04d] saving %s" % (counter, k)
flush = "\r" + (" " * max_out_length) + "\r"
max_out_length = max(len(last_cmd), max_out_length)
sys.stdout.write(flush + last_cmd)

records = shard_manager.finish()
meta_data = {} if meta_data is None else meta_data if not callable(meta_data) else meta_data()

nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json")

Expand Down

0 comments on commit 426b639

Please sign in to comment.