Skip to content

[Disco] Loading-time sharding support#15826

Merged
junrushao merged 1 commit intoapache:unityfrom
LeshengJin:Disco/shard
Oct 3, 2023
Merged

[Disco] Loading-time sharding support#15826
junrushao merged 1 commit intoapache:unityfrom
LeshengJin:Disco/shard

Conversation

@LeshengJin
Copy link
Contributor

@LeshengJin LeshengJin commented Sep 26, 2023

In our previous implementation, parameter sharding relies on pre-quantization weight processing, meaning each set of quantized weights corresponds strictly to a hardcoded constant num_shards, and re-quantization is strictly required upon each change of #GPUs, e.g. from 4-GPU to 8-GPU setting. This PR makes it possible to move parameter sharding to post-quantization loading-time. During loading, we iterate over all parameters and apply the sharding operation based on the provided sharding information.

To make this happen, this PR makes an enhancement to the existing shard_info.json to include the sharding function being used at loading time. Each parameter is attached to a list of loading-time preprocessing methods that are serially applied to it to transform this parameter to the desired shape, as shown in the example below:

shard_info = {
  "x_0": [ # name of the parameter
    [ # a list of preprocessing functions to be applied
      "tests.disco.shard_dim_1",  # name of the sharding function
      [(num_shards, 64, 64), "float16"],  # output shape/dtype of `tests.disco.shard_dim_1`
      num_shards,  # extra inputs to `tests.disco.shard_dim_1`
    ],
  ],
  "x_1": [...],
}

To parameter x_0, it means we will call method tests.disco.shard_dim_1 which has the signature:

def shard_dim_1(
  input: NDArray,
  num_shards, # extra inputs
  output: NDArray, # and its shape is (num_shards, 64, 64), and dtype is "float16"
) -> None: ...

This approach simplifies parameter sharding for users and ensures correctness.

@LeshengJin LeshengJin changed the title [Disco] Advanced sharding support [Disco] Loading-time sharding support Sep 26, 2023
@LeshengJin LeshengJin force-pushed the Disco/shard branch 2 times, most recently from 0e1a24c to 191d3f7 Compare September 27, 2023 00:39
@LeshengJin LeshengJin marked this pull request as ready for review September 27, 2023 01:14
@junrushao
Copy link
Member

CC: @jinhongyii

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's unify sharding and reordering into a single function - no need for shard3d any more :)

@LeshengJin LeshengJin force-pushed the Disco/shard branch 3 times, most recently from 52450e8 to 594f7e0 Compare September 28, 2023 06:04
@jinhongyii
Copy link
Contributor

jinhongyii commented Sep 30, 2023

Can we separate reorganizing function and sharding function again? I'm thinking of some situation where there can be 2 reorganizing function ( doing nothing and combine qkv), and n different sharding functions( different tensor shape and different sharding dimension). In this case, we will need at most 2n functions in total in IRModule. Another benefit of separating them is that DistIR will not need to handle merging reorganizing function and sharding function.

@junrushao junrushao force-pushed the Disco/shard branch 3 times, most recently from 24bfad2 to 959dd58 Compare October 1, 2023 15:12
@junrushao junrushao force-pushed the Disco/shard branch 2 times, most recently from 98177e5 to bfd621d Compare October 1, 2023 16:18
@junrushao junrushao force-pushed the Disco/shard branch 2 times, most recently from 6730cfe to 10b3c18 Compare October 3, 2023 01:29
In our previous implementation, parameter sharding relies on pre-quantization weight processing,
meaning each set of quantized weights corresponds strictly to a hardcoded constant `num_shards`,
and re-quantization is strictly required upon each change of #GPUs, e.g. from 4-GPU to 8-GPU
setting. This PR makes it possible to move parameter sharding to post-quantization loading-time.
During loading, we iterate over all parameters and apply the sharding operation based on the
provided sharding information.

To make this happen, this PR makes an enhancement to the existing `shard_info.json` to include the
sharding function being used at loading time. Each parameter is attached to a list of loading-time
preprocessing methods that are serially applied to it to transform this parameter to the desired
shape, as shown in the example below:

```python
shard_info = {
  "x_0": [ # name of the parameter
    [ # a list of preprocessing functions to be applied
      "tests.disco.shard_dim_1",  # name of the sharding function
      [(num_shards, 64, 64), "float16"],  # output shape/dtype of `tests.disco.shard_dim_1`
      num_shards,  # extra inputs to `tests.disco.shard_dim_1`
    ],
  ],
  "x_1": [...],
}
```

To parameter `x_0`, it means we will call method `tests.disco.shard_dim_1` which has the signature:

```python
def shard_dim_1(
  input: NDArray,
  num_shards, # extra inputs
  output: NDArray, # and its shape is (num_shards, 64, 64), and dtype is "float16"
) -> None: ...
```

This approach simplifies parameter sharding for users and ensures correctness.
@junrushao junrushao merged commit 88a08ae into apache:unity Oct 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants