diff --git a/llm/data.py b/llm/data.py index 767bd5a88a29..c49a84cc0caa 100644 --- a/llm/data.py +++ b/llm/data.py @@ -45,16 +45,53 @@ def get_convert_example(model): if base_model_prefix == "chatglm": return convert_example_chatglm elif base_model_prefix in ["chatglm_v2", "llama", "bloom", "opt", "qwen", "mixtral", "gemma"]: - return convert_example_common + return convert_example_common_meta_text else: raise ValueError( f"Unknown base_model_prefix: {model.base_model_prefix}. Supported base_model_prefix list: chatglm, bloom, llama, qwen, mixtral, gemma" ) -class DataFormatError(ValueError): +class DataFormatError( ValueError): pass +def tokenize_example_meta_text(tokenizer, example, data_args): + if "text" in example: + text = example["text"] + source = text + words = text.split(' ') + source = ' '.join(words) + if len(words) > 1: + # remove the first word in the sentence + target = ' '.join(words[1:]) + else: + target = '' + else: + raise DataFormatError( + f"Example format is wrong, please check: {example} or rewrite tokenize_example in data.py " + ) + tokenized_source = tokenizer( + source, + max_length=data_args.src_length, + truncation=True, + truncation_side="left", + add_special_tokens=True, + ) + tgt_max_length = data_args.max_length - len(tokenized_source["input_ids"]) + tokenized_target = tokenizer( + target, + max_length=tgt_max_length, + truncation=True, + truncation_side="left", + add_special_tokens=False, + ) + tokenized_target_input_ids = tokenized_target["input_ids"] + # Add eos_token_id at the end of sequence if the sentence is not truncated. + # Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id. + if len(tokenized_target_input_ids) < tgt_max_length: + tokenized_target_input_ids += [tokenizer.eos_token_id] + return tokenized_source, tokenized_target_input_ids + def tokenize_example(tokenizer, example, data_args): if "src" in example and "tgt" in example: @@ -77,7 +114,7 @@ def tokenize_example(tokenizer, example, data_args): target, max_length=tgt_max_length, truncation=True, - truncation_side="right", + truncation_side="left", add_special_tokens=False, ) @@ -86,10 +123,31 @@ def tokenize_example(tokenizer, example, data_args): # Attention! In some cases(ex. ChatGLMv2), tokenized eos_token is not equal to eos_token_id. if len(tokenized_target_input_ids) < tgt_max_length: tokenized_target_input_ids += [tokenizer.eos_token_id] - return tokenized_source, tokenized_target_input_ids - +def convert_example_common_meta_text(example, tokenizer, data_args, is_test=True, intokens=False): + if tokenizer.chat_template is not None: + return convert_rounds_example_common(example, tokenizer, data_args, is_test, intokens) + tokenized_source, tokenized_target_input_ids = tokenize_example_meta_text(tokenizer, example, data_args) + if is_test: + return { + **tokenized_source, + "labels": tokenized_target_input_ids, + } + else: + input_ids = tokenized_source["input_ids"] + tokenized_target_input_ids + source_length = len(tokenized_source["input_ids"]) + labels = [-100] * source_length + input_ids[source_length:] + # shift input_ids and labels + input_ids, labels = input_ids[:-1], labels[1:] + seq_length = len(input_ids) + features = {"input_ids": input_ids, "labels": labels} + if "position_ids" in tokenized_source: + features["position_ids"] = list(range(seq_length)) + if intokens: + features["attention_mask"] = np.tri(seq_length, seq_length, dtype=bool) + return features + def tokenize_rounds_example(tokenizer, example, data_args): """tokenize multi-rounds examples with chat_template.json diff --git a/llm/finetune_generation.py b/llm/finetune_generation.py index df7a22a0cb95..0bd2ac34476a 100644 --- a/llm/finetune_generation.py +++ b/llm/finetune_generation.py @@ -16,7 +16,7 @@ import sys from dataclasses import dataclass, field from functools import partial - +import math import paddle from argument import ( DataArgument, @@ -34,8 +34,9 @@ get_prefix_tuning_params, init_chat_template, ) - -from paddlenlp.data import DataCollatorForSeq2Seq +from glm.utils import GLMTrainer +# from llama_attn_replace_paddle import replace_llama_attn +from paddlenlp.data import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling from paddlenlp.datasets import InTokensIterableDataset, InTokensMapDataset, load_dataset from paddlenlp.metrics import BLEU, Rouge1, Rouge2, RougeL from paddlenlp.peft import LoRAConfig, LoRAModel, PrefixConfig, PrefixModelForCausalLM @@ -45,7 +46,6 @@ AutoConfig, AutoModelForCausalLM, AutoTokenizer, - Llama3Tokenizer, LlamaTokenizer, ) from paddlenlp.utils.log import logger @@ -66,6 +66,14 @@ class FinetuneArguments(TrainingArguments): default=0, metadata={"help": "The steps use to control the learing rate."}, ) + sparse: bool = field( + default=True, + metadata={"help": "The steps use to control the learing rate."}, + ) + trainable_params: str = field( + default="embed,norm", + metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, + ) def read_local_dataset(path): @@ -115,6 +123,7 @@ def main(): "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." ) + # Load model if training_args.fp16_opt_level == "O2": if training_args.fp16: @@ -193,10 +202,20 @@ def main(): # Config for model using dropout, such as GPT. model_config.hidden_dropout_prob = model_args.hidden_dropout_prob model_config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob - model_config.sep_parallel_degree = training_args.sep_parallel_degree model_config.tensor_parallel_output = True model_config.seq_length = data_args.max_length + + #set RoPE scaling factor + orig_rope_scaling_factor = model_config.rope_scaling_factor + orig_ctx_len = model_config.max_position_embeddings + if orig_ctx_len: + orig_ctx_len *= orig_rope_scaling_factor + if data_args.max_length > orig_ctx_len: + scaling_factor = float(math.ceil(data_args.max_length / orig_ctx_len)) + model_config.rope_scaling_factor = scaling_factor + model_config.rope_scaling_type = "linear" + if not training_args.autotuner_benchmark: model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, @@ -233,7 +252,7 @@ def neft_post_hook(module, input, output): if tokenizer.chat_template is not None: data_args.eval_with_do_generation = False - if isinstance(tokenizer, LlamaTokenizer) or isinstance(tokenizer, Llama3Tokenizer): + if isinstance(tokenizer, LlamaTokenizer): tokenizer.pad_token_id = tokenizer.eos_token_id if data_args.dataset_name_or_path is None: @@ -297,11 +316,7 @@ def neft_post_hook(module, input, output): else: train_ds = None if training_args.do_eval: - dev_ds = load_dataset( - "json", - data_files=glob.glob(os.path.join(data_args.dataset_name_or_path, "dev", "*.json")), - lazy=data_args.lazy, - )[0] + dev_ds = load_dataset('/home/mg/proof-pile',splits=["test"],cache_dir='proof-pile')[0] else: dev_ds = None if quant_args.do_ptq or quant_args.do_gptq: @@ -330,7 +345,7 @@ def neft_post_hook(module, input, output): else: train_ds = None if training_args.do_eval: - dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["dev"])[0] + dev_ds = load_dataset(data_args.dataset_name_or_path, splits=["test"])[0] else: dev_ds = None if quant_args.do_ptq or quant_args.do_gptq: @@ -361,7 +376,6 @@ def neft_post_hook(module, input, output): if training_args.pipeline_parallel_degree > 1: from data import convert_example_common - trans_func = partial(convert_example_common, tokenizer=tokenizer, data_args=data_args) else: trans_func = partial(get_convert_example(model), tokenizer=tokenizer, data_args=data_args) @@ -388,6 +402,7 @@ def neft_post_hook(module, input, output): "`zero_padding` conflicts with `eval_with_do_generation`. Setting zero_padding to False for the eval_dataset." ) eval_intokens = False + dev_ds = ( dev_ds.map(partial(trans_func, is_test=data_args.eval_with_do_generation, intokens=eval_intokens)) if dev_ds is not None @@ -458,7 +473,7 @@ def neft_post_hook(module, input, output): if model_args.lora: if model_args.lora_path is None: - target_modules = get_lora_target_modules(model) + target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] lora_config = LoRAConfig( target_modules=target_modules, r=model_args.lora_rank, @@ -474,9 +489,27 @@ def neft_post_hook(module, input, output): use_quick_lora=model_args.use_quick_lora, ) model = LoRAModel(model, lora_config) + + # model.mark_only_lora_as_trainable() + # model.print_trainable_parameters() + trainable_keywords = ["embed","norm"] + # set embedding and norm trainable + for name, param in model.named_parameters(): + make_trainable = False + for keyword in trainable_keywords: + if keyword in name: + make_trainable = True + break + if make_trainable: + param.stop_gradient = False + model.config.use_cache = False + + model.recompute_enable() + for param in model.parameters(): + if not param.stop_gradient and param.grad is None: + param.clear_gradient() else: model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) - model.print_trainable_parameters() def compute_metrics_do_generation(eval_preds): @@ -532,13 +565,18 @@ def compute_metrics_do_generation(eval_preds): eval_dataset=dev_ds, tokenizer=tokenizer, compute_metrics=metrics, + # data_collator=DataCollatorForLanguageModeling( + # tokenizer=tokenizer, + # return_tensors="np", + # mlm=False, + # pad_to_multiple_of=data_args.max_length, + # ), data_collator=DataCollatorForSeq2Seq( tokenizer=tokenizer, - max_length=max_length, - padding=padding, - max_label_length=max_length, return_tensors="np", - pad_to_multiple_of=data_args.pad_to_multiple_of, + max_length=max_length, + padding=True, + pad_to_multiple_of=data_args.max_length, ), do_generation=data_args.eval_with_do_generation, callbacks=[InTokensIterDatasetCallback()] if isinstance(train_ds, InTokensIterableDataset) else None, @@ -553,7 +591,8 @@ def compute_metrics_do_generation(eval_preds): checkpoint = training_args.resume_from_checkpoint elif last_checkpoint is not None: checkpoint = last_checkpoint - train_result = trainer.train(resume_from_checkpoint=checkpoint) + # train_result = trainer.train(resume_from_checkpoint=checkpoint) + train_result = trainer.train() if model_args.neftune: neft_post_hook_handle.remove() if training_args.benchmark: diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 81eee3f83539..cc48ad77aac3 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -62,10 +62,6 @@ def swiglu(x, y=None): init_name_mappings, ) from paddlenlp.transformers.long_sequence_strategies import LongSequenceStrategies -from paddlenlp.transformers.mc2_parallel_linear import ( - MC2ColumnSeqParallelLinear, - MC2RowSeqParallelLinear, -) from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -100,6 +96,13 @@ def swiglu(x, y=None): ] +def is_mc2_valid(): + current_device = get_env_device() + if current_device == "npu": + return True + return False + + def _get_interleave(n): def _get_interleave_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) @@ -300,7 +303,6 @@ def scaled_dot_product_attention( # In sep mode, the attenion mask should be created in the runtime. if reshard_layer is not None: attention_mask = None - # NOTE: we only call get_triangle_upper_mask under PP setup # FIXME ZHUI when we use pipeline parallel, the attention_mask can be None # we just make it triangle_upper_mask @@ -311,17 +313,17 @@ def scaled_dot_product_attention( raise ValueError( f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" ) - attn_weights = attn_weights + attention_mask if not paddle.in_dynamic_mode(): attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) else: with paddle.amp.auto_cast(False): attn_weights = F.softmax(attn_weights, axis=-1, dtype="float32").astype(query_states.dtype) - attn_output = paddle.matmul(attn_weights, value_states) attn_output = attn_output.transpose([0, 2, 1, 3]) - + # attn_output = attn_output.reshape((bsz, q_len, self.num_heads, self.head_dim)) + # shift back + # attn_output[:, :, num_heads//2:] = attn_output[:, :, num_heads//2:].roll(shifts=q_len//8, axis=1) if reshard_layer is not None: attn_output = reshard_layer( attn_output, @@ -330,7 +332,6 @@ def scaled_dot_product_attention( ) q_len = q_len // config.sep_parallel_degree num_heads = num_heads * config.sep_parallel_degree - if sequence_parallel: attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) else: @@ -562,6 +563,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed + class LlamaMLP(nn.Layer): def __init__(self, config): super().__init__() @@ -571,7 +573,12 @@ def __init__(self, config): self.fuse_attention_ffn = config.fuse_attention_ffn if config.sequence_parallel: - if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: + if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): + from paddlenlp.transformers.mc2_seqence_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, + ) + ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear else: @@ -689,7 +696,12 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = False if config.sequence_parallel: - if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: + if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): + from paddlenlp.transformers.mc2_seqence_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, + ) + ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear else: @@ -800,28 +812,24 @@ def _init_rope(self): self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - base=self.config.rope_theta, ) elif self.config.rope_scaling_type == "linear": self.rotary_emb = LlamaLinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling_factor, - base=self.config.rope_theta, ) elif self.config.rope_scaling_type == "ntk": self.rotary_emb = LlamaNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling_factor, - base=self.config.rope_theta, ) elif self.config.rope_scaling_type == "dynamic_ntk": self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling_factor, - base=self.config.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {self.config.rope_scaling_type}") @@ -838,20 +846,34 @@ def forward( ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: """Input shape: Batch x Time x Channel""" # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) - + def create_attention_mask(input_shape, dtype): + """ + Prepare the decoder attention mask where the diagonal and its lower side are 0, + and the upper side is -inf. + + Args: + input_shape (tuple): Shape of the input tensor, typically (batch_size, sequence_length). + dtype (paddle.dtype): Data type of the mask, usually 'float32' or 'float16'. + + Returns: + paddle.Tensor: Attention mask with shape [batch_size, 1, sequence_length, sequence_length]. + """ + # Assuming input_shape = (batch_size, sequence_length) + batch_size, seq_length = input_shape + # Create a lower triangular matrix including the diagonal, where the diagonal and below are 0 (allowed positions) + lower_triangular = paddle.tril(paddle.zeros((seq_length, seq_length), dtype=dtype)) + # Create the opposite mask for positions above the diagonal, set these to -inf + upper_triangular = paddle.triu(paddle.ones((seq_length, seq_length), dtype=dtype), diagonal=1) + # paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) + upper_triangular = paddle.where(upper_triangular == 1, paddle.finfo(dtype).min, 0.0).astype(dtype) + # Combine the two masks + attention_mask = lower_triangular + upper_triangular + # Add batch and head dimensions + attention_mask = attention_mask[None, None, :, :] # Expanding the dimensions to [1, 1, seq_length, seq_length] + attention_mask = paddle.tile(attention_mask, [batch_size, 1, 1, 1]) # Tiling for the batch size + return attention_mask if self.fuse_attention_qkv: mix_layer = self.qkv_proj(hidden_states) - # NOTE for GQA attention fusion (compatible with MHA and MQA): - # The weight for qkv_proj is in shape like [hidden_size, hidden_size + 2 * num_kv_heads * head_dim]. - # After the projection, the mix_layer is in shape like [b, s, hidden_size + 2 * num_kv_heads * head_dim]. - # Reshape the mix_layer into a shape like [b, s, num_kv_heads, (num_groups + 2) * head_dim], - # where num_groups = num_q_heads // num_kv_heads. - # Split the mix_layer on the last axis into three sections [num_groups * head_dim, head_dim, head_dim] - # to represent the q, k and v respectively. - # The q is in the shape like [b, s, num_kv_heads, num_groups * head_dim]. - # The k and v are in the shape like [b, s, num_kv_heads, head_dim]. - # Under MHA, the q is ready for the following calculation since num_kv_heads == num_q_heads, - # But for the GQA or MQA, q should be reshaped into [b, s, num_q_heads, head_dim]. if self.reshard_layer is not None: if self.sequence_parallel: assert self.seq_length % self.config.sep_parallel_degree == 0 @@ -894,7 +916,6 @@ def forward( query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - if self.reshard_layer is not None: if self.sequence_parallel: assert self.seq_length % self.config.sep_parallel_degree == 0 @@ -948,12 +969,11 @@ def forward( query_states = query_states.reshape(shape=target_query_shape) key_states = key_states.reshape(shape=target_key_value_shape) value_states = value_states.reshape(shape=target_key_value_shape) - + bsz, q_len, _ = hidden_states.shape + group_size = int(q_len * 1/4) kv_seq_len = key_states.shape[-3] - if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-3] - if self.config.rope: if self.reshard_layer is not None: batch_size, seq_length, _, _ = query_states.shape @@ -1019,6 +1039,7 @@ def forward( value_states = paddle.concat([past_key_value[1], value_states], axis=1) past_key_value = (key_states, value_states) if use_cache else None + if self.kv_indices is not None: key_states = paddle.index_select(key_states, self.kv_indices, axis=2) value_states = paddle.index_select(value_states, self.kv_indices, axis=2) @@ -1027,10 +1048,27 @@ def forward( # repeat k/v heads if n_kv_heads < n_heads # paddle version > 2.6 or develop support flash-attn with gqa/mqa paddle_version = float(paddle.__version__[:3]) - if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)): + if (paddle_version != 0.0) and (paddle_version <= 2.6): key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - + + def shift(qkv, bsz, q_len, group_size, num_heads, head_dim): + qkv[:, :, num_heads//2:] = paddle.roll(qkv[:, :, num_heads//2:], shifts=-group_size//2, axis=1) + qkv = paddle.reshape(qkv, (bsz * (q_len // group_size), group_size, num_heads, head_dim)) + return qkv + + #default + is_shift = True + if is_shift: + query_states = shift(query_states, bsz, q_len, group_size, self.num_heads, self.head_dim) + key_states = shift(key_states, bsz, q_len, group_size, self.num_heads, self.head_dim) + value_states = shift(value_states, bsz, q_len, group_size, self.num_heads, self.head_dim) + _,group_size,_,_ = query_states.shape + # print(attention_mask) + num_group = q_len // group_size + attention_mask = create_attention_mask((bsz,q_len),dtype="float16") + attention_mask = attention_mask[:, :, :group_size, :group_size] + attention_mask = paddle.tile(attention_mask,repeat_times=(num_group, 1, 1, 1)) has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) if ( self.enable_recompute @@ -1067,11 +1105,19 @@ def forward( attn_output, attn_weights = outputs else: attn_output = outputs - + #shift back + is_shift=True + if is_shift: + attn_output = paddle.reshape(attn_output, (bsz * num_group, self.num_heads, + group_size, self.head_dim)) + attn_output = paddle.transpose(attn_output, [0,2,1,3]) + attn_output = paddle.reshape(attn_output, (bsz, q_len, self.num_heads, self.head_dim)) + attn_output[:, :, self.num_heads//2:] = paddle.roll(attn_output[:, :, self.num_heads//2:], + shifts=group_size//2, axis=1) + attn_output = paddle.reshape(attn_output,(bsz, q_len, self.hidden_size)) # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. attn_output = self.o_proj(attn_output) - if not output_attentions: attn_weights = None @@ -1293,56 +1339,6 @@ def get_tensor_parallel_split_mappings(num_layers): return mappings - @classmethod - def _get_fuse_or_split_param_mappings(cls, config: LlamaConfig, is_fuse=False): - # return parameter fuse utils - from paddlenlp.transformers.conversion_utils import split_or_fuse_func - - fn = split_or_fuse_func(is_fuse=is_fuse) - - # last key is fused key, other keys are to be fused. - fuse_qkv_keys = ( - "layers.0.self_attn.q_proj.weight", - "layers.0.self_attn.k_proj.weight", - "layers.0.self_attn.v_proj.weight", - "layers.0.self_attn.qkv_proj.weight", - ) - - fuse_gate_up_keys = ( - "layers.0.mlp.gate_proj.weight", - "layers.0.mlp.up_proj.weight", - "layers.0.mlp.gate_up_fused_proj.weight", - ) - num_heads = config.num_attention_heads - num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) - fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) - fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) - - final_actions = {} - if is_fuse: - if fuse_attention_qkv: - for i in range(config.num_hidden_layers): - keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys]) - final_actions[keys] = partial( - fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads - ) - if fuse_attention_ffn: - for i in range(config.num_hidden_layers): - keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) - final_actions[keys] = fn - else: - if not fuse_attention_qkv: - for i in range(config.num_hidden_layers): - keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_qkv_keys]) - final_actions[keys] = partial( - fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads - ) - if not fuse_attention_ffn: - for i in range(config.num_hidden_layers): - keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) - final_actions[keys] = partial(fn, split_nums=2) - return final_actions - def _init_weights(self, layer): """Initialization hook""" if self.config.tensor_parallel_degree > 1: @@ -1589,7 +1585,6 @@ def forward( if position_ids is None: position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) - attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype ) # [bs, 1, seq_len, seq_len] @@ -1601,6 +1596,7 @@ def forward( else: attention_mask = attention_mask.astype("bool") hidden_states = inputs_embeds + # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None @@ -1706,7 +1702,6 @@ def forward(self, prediction_scores, masked_lm_labels): # skip ignore_index which loss == 0 masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] loss = paddle.mean(masked_lm_loss) - return loss