Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/en/advance/update_weights.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Update Weights

LMDeploy supports update model weights online for scenes such as RL training. Here are the steps to do so.

## Step 1: Launch server

For pytorch backend you have to add `--distributed-executor-backend ray`.

```shell
lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend
```

## Step 2: Offloads weights & kv cache

Before update model weights, the server should offloads weights and kv cache.

```python
from lmdeploy.utils import serialize_state_dict
import requests

BASE_URL = 'http://0.0.0.0:23333'
api_key = 'sk-xxx'

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}

# offloads weights and kv cache with level=2
response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))
assert response.status_code == 200, response.status_code

# wake up weights, the server is ready for update weights
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights']))
assert response.status_code == 200, response.status_code
```

## Step 3: Update weights

Split model weights into multi segments and update through `update_weights` endpoint.

```python
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
assert response.status_code == 200, f"response.status_code = {response.status_code}"

```

**Note**: For pytorch backend, lmdeploy also supports flattened bucket tensors:

```python
from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata

segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
named_tensors = list(segmented_state_dict[seg_idx].items())
bucket = FlattenedTensorBucket(named_tensors=named_tensors)
metadata = bucket.get_metadata()
flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata)
serialized_data = serialize_state_dict(flattened_tensor_data)
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket')
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
assert response.status_code == 200, f"response.status_code = {response.status_code}"
```

## Step 4: Wakeup server

After update model weights, the server should onloads kv cache and provide serving again with the new updated weights.

```python
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache']))
assert response.status_code == 200, response.status_code
```
1 change: 1 addition & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Documentation
advance/metrics.md
advance/context_parallel.md
advance/spec_decoding.md
advance/update_weights.md

.. toctree::
:maxdepth: 1
Expand Down
78 changes: 78 additions & 0 deletions docs/zh_cn/advance/update_weights.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# 权重更新

LMDeploy支持在线权重更新,方便RL训练等场景下的使用。以下是权重更新的步骤:

## 步骤 1: 启动服务

For pytorch backend you have to add `--distributed-executor-backend ray`.

```shell
lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend
```

## 步骤 2: 卸载权重和KV缓存

在权重更新前,需要调用API卸载权重和KV缓存,使推理引擎处于可更新状态:

```python
from lmdeploy.utils import serialize_state_dict
import requests

BASE_URL = 'http://0.0.0.0:23333'
api_key = 'sk-xxx'

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}",
}

# offloads weights and kv cache with level=2
response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2))
assert response.status_code == 200, response.status_code

# wake up weights, the server is ready for update weights
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights']))
assert response.status_code == 200, response.status_code
```

## 步骤 3: 更新权重

将模型权重切分后调用`update_weights`API进行更新。

```python
segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
serialized_data = serialize_state_dict(segmented_state_dict[seg_idx])
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1)
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
assert response.status_code == 200, f"response.status_code = {response.status_code}"

```

**注意**: 对于pytorch推理后端,lmdeploy还支持扁平化桶张量(flattened bucket tensor)传输方式:

```python
from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata

segmented_state_dict: List[Dict[str, torch.Tensor]] = ...
num_segment = len(segmented_state_dict)
for seg_idx in range(num_segment):
named_tensors = list(segmented_state_dict[seg_idx].items())
bucket = FlattenedTensorBucket(named_tensors=named_tensors)
metadata = bucket.get_metadata()
flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata)
serialized_data = serialize_state_dict(flattened_tensor_data)
data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket')
response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data)
assert response.status_code == 200, f"response.status_code = {response.status_code}"
```

## 步骤 4: 唤醒引擎

权重更新后,调用API构建KV缓存,唤醒引擎,重新提供推理服务。

```python
response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache']))
assert response.status_code == 200, response.status_code
```
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/metrics.md
advance/context_parallel.md
advance/spec_decoding.md
advance/update_weights.md

.. toctree::
:maxdepth: 1
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def add_parser_api_server():
ArgumentHelper.dllm_denoising_steps(pt_group)
ArgumentHelper.dllm_confidence_threshold(pt_group)
ArgumentHelper.enable_return_routed_experts(pt_group)
ArgumentHelper.distributed_executor_backend(pt_group)

# common engine args
dtype_act = ArgumentHelper.dtype(pt_group)
Expand Down
10 changes: 10 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,7 @@ def enable_return_routed_experts(parser):
default=False,
help='Whether to output routed expert ids for replay')

@staticmethod
def add_spec_group(parser):
spec_group = parser.add_argument_group('Speculative decoding arguments')
spec_group.add_argument('--speculative-algorithm',
Expand All @@ -719,6 +720,15 @@ def add_spec_group(parser):

return spec_group

@staticmethod
def distributed_executor_backend(parser):
"""Distributed_executor_backend."""
return parser.add_argument('--distributed-executor-backend',
type=str,
default=None,
choices=['uni', 'mp', 'ray'],
help='The distributed executor backend for pytorch engine.')


# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py
class FlexibleArgumentParser(argparse.ArgumentParser):
Expand Down
Loading