[Disco] Loading-time sharding support#15826
Merged
junrushao merged 1 commit intoapache:unityfrom Oct 3, 2023
Merged
Conversation
0e1a24c to
191d3f7
Compare
Member
|
CC: @jinhongyii |
jinhongyii
reviewed
Sep 27, 2023
junrushao
requested changes
Sep 28, 2023
Member
junrushao
left a comment
There was a problem hiding this comment.
Let's unify sharding and reordering into a single function - no need for shard3d any more :)
52450e8 to
594f7e0
Compare
Contributor
|
Can we separate reorganizing function and sharding function again? I'm thinking of some situation where there can be 2 reorganizing function ( doing nothing and combine qkv), and n different sharding functions( different tensor shape and different sharding dimension). In this case, we will need at most 2n functions in total in IRModule. Another benefit of separating them is that DistIR will not need to handle merging reorganizing function and sharding function. |
24bfad2 to
959dd58
Compare
junrushao
approved these changes
Oct 1, 2023
98177e5 to
bfd621d
Compare
jinhongyii
approved these changes
Oct 1, 2023
6730cfe to
10b3c18
Compare
In our previous implementation, parameter sharding relies on pre-quantization weight processing,
meaning each set of quantized weights corresponds strictly to a hardcoded constant `num_shards`,
and re-quantization is strictly required upon each change of #GPUs, e.g. from 4-GPU to 8-GPU
setting. This PR makes it possible to move parameter sharding to post-quantization loading-time.
During loading, we iterate over all parameters and apply the sharding operation based on the
provided sharding information.
To make this happen, this PR makes an enhancement to the existing `shard_info.json` to include the
sharding function being used at loading time. Each parameter is attached to a list of loading-time
preprocessing methods that are serially applied to it to transform this parameter to the desired
shape, as shown in the example below:
```python
shard_info = {
"x_0": [ # name of the parameter
[ # a list of preprocessing functions to be applied
"tests.disco.shard_dim_1", # name of the sharding function
[(num_shards, 64, 64), "float16"], # output shape/dtype of `tests.disco.shard_dim_1`
num_shards, # extra inputs to `tests.disco.shard_dim_1`
],
],
"x_1": [...],
}
```
To parameter `x_0`, it means we will call method `tests.disco.shard_dim_1` which has the signature:
```python
def shard_dim_1(
input: NDArray,
num_shards, # extra inputs
output: NDArray, # and its shape is (num_shards, 64, 64), and dtype is "float16"
) -> None: ...
```
This approach simplifies parameter sharding for users and ensures correctness.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
In our previous implementation, parameter sharding relies on pre-quantization weight processing, meaning each set of quantized weights corresponds strictly to a hardcoded constant
num_shards, and re-quantization is strictly required upon each change of #GPUs, e.g. from 4-GPU to 8-GPU setting. This PR makes it possible to move parameter sharding to post-quantization loading-time. During loading, we iterate over all parameters and apply the sharding operation based on the provided sharding information.To make this happen, this PR makes an enhancement to the existing
shard_info.jsonto include the sharding function being used at loading time. Each parameter is attached to a list of loading-time preprocessing methods that are serially applied to it to transform this parameter to the desired shape, as shown in the example below:To parameter
x_0, it means we will call methodtests.disco.shard_dim_1which has the signature:This approach simplifies parameter sharding for users and ensures correctness.