-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
lora_model.py
861 lines (783 loc) Β· 39.5 KB
/
lora_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import math
import os
import re
import tempfile
from collections import OrderedDict
from functools import partial
from typing import Dict, List, Union
import aistudio_sdk
import numpy as np
import paddle
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import (
ColumnParallelLinear,
PipelineLayer,
RowParallelLinear,
)
from ...transformers import linear_utils
from ...transformers.conversion_utils import ConversionMixin
from ...transformers.model_utils import (
PretrainedModel,
_add_variant,
_load_state_dict_into_model,
dtype_guard,
load_state_dict,
)
from ...transformers.utils import get_checkpoint_shard_files, weight_name_suffix
from ...utils.distributed import distributed_allgather, distributed_gather
from ...utils.env import LORA_WEIGHTS_NAME, SAFE_PEFT_WEIGHTS_INDEX_NAME
from ...utils.log import logger
from ...utils.tools import get_env_device
from .lora_config import LoRAConfig
def get_lora_layers():
try:
if get_env_device() == "xpu":
# If paddle_xpu is not installed, just use PaddleNLP's native lora layers
from paddle_xpu.layers.nn.lora_layers import (
XPUColumnParallelLoRALinear as ColumnParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import (
XPUColumnSequenceParallelLoRALinear as ColumnSequenceParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import XPULoRALinear as LoRALinear
from paddle_xpu.layers.nn.lora_layers import (
XPURowParallelLoRALinear as RowParallelLoRALinear,
)
from paddle_xpu.layers.nn.lora_layers import (
XPURowSequenceParallelLoRALinear as RowSequenceParallelLoRALinear,
)
from .lora_layers import LoRAConv2D
else:
raise ImportError # Force to use the fallback if not XPU
except ImportError:
from .lora_layers import (
ColumnParallelLoRALinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
)
return {
"ColumnParallelLoRALinear": ColumnParallelLoRALinear,
"ColumnSequenceParallelLoRALinear": ColumnSequenceParallelLoRALinear,
"LoRAConv2D": LoRAConv2D,
"LoRALinear": LoRALinear,
"RowParallelLoRALinear": RowParallelLoRALinear,
"RowSequenceParallelLoRALinear": RowSequenceParallelLoRALinear,
}
lora_layers = get_lora_layers()
ColumnParallelLoRALinear = lora_layers["ColumnParallelLoRALinear"]
ColumnSequenceParallelLoRALinear = lora_layers["ColumnSequenceParallelLoRALinear"]
LoRAConv2D = lora_layers["LoRAConv2D"]
LoRALinear = lora_layers["LoRALinear"]
RowParallelLoRALinear = lora_layers["RowParallelLoRALinear"]
RowSequenceParallelLoRALinear = lora_layers["RowSequenceParallelLoRALinear"]
AVAILABLE_LAYERS = [
ColumnParallelLoRALinear,
ColumnSequenceParallelLoRALinear,
LoRAConv2D,
LoRALinear,
RowParallelLoRALinear,
RowSequenceParallelLoRALinear,
]
try:
from ...quantization.quantization_linear import (
ColumnParallelQuantizationLinear,
QuantizationLinear,
RowParallelQuantizationLinear,
)
from .lora_quantization_layers import (
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
)
AVAILABLE_LAYERS += [
ColumnParallelQuantizationLoRALinear,
QuantizationLoRALinear,
RowParallelQuantizationLoRALinear,
]
except:
QuantizationLinear = None
ColumnParallelQuantizationLinear = None
RowParallelQuantizationLinear = None
QuantizationLoRALinear = None
ColumnParallelQuantizationLoRALinear = None
RowParallelQuantizationLoRALinear = None
class LoRAModel(nn.Layer):
# TODO:lugimzzz support restore in following PR
restore_layer_map: Dict[nn.Layer, nn.Layer] = {
LoRALinear: nn.Linear,
LoRAConv2D: nn.Conv2D,
# ColumnParallelLoRALinear: ColumnParallelLinear,
# RowParallelLoRALinear: RowParallelLinear,
# QuantizationLoRALinear: QuantizationLinear,
}
def __init__(self, model, lora_config: LoRAConfig) -> None:
super().__init__()
self.quantized = False
self.lora_config = lora_config
self.lora_split_mapping = {}
if self.lora_config.dtype is None:
self.lora_config.dtype = paddle.get_default_dtype()
with dtype_guard(self.lora_config.dtype):
self.model = self.get_lora_model(model, lora_config)
self.is_pipelinemodel = False
if issubclass(type(self.model), PipelineLayer):
self.is_pipelinemodel = True
self.model._single_to_pp_mapping = None
if self.lora_config.tensor_parallel_degree != self.model.config.tensor_parallel_degree:
self.lora_config.tensor_parallel_degree = self.model.config.tensor_parallel_degree
logger.warning(
f"Reset tensor_parallel_degree of lora_config to {self.model.config.tensor_parallel_degree}."
)
self.forward = self.model.forward
logger.info("Mark only lora and trainable_module as trainable.")
self.mark_only_lora_as_trainable()
def add_lora_split_mapping(self, module_name, is_column=False):
self.lora_split_mapping[module_name] = is_column
def _get_tensor_parallel_mappings(self, config, is_split=True):
from paddlenlp.transformers.conversion_utils import split_or_merge_func
fn = split_or_merge_func(
is_split=is_split,
tensor_parallel_degree=config.tensor_parallel_degree,
tensor_parallel_rank=config.tensor_parallel_rank,
num_attention_heads=config.num_attention_heads,
)
rename_lora_split_mapping = {}
if issubclass(type(self.model), PipelineLayer):
# rename lora_split_mapping
prefixes = self.model.get_sequential_name_prefixes()
keys = self.lora_split_mapping.keys()
first_key = ""
for k in keys:
first_key = k
break
first_key = first_key.split(".")
use_virtual_pp_degree = first_key[0].isdigit() and first_key[1].isdigit()
for k in keys:
name_splited = k.split(".")
if use_virtual_pp_degree:
if name_splited[0].isdigit():
if name_splited[1].isdigit():
idx = str(int(name_splited[0]) + int(name_splited[1]))
single_name = [prefixes[idx]]
single_name.extend(name_splited[2:])
else:
single_name = [prefixes[str(len(prefixes) - 1)]]
single_name.extend(name_splited[2:])
logger.warning(
f"Please check! we treat this key as last layer, get {k}, set origin name as {'.'.join(single_name)}"
)
else:
raise ValueError(f"Please check! {k} is not a valid key.")
else:
idx = name_splited[0]
# for normal pp layer name
if idx.isdigit():
single_name = [prefixes[idx]]
single_name.extend(name_splited[1:])
else:
raise ValueError(f"Unexpected key: {k} for pp lora layer.")
rename_lora_split_mapping[".".join(single_name)] = self.lora_split_mapping[k]
lora_split_mapping = (
rename_lora_split_mapping if issubclass(type(self.model), PipelineLayer) else self.lora_split_mapping
)
def get_tensor_parallel_split_mappings():
final_actions = {}
for key, is_col in lora_split_mapping.items():
final_actions[key] = partial(fn, is_column=is_col)
return final_actions
mappings = get_tensor_parallel_split_mappings()
return mappings
@classmethod
def from_pretrained(cls, model, lora_path, **kwargs):
lora_config = kwargs.pop("lora_config", None)
# init lora config & lora model
if not isinstance(lora_config, LoRAConfig):
lora_config = LoRAConfig.from_pretrained(lora_path)
# define a new variable to conserve original lora_config.tensor_parallel_degree value which will update while initializing lora model
lora_config_tensor_parallel_degree = lora_config.tensor_parallel_degree
lora_model = cls(model, lora_config)
lora_model_index_file = os.path.join(lora_path, SAFE_PEFT_WEIGHTS_INDEX_NAME)
if os.path.exists(lora_model_index_file):
# load safetensors format file.
resolved_archieve_file, sharded_metadata = get_checkpoint_shard_files(
pretrained_model_name_or_path=lora_path,
index_filename=lora_model_index_file,
)
loaded_keys = sharded_metadata["all_checkpoint_keys"]
expected_keys = set(lora_model.get_trainable_state_dict().keys())
missing_keys = expected_keys - set(loaded_keys)
if len(missing_keys) > 0:
raise ValueError(f"missing_keys: {missing_keys}")
error_msgs = []
for shard_file in resolved_archieve_file:
pre_tensor_parallel_split = False
if model.config.tensor_parallel_degree > 1:
pre_tensor_parallel_split = True
tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True)
state_dict = load_state_dict(
shard_file, tp_actions if pre_tensor_parallel_split else None, expected_keys
)
error_msgs += _load_state_dict_into_model(lora_model.model, state_dict, "")
del state_dict
gc.collect()
if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs)
raise RuntimeError(
f"Error(s) in loading state_dict for {lora_model.__class__.__name__}:\n\t{error_msg}"
)
return lora_model
# define lora weight name
if lora_config_tensor_parallel_degree > 1:
lora_weight_name = _add_variant(LORA_WEIGHTS_NAME, f"tp{model.config.tensor_parallel_rank:0>2d}")
else:
lora_weight_name = LORA_WEIGHTS_NAME
# load and set lora weight parameter
lora_weight_path = os.path.join(lora_path, lora_weight_name)
if os.path.exists(lora_weight_path):
# load lora weight parameter
lora_state_dict = paddle.load(lora_weight_path, return_numpy=True)
logger.info(f"Loading the LoRA weights from {lora_weight_path}")
if (
lora_config_tensor_parallel_degree > 1
and lora_config_tensor_parallel_degree != model.config.tensor_parallel_degree
):
raise NotImplementedError(
f"{lora_config_tensor_parallel_degree} is not equal to {model.config.tensor_parallel_degree}. Please merge LoRA weights first."
)
# convert parameters to tensor parallel for mp model
if lora_config_tensor_parallel_degree <= 1 and model.config.tensor_parallel_degree > 1:
lora_state_dict = lora_model._convert_tensor_parallel(lora_state_dict=lora_state_dict)
# set lora state dict
lora_model.set_state_dict(lora_state_dict)
else:
logger.error(f"LoRA weights not found under {lora_path}, creating LoRA weights from scratch")
return lora_model
def set_state_dict(self, state_dict):
import warnings
warnings.filterwarnings(
action="ignore", message=".*Skip loading for.*", category=Warning, lineno=0, append=False
)
self.model.set_state_dict(state_dict)
logger.info("Load lora weight successfully")
def _merge_trainable_tensor_parallel(self, trainable_state_dict):
trainable_name_action_mappings = self._get_tensor_parallel_convert_actions(
trainable_state_dict.keys(), is_split=False
)
hcg = paddle.distributed.fleet.get_hybrid_communicate_group()
mp_group = hcg.get_model_parallel_group()
is_dst = paddle.distributed.get_rank(mp_group) == 0
for key in trainable_state_dict:
tensor = trainable_state_dict[key]
if key in trainable_name_action_mappings:
if get_env_device() == "xpu":
ret = distributed_allgather(tensor, group=mp_group, offload=True)
else:
ret = distributed_gather(tensor, group=mp_group, offload=True)
action = trainable_name_action_mappings[key]
if key in self.lora_split_mapping and not self.lora_split_mapping[key] and "_scale" in key and is_dst:
ret = paddle.to_tensor(ret)
tensor = paddle.max(ret, axis=0)
else:
tensor = action(ret) if is_dst else None
trainable_state_dict[key] = tensor
else:
trainable_state_dict[key] = tensor.cpu().numpy() if is_dst else None
return trainable_state_dict
def _get_tensor_parallel_convert_actions(self, loaded_keys, is_split=True, ignore_error=False, config=None):
if config is None:
config = self.model.config
specific_name_action_mappings = self._get_tensor_parallel_mappings(config, is_split=is_split)
name_action_mappings = self.model._get_tensor_parallel_mappings(config, is_split=is_split)
state_keys_map = ConversionMixin._resolve_prefix_keys(
name_action_mappings.keys(), self.model.state_dict().keys(), ignore_error=ignore_error
)
for k, v in state_keys_map.items():
if v in loaded_keys:
specific_name_action_mappings[v] = name_action_mappings[k]
return specific_name_action_mappings
def _convert_tensor_parallel(self, lora_state_dict):
lora_name_action_mappings = self._get_tensor_parallel_convert_actions(lora_state_dict.keys(), is_split=True)
for name, action in lora_name_action_mappings.items():
if name in lora_state_dict:
tensor = lora_state_dict.pop(name)
lora_state_dict[name] = action(tensor)
else:
logger.warning(f"{name} not found in lora_state_dict!")
return lora_state_dict
def save_pretrained(self, save_directory: str, merge_tensor_parallel: bool = False, **kwargs):
save_model_config = kwargs.get("save_model_config", True)
if self.is_pipelinemodel:
self.model._single_to_pp_mapping = None
if self.quantized and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
"Quantized strategy does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
)
if self.is_pipelinemodel and merge_tensor_parallel and self.lora_config.tensor_parallel_degree > 1:
merge_tensor_parallel = False
logger.warning(
"Pipeline parallism does not support merge_tensor_parallel. Set merge_tensor_parallel to False."
)
variant = kwargs.get("variant", None)
is_main_process = kwargs.get("is_main_process", paddle.distributed.get_rank() == 0)
assert not os.path.isfile(
save_directory
), f"Saving directory ({save_directory}) should be a directory, not a file"
os.makedirs(save_directory, exist_ok=True)
lora_config_to_save = LoRAConfig(**self.lora_config.to_dict())
if merge_tensor_parallel and lora_config_to_save.tensor_parallel_degree > 1:
trainable_state_dict = self.get_trainable_state_dict()
trainable_state_dict = self._merge_trainable_tensor_parallel(trainable_state_dict)
if not is_main_process:
logger.info("Saving with merge_tensor_parallel, tensor_parallel_rank > 0 don't need save")
return
if variant is not None and "tp" in variant:
variant = "_".join([x for x in variant.split("_") if "tp" not in x])
lora_config_to_save.tensor_parallel_degree = -1
else:
trainable_state_dict = self.get_trainable_state_dict()
if lora_config_to_save.tensor_parallel_degree > 1:
if variant is None:
variant = weight_name_suffix()
# save lora weight
lora_weight_name = _add_variant(LORA_WEIGHTS_NAME, variant)
weight_filename = os.path.join(save_directory, lora_weight_name)
paddle.save(trainable_state_dict, weight_filename)
# save lora config
if is_main_process:
lora_config_to_save.save_pretrained(save_directory)
if save_model_config:
model_config_to_save = copy.deepcopy(self.model.config)
if merge_tensor_parallel:
model_config_to_save.tensor_parallel_degree = -1
model_config_to_save.save_pretrained(save_directory)
def _find_and_replace_module(self, model, module_name, lora_config, enable_lora):
parent_module = model
attribute_chain = module_name.split(".")
for name in attribute_chain[:-1]:
parent_module = getattr(parent_module, name)
module = getattr(parent_module, attribute_chain[-1])
lora_module = None
if isinstance(module, nn.Linear):
lora_module = LoRALinear(
in_features=module.weight.shape[0],
out_features=module.weight.shape[1],
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
pissa=lora_config.pissa,
bias_attr=False if module.bias is None else None,
use_quick_lora=lora_config.use_quick_lora,
)
if isinstance(module, nn.Conv2D):
lora_module = LoRAConv2D(
in_channels=module._in_channels,
out_channels=module._out_channels,
kernel_size=module._kernel_size,
stride=module._stride,
padding=module._padding,
dilation=module._dilation,
groups=module._groups,
padding_mode=module._padding_mode,
data_format=module._data_format,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
bias_attr=module._bias_attr,
)
elif isinstance(module, ColumnParallelLinear):
# recover the original output_features
output_features = module.weight.shape[1] * module.world_size
lora_module = ColumnParallelLoRALinear(
in_features=module.weight.shape[0],
out_features=output_features,
gather_output=module.gather_output,
has_bias=module.bias is not None,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
pissa=lora_config.pissa,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora B matrix
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, RowParallelLinear):
# recover the original output_features
lora_module = RowParallelLoRALinear(
in_features=module.weight.shape[0] * module.world_size,
out_features=module.weight.shape[1],
has_bias=module.bias is not None,
input_is_parallel=module.input_is_parallel,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
pissa=lora_config.pissa,
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, linear_utils.ColumnSequenceParallelLinear):
# recover the original output_features
output_features = module.weight.shape[1] * module.world_size
lora_module = ColumnSequenceParallelLoRALinear(
in_features=module.weight.shape[0],
out_features=output_features,
gather_output=module.gather_output,
has_bias=module.bias is not None,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora B matrix
self.add_lora_split_mapping(module_name + ".lora_B", is_column=True)
# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=True)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif isinstance(module, linear_utils.RowSequenceParallelLinear):
# recover the original output_features
lora_module = RowSequenceParallelLoRALinear(
in_features=module.weight.shape[0] * module.world_size,
out_features=module.weight.shape[1],
has_bias=module.bias is not None,
input_is_parallel=module.input_is_parallel,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
rslora=lora_config.rslora,
lora_plus_scale=lora_config.lora_plus_scale,
use_quick_lora=lora_config.use_quick_lora,
)
# Lora column parallel will spilt lora A matrix
self.add_lora_split_mapping(module_name + ".lora_A", is_column=False)
# for lora qat
if self.lora_config.do_qat:
self.add_lora_split_mapping(module_name + ".weight_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter._scale", is_column=False)
self.add_lora_split_mapping(module_name + ".activation_quanter.quanter._scale", is_column=False)
elif QuantizationLinear is not None and isinstance(module, QuantizationLinear):
lora_module = QuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
block_size=module.block_size,
double_quant_block_size=module.double_quant_block_size,
double_quant=module.double_quant,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
self.quantized = True
elif ColumnParallelQuantizationLinear is not None and isinstance(module, ColumnParallelQuantizationLinear):
lora_module = ColumnParallelQuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
gather_output=module.gather_output,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
lora_A_weight_attr=paddle.ParamAttr(
initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu")
),
)
self.quantized = True
elif RowParallelQuantizationLinear is not None and isinstance(module, RowParallelQuantizationLinear):
lora_module = RowParallelQuantizationLoRALinear(
in_features=module.in_features,
out_features=module.out_features,
quant_algo=module.quant_algo,
dtype=module._dtype,
bias_attr=False if module.bias is None else None,
input_is_parallel=module.input_is_parallel,
r=lora_config.r,
lora_alpha=lora_config.lora_alpha,
lora_dropout=lora_config.lora_dropout,
)
self.quantized = True
if lora_module is None:
raise ValueError(
f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supportedγ"
)
if getattr(lora_module, "quant_weight", None) is not None:
lora_module.quant_weight = module.quant_weight
if getattr(lora_module, "quant_scale", None) is not None:
lora_module.quant_scale = module.quant_scale
if getattr(lora_module, "qquant_scale", None) is not None:
lora_module.qquant_scale = module.qquant_scale
if getattr(lora_module, "double_quant_scale", None) is not None:
lora_module.double_quant_scale = module.double_quant_scale
if getattr(lora_module, "quant_sacle_offset", None) is not None:
lora_module.quant_sacle_offset = module.quant_sacle_offset
else:
lora_module.weight = module.weight
if module.bias is not None:
lora_module.bias = module.bias
setattr(parent_module, attribute_chain[-1], lora_module)
def _find_and_restore_module(self, module_name):
parent_module = self.model
attribute_chain = module_name.split(".")
for name in attribute_chain[:-1]:
parent_module = getattr(parent_module, name)
module = getattr(parent_module, attribute_chain[-1])
original_model_class = self.restore_layer_map[module.__class__]
original_module = original_model_class(in_features=module.weight.shape[0], out_features=module.weight.shape[1])
original_module.weight = module.weight
if module.bias is not None:
original_module.bias = module.bias
setattr(parent_module, attribute_chain[-1], original_module)
def get_trainable_state_dict(self):
trainable_state_dict = OrderedDict()
for name, weight in self.model.state_dict().items():
# get lora parameter & QAT scale parameter
if not weight.stop_gradient or "activation_quanter" in name or "weight_quanter" in name:
trainable_state_dict[name] = weight
return trainable_state_dict
def print_trainable_parameters(self) -> None:
freeze_numel = 0
trainable_numel = 0
for _, weight in self.model.state_dict().items():
if weight.stop_gradient:
freeze_numel += np.prod(weight.shape)
else:
trainable_numel += np.prod(weight.shape)
logger.debug(
f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel+trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel+trainable_numel):.2%}"
)
def mark_only_lora_as_trainable(self) -> None:
for _, layer in self.model.named_sublayers():
if (
isinstance(layer, LoRALinear)
or isinstance(layer, LoRAConv2D)
or isinstance(layer, ColumnParallelLoRALinear)
or isinstance(layer, RowParallelLoRALinear)
or isinstance(layer, ColumnSequenceParallelLoRALinear)
or isinstance(layer, RowSequenceParallelLoRALinear)
or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear))
or (
ColumnParallelQuantizationLoRALinear is not None
and isinstance(layer, ColumnParallelQuantizationLoRALinear)
)
or (
RowParallelQuantizationLoRALinear is not None
and isinstance(layer, RowParallelQuantizationLoRALinear)
)
):
for name, weight in layer.state_dict().items():
if self.lora_config.trainable_bias in ["lora", "all"] and "bias" in name:
weight.stop_gradient = False
elif "lora" in name:
weight.stop_gradient = False
else:
weight.stop_gradient = True
else:
for name, weight in layer.state_dict().items():
if self.lora_config.trainable_bias == "all" and "bias" in name:
weight.stop_gradient = False
else:
weight.stop_gradient = True
if self.lora_config.trainable_modules is not None:
for name, weight in self.model.state_dict().items():
if any(
re.fullmatch(trainable_module, name) for trainable_module in self.lora_config.trainable_modules
):
weight.stop_gradient = False
def get_lora_model(self, model: Union[PretrainedModel, nn.Layer], lora_config: LoRAConfig):
if lora_config.target_modules is None:
return model
elif isinstance(lora_config.target_modules, str):
target_modules = [lora_config.target_modules]
if lora_config.enable_lora_list is None or (
isinstance(lora_config.enable_lora_list, List)
and all(isinstance(item, bool) for item in lora_config.enable_lora_list)
):
enable_lora_list = [lora_config.enable_lora_list]
else:
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `str`, `enable_lora_list` must be `None` or `List[bool]`"
)
else:
target_modules = lora_config.target_modules
if lora_config.enable_lora_list is None:
enable_lora_list = [None for _ in range(len(target_modules))]
elif isinstance(lora_config.enable_lora_list, List):
enable_lora_list = lora_config.enable_lora_list
if len(enable_lora_list) != len(target_modules):
raise TypeError(
f"Invalid lora_config.enable_lora_list value: {lora_config.enable_lora_list}. Since lora_config.target_modules is `List[str]`, `enable_lora_list` should have the same length as `target_modules`"
)
for enable_lora in enable_lora_list:
if not (
enable_lora is None
or (isinstance(enable_lora, List) and all(isinstance(item, bool) for item in enable_lora))
):
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
else:
raise TypeError(
f"Invalid `enable_lora_list` value: {lora_config.enable_lora_list}. Since `target_modules` is `List[str]`, `enable_lora_list` must be `None` or `List[Optional[List[bool]]]`"
)
for target_module, enable_lora in zip(target_modules, enable_lora_list):
for i in model.named_sublayers():
module_name = i[0]
if re.fullmatch(target_module, module_name):
self._find_and_replace_module(model, module_name, lora_config, enable_lora)
return model
def restore_original_model(self):
# make sure W and lora weights are not merged before we restore the original model
for layer_name, layer in self.model.named_sublayers():
if isinstance(layer, LoRALinear):
self._find_and_restore_module(layer_name)
elif (
isinstance(layer, ColumnParallelLoRALinear)
or isinstance(layer, ColumnSequenceParallelLoRALinear)
or isinstance(layer, LoRAConv2D)
or isinstance(layer, RowParallelLoRALinear)
or isinstance(layer, RowSequenceParallelLoRALinear)
or (QuantizationLoRALinear is not None and isinstance(layer, QuantizationLoRALinear))
or (
ColumnParallelQuantizationLoRALinear is not None
and isinstance(layer, ColumnParallelQuantizationLoRALinear)
)
or (
RowParallelQuantizationLoRALinear is not None
and isinstance(layer, RowParallelQuantizationLoRALinear)
)
):
raise NotImplementedError(f"{layer} restoration is not supported yet.")
return self.model
def __getattr__(self, name: str):
"""Forward missing attributes to the wrapped module."""
try:
return super().__getattr__(name) # defer to nn.Layer's logic
except AttributeError:
return getattr(self.model, name)
def train(self):
self.training = True
self.model.training = True
for layer in self.model.sublayers():
layer.training = True
layer.train()
def eval(self):
self.training = False
self.model.training = False
for layer in self.model.sublayers():
layer.training = False
layer.eval()
def save_to_aistudio(
self,
repo_id,
private=True,
license="Apache License 2.0",
exist_ok=True,
subfolder=None,
merge_tensor_parallel=False,
**kwargs
):
"""
Uploads all elements of this model to a new AiStudio Hub repository.
Args:
repo_id (str): Repository name for your model/tokenizer in the Hub.
token (str): Your token for the Hub.
private (bool, optional): Whether the model/tokenizer is set to private. Defaults to True.
license (str): The license of your model/tokenizer. Defaults to: "Apache License 2.0".
exist_ok (bool, optional): Whether to override existing repository. Defaults to: True.
subfolder (str, optional): Push to a subfolder of the repo instead of the root
merge_tensor_parallel (bool): Whether to merge the tensor parallel weights. Defaults to False.
"""
res = aistudio_sdk.hub.create_repo(repo_id=repo_id, private=private, license=license, **kwargs)
if "error_code" in res:
if res["error_code"] == 10003 and exist_ok:
logger.info(
f"Repo {repo_id} already exists, it will override files with the same name. To avoid this, please set exist_ok=False"
)
else:
logger.error(
f"Failed to create repo {repo_id}, error_code: {res['error_code']}, error_msg: {res['error_msg']}"
)
else:
logger.info(f"Successfully created repo {repo_id}")
with tempfile.TemporaryDirectory() as root_dir:
if subfolder is not None:
save_dir = os.path.join(root_dir, subfolder)
else:
save_dir = root_dir
# save model
self.save_pretrained(save_dir, merge_tensor_parallel=merge_tensor_parallel)
# Upload model and return
logger.info(f"Pushing to the {repo_id}. This might take a while")
for filename in os.listdir(save_dir):
res = aistudio_sdk.hub.upload(
repo_id=repo_id, path_or_fileobj=os.path.join(save_dir, filename), path_in_repo=filename, **kwargs
)
if "error_code" in res:
logger.error(
f"Failed to upload {filename}, error_code: {res['error_code']}, error_msg: {res['error_msg']}"
)
else:
logger.info(f"{filename}: {res['message']}")
def disable_lora(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.disable_lora = True
def enable_lora(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.disable_lora = False
def merge(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.merge()
def unmerge(self):
for _, layer in self.model.named_sublayers():
if any(isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS):
layer.unmerge()