Skip to content

[Contrib] Save/load ShapeTuple in tvm.contrib.tvmjs#15700

Closed
Lunderberg wants to merge 1 commit intoapache:unityfrom
Lunderberg:unity_save_load_shapetuple
Closed

[Contrib] Save/load ShapeTuple in tvm.contrib.tvmjs#15700
Lunderberg wants to merge 1 commit intoapache:unityfrom
Lunderberg:unity_save_load_shapetuple

Conversation

@Lunderberg
Copy link
Contributor

Prior to this commit, only tensor parameters could be saved and loaded using tvm.contrib.tvmjs. This commit extends the functionality to also support ShapeTuple. This is intended for use alongside the LiftTransformParams functionality introduced in #15699.

@tqchen
Copy link
Member

tqchen commented Sep 8, 2023

looking at the background PR, I am not too sure if we want to hardcode slice index in the parameters. This is because in many cases they can be dynamically given by the runtime.

I think a better way could be provide these particular parameters as part of the runtime builtin, e.g. some form of get_rank function, then passed into the function. Or we can simply have the generated function call the builtin to figure out its rank

@Lunderberg
Copy link
Contributor Author

@tqchen For the background PR, the slice index is only included in the saved parameters if two conditions are met.

  1. The symbolic variable is required for a computation that occurs during runtime.
  2. The symbolic variable cannot be inferred at runtime, either from a runtime or from tensor parameters produced by the lifted *_transform_params function.

I don't think we should require a separate channel for passing this information. This is a more general approach that handles any symbolic variable that may need to be passed across the boundary of a segmented compute graph, not just a rank variable. If we have other variables that we want to enclose with the parameters, such as a max_seq_len, lora_scaling or temperature, this approach would allow them to be included with the associated parameter set.

I think there's still room for improvement, as it may be useful to expose these parameters in an inspectable manner (e.g. a JSON dictionary), such that a user can identify which parameters were used to generate a transformed parameter set.

@tqchen
Copy link
Member

tqchen commented Sep 8, 2023

Definitely agree that the parameters like lora_scaling, temperature needs to be passed in and configurable.

One thing to note is that these are usually parameters that needs to be decoupled from the weights themselves. In most of the cases, we would like to keep the same set of weights, change temperature/lora_scaling in the application without recompute the weight parameter set themselves.

The way to realize such decoupling is to enable mechanisms to pass in small set of parameters from app/controller side during runtime and not rely on weight parameter serialization for handle these cases. The disco runtime brought support for such cases.

@tqchen
Copy link
Member

tqchen commented Sep 8, 2023

To expand a bit, since I think the examples are great illustrating the overall case, there are usually two category of parameter inputs to a function

  • C0: weights, usually fixed for all applications during inference
  • C1: config settings(temperature, scaling factor) that are changeable during inference.

A typical function signature should ideally separate C0 and C1

def f(input, weight_params, config_params):
    pass

So we will be able to transform and setup weight parameter once, while passing config_params during runtime from the app configurations.

@Lunderberg
Copy link
Contributor Author

I agree, generally these parameters would be configurable at run-time. In that case, they would be parameters prior to the "num_input" used by LiftTransformParams to identify liftable parameters. However, we should have the ability to bind any parameter at any stage in the lowering. This enables the most common case, where C0 is provided at compile-time and C1 is provided at run-time, but also enables C1 to be provided at compile-time, or at a model-initialization step

# Initial model definition
class InitialModelDefinition:
    def end_to_end_model(input1, ..., inputN, config1, ..., configN, weight1, ..., weightN):
        ...

# Option 1: Retains separate C0 and C1
class AfterLiftingCompileTimeKnowns:
    def end_to_end_model(input1, ..., inputN, config1, ..., configN, transformed_weights):
        ...

    def compile_time_transform(weight1, ..., weightN) -> R.Tuple:
        ...

# Option 2: Optional initialization, merges C0 and C1
class AfterLiftingInitializationTimeKnowns:
    def end_to_end_model(input1, ..., inputN, transformed_weights_and_config):
        ...

    def compile_time_transform(config1, ..., configN, weight1, ..., weightN) -> R.Tuple:
        ...

# Option 2: C0 is pre-computed at compile-time, C1 is pre-computed during initialization
class AfterLiftingInitializationTimeKnowns:
    def end_to_end_model(input1, ..., inputN, transformed_weights_and_config):
        ...

    def compile_time_transform(weight1, ..., weightN) -> R.Tuple:
        ...

    def initialization_time_transform(config1, ...,  configN, transformed_weights) -> R.Tuple:
        ...

For this specific PR, by making the tvmjs be a general utility to save/restore the output of a relax function, rather than a specific utility to save/restore only relax tensors, we avoid locking ourselves into a specific use case.

@Lunderberg
Copy link
Contributor Author

Regarding disco, I agree that the distribution will be handled through the disco runtime, but don't see any conflict with that and the LiftTransformParams approach. By using both at the same time, we can have a simpler and more efficient startup.

# Load weights, then use disco runtime to perform sharding.  The
# sharding function uses the rank_dref to determine which portion of
# each weight to retain.
rank_dref = ...
weights = tvm.nd.array(...)
sharding_function = disco_session.get_func(...)
sharded_weights_dref = sharding_function(weights, rank_dref)

# Load sharded weights through disco runtime.  The load weights
# function uses the rank_dref to determine which set of pre-sharded
# weights to load.
rank_dref = ...
load_weights_function = disco_session.get_func(...)
sharded_weights_dref = load_weights_function(rank_dref)

Rather than loading all the weights onto one GPU, then sharding them at runtime, using both tools together allows us to have the pre-sharded weights loaded only onto the GPU that will be using that shard.

@tqchen
Copy link
Member

tqchen commented Sep 8, 2023

In the case where it is desirable to retain C1 at compile-time, perhaps the simplest way is to embed C1 as compile time constant object (so they can be embedded in VM bytecode) without relying on external format like the ndarray store.

The main rationale here is to keep the weight format simple and indeed intended for weight primarily, because they are big and needs some special considerations. We also keep weight format simple and only limit to weight, so loader do not need to worry about other cases. Embedded constant section is actually great for the rest of use-cases, where constants are small and can be handled natively by VM loading

Coming back to the rank in particular(and not other info), indeed it is useful for each of the worker to know the rank. This is somewhat different from both C0 and C1, since it is a runtime information that can be queried. Perhaps the simplest way is to enable a R.disco.get_rank builtin that can be queried by the loader, so they loads the respective pre-shaded data

@csullivan
Copy link
Contributor

@tqchen, is your primary feedback that you would prefer to see rank/slice index determined from the runtime (e.g. as you suggest R.disco.get_rank) as opposed to loading it from the serialized params file in order to keep the weight format simple?

I want clarify your perspective, as we are hoping you are not making a stronger argument about the use of LiftTransformParams for build time parameter manipulation.

I think @Lunderberg's comments are focused around concern that you are making this stronger argument. Because LiftTransformParams may lift symbolic variables, he is concerned that we would be throwing away information present at build time if we didn't have a way to serialize it. For rank we can agree this could be derived from the runtime. For other cases you are suggesting we can put this information into the embedded constant sections.

Does this capture both your views correctly?

@Lunderberg
Copy link
Contributor Author

@csullivan That's an accurate statement for me. I've been viewing LiftTransformParams as a general utility to segment a computational graph, and the tvm.contrib.tvmjs as a general library to save/load a precomputed fragment of the computational graph. Since a graph could require passing tensors, symbolic variables, or both across the border of a partition, I wanted the general utilities to support the general case.

@tqchen
Copy link
Member

tqchen commented Sep 8, 2023

Thanks for clarification. I think LiftParamsTransforms is useful. The discussions here likely suggests that:

  • LiftParamsTransforms should be ideally be able to separate C0 and C1, so runtime dependent variables can be passed as separate category
  • When they are available in compile time, they can be embedded as constant section.

@Lunderberg Lunderberg force-pushed the unity_save_load_shapetuple branch from d23f8b1 to 8674378 Compare September 13, 2023 15:07
Prior to this commit, only tensor parameters could be saved and loaded
using `tvm.contrib.tvmjs`.  This commit extends the functionality to
also support `ShapeTuple`.  This is intended for use alongside the
`LiftTransformParams` functionality introduced in
apache#15699.
@Lunderberg Lunderberg force-pushed the unity_save_load_shapetuple branch from 8674378 to 67d729e Compare November 9, 2023 16:12
@Lunderberg
Copy link
Contributor Author

Rebased onto main to keep various dev branches up-to-date. The need for this change may have been obsoleted by #15957, but I need to test locally to check.

@junrushao junrushao force-pushed the unity branch 2 times, most recently from c95d45f to 45eeb8c Compare December 18, 2023 21:00
@tqchen tqchen deleted the branch apache:unity March 29, 2024 12:18
@tqchen tqchen closed this Mar 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants