[Contrib] Save/load ShapeTuple in tvm.contrib.tvmjs#15700
[Contrib] Save/load ShapeTuple in tvm.contrib.tvmjs#15700Lunderberg wants to merge 1 commit intoapache:unityfrom
Conversation
|
looking at the background PR, I am not too sure if we want to hardcode slice index in the parameters. This is because in many cases they can be dynamically given by the runtime. I think a better way could be provide these particular parameters as part of the runtime builtin, e.g. some form of get_rank function, then passed into the function. Or we can simply have the generated function call the builtin to figure out its rank |
|
@tqchen For the background PR, the slice index is only included in the saved parameters if two conditions are met.
I don't think we should require a separate channel for passing this information. This is a more general approach that handles any symbolic variable that may need to be passed across the boundary of a segmented compute graph, not just a I think there's still room for improvement, as it may be useful to expose these parameters in an inspectable manner (e.g. a JSON dictionary), such that a user can identify which parameters were used to generate a transformed parameter set. |
|
Definitely agree that the parameters like One thing to note is that these are usually parameters that needs to be decoupled from the weights themselves. In most of the cases, we would like to keep the same set of weights, change The way to realize such decoupling is to enable mechanisms to pass in small set of parameters from app/controller side during runtime and not rely on weight parameter serialization for handle these cases. The disco runtime brought support for such cases. |
|
To expand a bit, since I think the examples are great illustrating the overall case, there are usually two category of parameter inputs to a function
A typical function signature should ideally separate C0 and C1 def f(input, weight_params, config_params):
passSo we will be able to transform and setup weight parameter once, while passing config_params during runtime from the app configurations. |
|
I agree, generally these parameters would be configurable at run-time. In that case, they would be parameters prior to the # Initial model definition
class InitialModelDefinition:
def end_to_end_model(input1, ..., inputN, config1, ..., configN, weight1, ..., weightN):
...
# Option 1: Retains separate C0 and C1
class AfterLiftingCompileTimeKnowns:
def end_to_end_model(input1, ..., inputN, config1, ..., configN, transformed_weights):
...
def compile_time_transform(weight1, ..., weightN) -> R.Tuple:
...
# Option 2: Optional initialization, merges C0 and C1
class AfterLiftingInitializationTimeKnowns:
def end_to_end_model(input1, ..., inputN, transformed_weights_and_config):
...
def compile_time_transform(config1, ..., configN, weight1, ..., weightN) -> R.Tuple:
...
# Option 2: C0 is pre-computed at compile-time, C1 is pre-computed during initialization
class AfterLiftingInitializationTimeKnowns:
def end_to_end_model(input1, ..., inputN, transformed_weights_and_config):
...
def compile_time_transform(weight1, ..., weightN) -> R.Tuple:
...
def initialization_time_transform(config1, ..., configN, transformed_weights) -> R.Tuple:
...For this specific PR, by making the |
|
Regarding disco, I agree that the distribution will be handled through the disco runtime, but don't see any conflict with that and the # Load weights, then use disco runtime to perform sharding. The
# sharding function uses the rank_dref to determine which portion of
# each weight to retain.
rank_dref = ...
weights = tvm.nd.array(...)
sharding_function = disco_session.get_func(...)
sharded_weights_dref = sharding_function(weights, rank_dref)
# Load sharded weights through disco runtime. The load weights
# function uses the rank_dref to determine which set of pre-sharded
# weights to load.
rank_dref = ...
load_weights_function = disco_session.get_func(...)
sharded_weights_dref = load_weights_function(rank_dref)Rather than loading all the weights onto one GPU, then sharding them at runtime, using both tools together allows us to have the pre-sharded weights loaded only onto the GPU that will be using that shard. |
|
In the case where it is desirable to retain C1 at compile-time, perhaps the simplest way is to embed C1 as compile time constant object (so they can be embedded in VM bytecode) without relying on external format like the ndarray store. The main rationale here is to keep the weight format simple and indeed intended for weight primarily, because they are big and needs some special considerations. We also keep weight format simple and only limit to weight, so loader do not need to worry about other cases. Embedded constant section is actually great for the rest of use-cases, where constants are small and can be handled natively by VM loading Coming back to the rank in particular(and not other info), indeed it is useful for each of the worker to know the rank. This is somewhat different from both C0 and C1, since it is a runtime information that can be queried. Perhaps the simplest way is to enable a |
|
@tqchen, is your primary feedback that you would prefer to see I want clarify your perspective, as we are hoping you are not making a stronger argument about the use of LiftTransformParams for build time parameter manipulation. I think @Lunderberg's comments are focused around concern that you are making this stronger argument. Because LiftTransformParams may lift symbolic variables, he is concerned that we would be throwing away information present at build time if we didn't have a way to serialize it. For Does this capture both your views correctly? |
|
@csullivan That's an accurate statement for me. I've been viewing |
|
Thanks for clarification. I think LiftParamsTransforms is useful. The discussions here likely suggests that:
|
d23f8b1 to
8674378
Compare
Prior to this commit, only tensor parameters could be saved and loaded using `tvm.contrib.tvmjs`. This commit extends the functionality to also support `ShapeTuple`. This is intended for use alongside the `LiftTransformParams` functionality introduced in apache#15699.
8674378 to
67d729e
Compare
|
Rebased onto main to keep various dev branches up-to-date. The need for this change may have been obsoleted by #15957, but I need to test locally to check. |
c95d45f to
45eeb8c
Compare
Prior to this commit, only tensor parameters could be saved and loaded using
tvm.contrib.tvmjs. This commit extends the functionality to also supportShapeTuple. This is intended for use alongside theLiftTransformParamsfunctionality introduced in #15699.