From b65d0ff0e6776b543e590089e0c2d2d26a1044a9 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 11 Dec 2024 18:03:53 +0800 Subject: [PATCH 1/3] fix lora name and rearange lora_b for wqkv --- lmdeploy/pytorch/models/internlm2.py | 26 ++++++++++++++++++++++++++ lmdeploy/pytorch/models/patch.py | 4 ++++ 2 files changed, 30 insertions(+) diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py index 6cbc2ccff3..3cfcc92bb3 100644 --- a/lmdeploy/pytorch/models/internlm2.py +++ b/lmdeploy/pytorch/models/internlm2.py @@ -395,6 +395,32 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, ) + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + adapter_id: int): + """load lora weights.""" + + from lmdeploy.pytorch.adapter.adapter import load_lora_weights + + num_heads = self.config.num_attention_heads + num_key_value_heads = self.config.num_key_value_heads + hidden_size = self.config.hidden_size + head_dim = hidden_size // num_heads + group_size = num_heads // num_key_value_heads + + def _rearange_wqkv(weights): + for name, loaded_weight in weights: + if 'wqkv.lora_B' in name: + loaded_weight = loaded_weight.unflatten( + 0, (-1, 2 + group_size, head_dim)) + q = loaded_weight[:, :-2].flatten(0, 2) + k = loaded_weight[:, -2].flatten(0, 1) + v = loaded_weight[:, -1].flatten(0, 1) + loaded_weight = torch.cat([q, k, v], dim=0) + yield name, loaded_weight + + weights_iter = _rearange_wqkv(weights) + load_lora_weights(self, weights_iter, adapter_id) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" # modify from vllm diff --git a/lmdeploy/pytorch/models/patch.py b/lmdeploy/pytorch/models/patch.py index a7fe4431ed..9604b19af5 100644 --- a/lmdeploy/pytorch/models/patch.py +++ b/lmdeploy/pytorch/models/patch.py @@ -251,6 +251,10 @@ def add_adapters(model: torch.nn.Module, ranks, scalings = get_ranks_and_scalings(target_name, adapter_cfgs, device=device) + # split in case target_name has '.' like 'attention.wo' + # which cannot be used as name of a module + # and it's not aligned with key in model.packed_modules_mapping + target_name = target_name.split('.')[-1] found_mods, pack_idx = find_all_target(model, target_name) sum_rank = ranks.sum().item() From b3eda3cc390d62dd2fe27cbf9b9c3703eed8b73b Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Mon, 16 Dec 2024 15:53:48 +0800 Subject: [PATCH 2/3] update for internvl --- lmdeploy/pytorch/models/internvl.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py index 79a796f7a2..d9d1553de3 100644 --- a/lmdeploy/pytorch/models/internvl.py +++ b/lmdeploy/pytorch/models/internvl.py @@ -516,6 +516,17 @@ def prepare_inputs_for_generation( inputs_embeds=inputs_embeds, ) + def load_lora_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], + adapter_id: int): + """load lora weights.""" + + if hasattr(self.language_model, 'load_lora_weights'): + return self.language_model.load_lora_weights(weights, adapter_id) + else: + from lmdeploy.pytorch.adapter.adapter import load_lora_weights + + return load_lora_weights(weights, adapter_id) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): """load weights.""" From e078bddd1c7ce03ac934e6ebfc74cb7b18f11158 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Wed, 18 Dec 2024 20:14:51 +0800 Subject: [PATCH 3/3] fix torchvision mismatch torch --- requirements/runtime_ascend.txt | 2 +- requirements/runtime_cuda.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 965175faf3..81f538275c 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -18,6 +18,6 @@ shortuuid tiktoken torch<=2.4.0,>=2.3.1 torch-npu==2.3.1 -torchvision<=0.19.0,>=0.15.0 +torchvision<=0.19.0,>=0.18.1 transformers uvicorn diff --git a/requirements/runtime_cuda.txt b/requirements/runtime_cuda.txt index a11a749424..41af6039bd 100644 --- a/requirements/runtime_cuda.txt +++ b/requirements/runtime_cuda.txt @@ -16,7 +16,7 @@ sentencepiece shortuuid tiktoken torch<=2.5.1,>=2.0.0 -torchvision<=0.19.0,>=0.15.0 +torchvision<=0.20.1,>=0.15.0 transformers triton==3.0.0; sys_platform == "linux" uvicorn