diff --git a/paddlenlp/transformers/convert/doc/README.md b/paddlenlp/transformers/convert/doc/README.md new file mode 100644 index 000000000000..8623ddff561c --- /dev/null +++ b/paddlenlp/transformers/convert/doc/README.md @@ -0,0 +1,268 @@ + + +## 使用方法: + +### 1构建modular__**.py文件 + +#### 1.1分析基础模型 + +​ 在开始构建 `modular_xxx.py`文件之前,首先需要深入分析要基于的**基础模型**。这个基础模型通常是 paddle/ Transformers 库中已有的成熟模型。 + +##### 1.1.1 选择合适的基础模型 + +**选择标准:** + +- **架构相似性**:新模型与基础模型的架构应尽可能相似 +- **任务类型**:基础模型应支持相同的任务(如文本生成、分类等) +- **代码质量**:选择代码结构清晰、文档完善的模型 + +**常见基础模型选择:** + +``` +# 基于BERT架构的模型 +基础模型:BertModel, RobertaModel, DebertaModel + +# 基于GPT架构的模型 +基础模型:GPT2Model, LlamaModel, GPTNeoXModel + +# 基于Encoder-Decoder架构的模型 +基础模型:T5Model, BartModel, PegasusModel +``` + +##### 1.1.2 分析基础模型的关键组件 + +对于选定的基础模型,需要分析其核心组件: + +###### **1. 配置文件 (`configuration_xxx.py`)** + +``` +# 分析配置参数 +# 关注:hidden_size, num_attention_heads, num_hidden_layers, +# vocab_size, max_position_embeddings 等关键参数 +``` + +###### **2. 模型架构 (`modeling_xxx.py`)** + +``` +# 分析模型类结构 +import inspect +from transformers import BertModel + +# 查看类的方法和属性 +print(inspect.getmembers(BertModel, predicate=inspect.ismethod)) +# 重点关注:__init__, forward, 以及其他关键方法 +``` + +##### 1.1.3 识别需要修改的部分 + +基于分析结果,确定哪些部分需要自定义: + +| 组件 | 是否需要修改 | 修改原因 | +| :--------------- | :----------- | :------------------------- | +| **配置参数** | ✅ 通常需要 | 调整模型尺寸、注意力头数等 | +| **前向传播逻辑** | ✅ 通常需要 | 适配新的架构变化 | +| **注意力机制** | ⚠️ 可能需要 | 如果使用不同的注意力机制 | +| **位置编码** | ⚠️ 可能需要 | 如果使用不同的位置编码方案 | +| **输出头** | ✅ 通常需要 | 适配不同的任务需求 | +| **初始化方法** | ⚠️ 可能需要 | 如果使用不同的初始化策略 | + +#### 1.2编写modular文件结构 + +​ 在完成基础模型分析后,您需要创建一个结构清晰、符合规范的 `modular_xxx.py`文件。这个文件是代码生成器的模板,其结构直接决定了最终输出的 `modeling_xxx.py`文件的质量。 + +##### 1.2.1 文件基本结构 + +一个标准的 `modular_xxx.py`文件应包含以下部分,按顺序排列: + +``` +# coding=utf-8 +# 版权声明 (可选) +""" 新模型的简要文档字符串 (可选) """ + +# 1. 导入部分 +from typing import Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss + +# 从基础模型导入必要的组件 +from transformers.models.llama.modeling_llama import ( + LlamaConfig, + LlamaModel, + LlamaForCausalLM, + LlamaDecoderLayer, + # ... 其他需要继承或引用的组件 +) +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +# 3. 注意力机制 (如果需要自定义) +class MyNewAttention(nn.Module): + """自定义注意力机制""" + def __init__(self, config: MyNewModelConfig): + super().__init__() + # 实现自定义注意力逻辑 + pass + + def forward(self, hidden_states, attention_mask=None): + # 实现前向传播 + pass + + +# 4. 解码器层 (如果需要修改层结构) +class MyNewDecoderLayer(LlamaDecoderLayer): + """ + 自定义解码器层,继承自LlamaDecoderLayer + 重写需要修改的方法 + """ + def __init__(self, config: MyNewModelConfig): + super().__init__(config) + # 替换或修改注意力机制 + if config.use_custom_attention: + self.self_attn = MyNewAttention(config) + + def forward(self, hidden_states, attention_mask=None): + # 可以完全重写或部分修改父类逻辑 + if self.config.use_custom_attention: + # 自定义逻辑 + return self._custom_forward(hidden_states, attention_mask) + else: + # 回退到父类逻辑 + return super().forward(hidden_states, attention_mask) + + def _custom_forward(self, hidden_states, attention_mask): + """自定义前向传播实现""" + pass + + +# 5. 主模型类 +class MyNewModel(LlamaModel): + """ + 我的新模型主类,继承自LlamaModel + 通常需要重写 __init__ 和 forward 方法 + """ + def __init__(self, config: MyNewModelConfig): + super().__init__(config) + # 替换解码器层 + self.layers = nn.ModuleList([ + MyNewDecoderLayer(config) for _ in range(config.num_hidden_layers) + ]) + # 其他自定义初始化 + self.custom_layer = nn.Linear(config.hidden_size, config.custom_param) + + def forward(self, input_ids, attention_mask=None): + # 调用父类获取基础输出 + super().forward(input_ids, attention_mask) + + # 添加自定义处理 + hidden_states = outputs[0] + custom_output = self.custom_layer(hidden_states) + + # 返回修改后的输出 + return (custom_output,) + outputs[1:] + + +# 6. 任务特定模型 (如用于因果语言建模) +class MyNewForCausalLM(LlamaForCausalLM): + """ + 用于因果语言建模的我的新模型 + """ + def __init__(self, config: MyNewModelConfig): + super().__init__(config) + # 替换主模型 + self.model = MyNewModel(config) + + def forward(self, input_ids, attention_mask=None, labels=None): + # 可以完全重写或扩展父类逻辑 + outputs = self.model(input_ids, attention_mask=attention_mask) + + # 计算损失等 + loss = None + if labels is not None: + # 计算损失逻辑 + pass + + return {"loss": loss, "logits": outputs[0]} + + +# 8. 更新 __all__ 列表,声明哪些类应该被导出 +__all__ = [ + "MyNewModelConfig", + "MyNewModel", + "MyNewForCausalLM", + "MyNewDecoderLayer", +] +``` + +##### 1.2.2 关键编写原则 + +**清晰的继承关系** + +``` +# ✅ 正确:明确继承关系 +class MyNewModel(LlamaModel): + pass + +# ❌ 避免:直接继承过于通用的基类 +class MyNewModel(PreTrainedModel): + pass # 这会导致需要实现大量抽象方法 +``` + +**最小化重写** + +``` +# ✅ 正确:只重写需要修改的方法 +class MyNewDecoderLayer(LlamaDecoderLayer): + def __init__(self, config): + super().__init__(config) # 先调用父类初始化 + # 只修改需要定制的部分 + if config.use_custom_attention: + self.self_attn = CustomAttention(config) + +# ❌ 避免:完全重写整个类,除非必要 +``` + +**保持接口一致性**: + +``` +def forward(self, input_ids, attention_mask=None, **kwargs): + # 处理自定义逻辑 + result = custom_processing(input_ids) + # 调用父类实现剩余逻辑 + super().forward(result, attention_mask, **kwargs) +``` + +**充分利用现有组件**: + +``` +# ✅ 正确:复用基础模型的组件 +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) +``` + +### 2.**执行转换命令** + +通过一个用户主脚本main来驱动整个流程。其标准使用方式如下: + +``` +#自动查找各个模型文件下的modular__**.py模块化构建代码,执行转换生成modeling__***.py文件 +python main.py +``` + +### **自动化处理流水线** + + + +### 最终输出 + +最终,在模型对应的目录下(如 `src/transformers/models/qwen2/`)会生成目标文件: + +- **`modeling_qwen2.py`**:**这是唯一的输出文件,也是最终成果。** 它包含了:**模型架构**(如 `Qwen2Model`, `Qwen2ForCausalLM`)**内联的配置类**(如 `Qwen2Config`)**所有相关的函数、工具类和常量****正确的导入语句**(只导入标准库或 transformers 的通用组件)**文件顶部的警告注释**:明确告知开发者此文件为自动生成,不可手动编辑。 \ No newline at end of file diff --git a/paddlenlp/transformers/convert/doc/llama_analysis.md b/paddlenlp/transformers/convert/doc/llama_analysis.md new file mode 100644 index 000000000000..603e95d422e3 --- /dev/null +++ b/paddlenlp/transformers/convert/doc/llama_analysis.md @@ -0,0 +1,60 @@ +以下是针对Llama模型架构中各组件的功能解析,按模块分类说明其核心作用: + +------ + +### **核心工具函数** + +| 函数/常量 | 作用 | 与Qwen2的差异 | +| :----------------------------: | :-----------------------------------: | :--------------------: | +| `swiglu` | 实现SwiGLU激活函数:`x * silu(gate)` | Qwen2使用标准GLU | +| `rms_norm_fused` | 启用融合的RMSNorm计算(CUDA优化) | 实现相同但配置参数不同 | +| `__all__` | 定义模块的公开接口 | - | +| `_get_interleave` | 生成交错注意力头索引(用于长序列) | Llama特有 | +| `build_alibi_tensor` | 构建ALiBi位置偏置张量(相对位置编码) | Qwen2未使用 | +| `get_triangle_upper_mask` | 生成因果上三角掩码 | 实现逻辑相同 | +| `assign_kv_heads` | 分配KV头的索引(支持GQA/MQA) | - | +| `parallel_matmul` | 并行矩阵乘法(张量并行) | - | +| `scaled_dot_product_attention` | 核心注意力计算 | Llama支持更多掩码类型 | +| `_make_causal_mask` | 动态生成因果掩码(考虑padding) | Qwen2更简化 | + +### **归一化层** + +| 类/函数 | 作用 | 差异点 | +| :------------: | :-------------------: | :-------------: | +| `LlamaRMSNorm` | 带融合优化的RMS归一化 | 与Qwen2实现相同 | + +### **位置编码(核心差异)** + +| 类 | 作用 | 特性 | +| :-------------------------------------: | :------------------------: | :---------------: | +| `LlamaRotaryEmbedding` | 基础RoPE实现 | - | +| `LlamaLinearScalingRotaryEmbedding` | 线性缩放RoPE(扩展上下文) | Qwen2无此变体 | +| `LlamaNTKScalingRotaryEmbedding` | NTK-aware缩放RoPE | 动态调整高频/低频 | +| `LlamaDynamicNTKScalingRotaryEmbedding` | 动态NTK缩放(训练自适应) | Llama特有 | +| `Llama3RotaryEmbedding` | Llama3专用RoPE | 改进的旋转策略 | + +### **前馈网络** + +| 类 | 作用 | 差异 | +| :--------: | :-----------------: | :------------: | +| `LlamaMLP` | 使用SwiGLU的门控FFN | Qwen2用普通GLU | + +### **注意力机制** + +| 类 | 核心改进 | 说明 | +| :-----------------: | :----------------------------------------: | :------------: | +| `LlamaAttention` | - 多版本RoPE支持 - ALiBi融合 - 动态NTK缩放 | 比Qwen2更复杂 | +| `LlamaDecoderLayer` | 深度优化层实现 | 支持梯度检查点 | + +### **预训练基础** + +| 类 | 关键功能 | 扩展性 | +| :--------------------: | :--------------------------: | :-----------: | +| `LlamaPretrainedModel` | - 多设备加载 - FLOPs计算工具 | 比Qwen2更完善 | + +### **任务模块** + +| 类 | 用途 | 特色 | +| :----------------: | :------------: | :---------------: | +| `LlamaForCausalLM` | 语言建模 | 支持静态图导出 | +| `ConcatMaskedLoss` | 多任务损失合并 | 处理padding的梯度 | \ No newline at end of file diff --git a/paddlenlp/transformers/convert/doc/process.png b/paddlenlp/transformers/convert/doc/process.png new file mode 100644 index 000000000000..317248b7d040 Binary files /dev/null and b/paddlenlp/transformers/convert/doc/process.png differ diff --git a/paddlenlp/transformers/convert/doc/qwen2_analysis.md b/paddlenlp/transformers/convert/doc/qwen2_analysis.md new file mode 100644 index 000000000000..4ac53f797ee7 --- /dev/null +++ b/paddlenlp/transformers/convert/doc/qwen2_analysis.md @@ -0,0 +1,87 @@ +![img](https://i-blog.csdnimg.cn/blog_migrate/4757ad5c98d4344547227fc52684bac1.png) + +是Qwen2模型中各组件和函数的详细作用说明,按模块分类整理: + +### **核心工具函数** + +| 函数/常量 | 作用 | +| :----------------------------: | :------------------------------------------------------: | +| `__all__` | 定义模块的公开接口,控制`from module import *`时的可见性 | +| `get_triangle_upper_mask` | 生成上三角因果注意力掩码(防止未来信息泄露) | +| `assign_kv_heads` | 分配Key/Value头的索引(用于GQA/MQA) | +| `parallel_matmul` | 并行矩阵乘法(支持张量并行) | +| `scaled_dot_product_attention` | 实现缩放点积注意力核心计算 | +| `masked_fill` | 按掩码填充张量(如将padding位置设为负无穷) | +| `is_casual_mask` | 判断是否为因果注意力掩码 | +| `_make_causal_mask` | 创建因果注意力掩码(考虑padding) | +| `_expand_2d_mask` | 将2D掩码扩展为4D(适配多头注意力) | +| `repeat_kv` | 重复Key/Value头(用于GQA/MQA) | + +### **归一化层** + +| 类/函数 | 作用 | +| :------------: | :------------------------------: | +| `Qwen2RMSNorm` | **RMS归一化层**(替代LayerNorm) | + +### **位置编码** + +| 类/函数 | 作用 | +| :----------------------: | :--------------------------------: | +| `Qwen2RotaryEmbedding` | **旋转位置编码(RoPE)** | +| - `rotate_half` | 旋转向量的后半部分(RoPE核心操作) | +| - `apply_rotary_pos_emb` | 将旋转位置编码应用到注意力分数 | + +### **前馈网络** + +| 类/函数 | 作用 | +| :--------: | :---------------------------: | +| `Qwen2MLP` | **门控线性单元(GLU)前馈网络** | + +### **注意力机制** + +| 类/函数 | 作用 | +| :-----------------: | :------------------------------------------------: | +| `Qwen2Attention` | **多头注意力机制** | +| - `__init__` | 初始化Q/K/V投影层、输出层和RoPE | +| - `forward` | 处理输入序列,计算注意力分数并聚合值向量 | +| `Qwen2DecoderLayer` | **Transformer解码层** | +| - `__init__` | 组合自注意力层和前馈网络 | +| - `forward` | 执行:`LN -> Attention -> Add -> LN -> MLP -> Add` | + +### **预训练基础** + +| 类/函数 | 作用 | +| :--------------------: | :--------------------------------: | +| `Qwen2PretrainedModel` | **预训练模型基类** | +| - `config_class` | 关联的配置类(Qwen2Config) | +| - `_get_name_mappings` | 定义参数名称映射(用于加载检查点) | +| - `_init_weights` | 参数初始化策略 | +| - `_get_model_flops` | 计算模型FLOPs | + +### **主干模型** + +| 类/函数 | 作用 | +| :---------------------------------: | :-----------------------: | +| `Qwen2Model` | **模型主干架构** | +| - `_prepare_decoder_attention_mask` | 生成解码器掩码 | +| - `forward` | 执行完整的Transformer堆栈 | +| `Qwen2ForCausalLM` | **因果语言模型** | +| - `prepare_inputs_for_generation` | 处理生成时的输入格式 | +| - `forward` | 计算语言建模损失 | + +### **任务特定头部** + +| 类 | 作用 | +| :------------------------------: | :----------------------: | +| `Qwen2LMHead` | 语言模型头部(词表投影) | +| `Qwen2ForSequenceClassification` | 序列分类任务适配 | +| `Qwen2ForTokenClassification` | 标记分类任务适配 | +| `Qwen2SentenceEmbedding` | 句子向量提取 | + +### **训练相关** + +| 类/函数 | 作用 | +| :-------------------------: | :------------------------: | +| `Qwen2PretrainingCriterion` | 预训练损失计算 | +| `recompute_training_full` | 激活重计算策略 | +| `create_custom_forward` | 为梯度检查点创建自定义前向 | diff --git "a/paddlenlp/transformers/convert/doc/\351\241\271\347\233\256\346\212\245\345\221\212\357\274\232\351\243\236\346\241\250PaddleNLP-\345\211\215\346\262\277\346\250\241\345\236\213\346\250\241\345\235\227\345\214\226\350\256\276\350\256\241.md" "b/paddlenlp/transformers/convert/doc/\351\241\271\347\233\256\346\212\245\345\221\212\357\274\232\351\243\236\346\241\250PaddleNLP-\345\211\215\346\262\277\346\250\241\345\236\213\346\250\241\345\235\227\345\214\226\350\256\276\350\256\241.md" new file mode 100644 index 000000000000..97f0e1382d41 --- /dev/null +++ "b/paddlenlp/transformers/convert/doc/\351\241\271\347\233\256\346\212\245\345\221\212\357\274\232\351\243\236\346\241\250PaddleNLP-\345\211\215\346\262\277\346\250\241\345\236\213\346\250\241\345\235\227\345\214\226\350\256\276\350\256\241.md" @@ -0,0 +1,150 @@ +## 项目报告:飞桨PaddleNLP-前沿模型模块化设计 + +### 项目信息 + +* 项目名称:飞桨PaddleNLP-前沿模型模块化设计 +* 方案描述: + * 文件解析与并行处理:自动查找并并行处理 modular_*.py 文件。 + * 转换流程 (run_converter):封装了单个模块化文件到独立模型文件的完整转换逻辑,包括动态确定模型名称和输出路径、收集并展开导入、重写子类并生成中间文件、移除导入并重写、全局重命名以及最终文件写入和临时文件清理。 + * 辅助工具 (until 目录):包含 collect_import_modeling.py、rename_identifiers.py 和 rewrite_child_classes.py 等脚本,用于支持转换流程中的导入处理、标识符重命名和子类重写。 + * modular_qwen2.py:作为 convert 工具的输入,该文件继承了 Llama 模型的许多组件,并进行了 Qwen2 特有的修改和优化,如 Qwen2RMSNorm、Qwen2RotaryEmbedding、Qwen2MLP 和 Qwen2Attention。它还包含了大量与 PaddlePaddle 分布式训练(如张量并行、序列并行、重计算)和性能优化(如 Flash Attention、融合操作)相关的导入和逻辑。 + * configuration.py:定义了 Qwen2Config 类,存储 Qwen2 模型的所有配置参数,包括词汇表大小、隐藏层维度、注意力头数量、激活函数、最大位置嵌入长度等,并支持滑动窗口注意力等高级特性。 + * modeling__qwen2.py:是经过 convert 工具处理后生成的最终模型文件,包含了 Qwen2 模型的完整实现,包括核心模型类、组件实现、辅助函数以及分布式训练和性能优化相关的代码。 + +* 时间规划: + + * 需求分析与方案设计(7月1日-7月15日) + + 详细调研和对比分析至少两种主流LLM(如Llama系列、Qwen系列)的架构细节、实现差异 和共通之处。 设计LLM的模块化组件体系,明确各模块的功能边界、输入输出接口、可配置参数以及模块 间的依赖关系。制定基于libcst的源码分析策略,确定需要识别的代码模式和转换规则。 初步规划模型并行能力的模块化方案和自动化集成思路 + + * 模块化核心功能开发(7月15日-8月15日) + + 利用libcst等工具,开发源码分析和转换工具的原型,能够解析现有模型代码并提取关键 结构信息,或根据配置生成初步的模块化代码片段。 搭建单元测试和集成测试框架,确保各模块和工具的正确性。 + + * Qwen2模型自动化构造(8月15日-9月7日) + + 以Llama模型结构为蓝本,利用阶段二开发的工具和模块库,自动化生成Qwen2模型的完整结构代码 + + * 精度对齐及模型并行能力验证(9月7日-9月21日) + + 对生成的Qwen2模型进行细致的功能测试和精度验证,通过在测试数据上与手动实现的 Qwen2模型进行效果对比,确保数值精度对齐。 + + * 文档撰写与项目总结(9月21日-9月30日) + + 编写详细的设计文档、用户手册、上手教程以及最佳实践案例。 整理项目代码,按照PaddleNLP社区规范准备Pull Request,将核心成果贡献给社区。完成项目总结报告。 + +### 项目进度 + +* 已完成工作: + + 对照项目申请书的方案,我完成了预定任务,主要工作成果如下: + + * 核心转换流水线 + - 实现了`run_converter()`函数,执行完整的三阶段转换流程 + - 支持并行处理多个模型文件转换 + - 自动生成带警告标识的输出文件 + * 导入扩展系统 + - 实现`expand_modeling_imports()`函数进行递归依赖解析 + - 支持模块化导入的自动展开和集成 + * 标识符重命名系统 + - 实现`rename_identifiers()`函数进行智能重命名 + - 支持大小写保持的标识符转换 + * 类重写系统 + - 实现完整的类重写工具,支持继承关系扁平化 + - 集成依赖分析和合并引擎 + * 以Llama为蓝本的Qwen2模型自动化生成 + * 实现了基于llama的modular__qwen2.py + * 通过转换系统自动生成完整的Qwen2模型实现 + * 对生成的模型进行了精度验证和并行能力验证 + * 与原代码进行精度对齐 + * 进行了并行能力验证 + * 项目文档 + * 转换工具的使用方法 + * 模块化构建的流程 + * 精度及并行能力验证报告 + +* 遇到的问题以及解决方案 + + * Import导入项收集的复杂性问题 + + 存在的难点: + + - (1)重复导入识别:同一个模块可能通过不同路径被多次导入,需要去重处理 + - (2)导入格式多样性:存在相对导入(`..modeling`)、绝对导入等多种格式,解析复杂 + - (3)循环依赖检测:模块间可能存在相互依赖,导致无限递归 + - (4)无效导入过滤:需要区分真正的modeling导入和其他类型的导入 + + 针对以上问题,我提出了分层过滤解决方案: + + - (1)专门的导入收集器: collect_import_modeling实现`ModelingImportCollector`专门识别包含"modeling"关键字的导入语句 + - (2)路径标准化处理: collect_import_modeling通过`resolve_file_path()`函数统一处理相对导入路径转换 + - (3)循环依赖避免机制: collect_import_modeling使用`seen`集合记录已处理的依赖项,防止无限递归 + - (4)严格模式过滤:filter_specific_modeling_imports()`只保留严格符合相对导入模式的modeling导入 + + * 大规模文件处理效率 + + 存在的难点: + + - (1)单线程处理效率低:大量模型文件的串行处理耗时过长 + - (2)内存占用过高:同时加载多个大型模型文件导致内存压力 + + 针对以上问题,我构建了并行处理架构: + + - (1)多进程并行化: main.py使用`multiprocessing.Pool`实现多进程并行转换,动态调整工作进程数量 + - (2)临时文件管理: main.py:65-68 处理完成后自动清理临时文件,减少内存占用 + + * 标识符重命名一致性 + + 存在的难点: + + - (1)大小写风格保持:需要保持原代码的命名风格(如llama→qwen2, Llama→Qwen2, LLAMA→QWEN2) + - (2)误替换风险:字符串级别的替换容易产生误替换和语法错误 + - (3)冲突检测复杂:需要避免与已存在的标识符产生命名冲突 + + 针对以上问题,我开发了AST级别的智能重命名系统: + + - (1)大小写保持算法: rename_identifiers.py通过`_case_preserving_replace()`方法检测原标识符的大小写模式并应用到目标名称 + - (2)AST精确转换: rename_identifiers.py 使用`GenericRenamerTransformer`在AST层面进行精确替换,避免误替换 + - (3)智能冲突检测: rewrite_child_classes.py维护`existing_names`集合检测命名冲突,只注入不冲突的依赖项 + + * 注入依赖时的位置问题 + + 存在的难点: + + - (1)依赖注入顺序混乱:不同类型的依赖(方法、类)混合注入导致代码结构不清晰 + - (2)父子类位置关系错误:子类可能在父类定义之前被注入,导致引用错误 + - (3)代码可读性差:依赖项随意插入破坏了代码的逻辑结构和可维护性 + + 针对以上问题,我实现了智能的分层注入策略: + + - (1)依赖类型分类注入: 系统将注入的依赖分为方法和类两类,方法优先注入在imports之后,类按依赖关系分层注入 + - (2)父子类依赖关系排序: 通过分析类的继承关系,将有父类依赖的类和无父类依赖的类分开处理,确保父类先于子类定义 + - (3)动态位置插入机制: 在遍历主逻辑时,当遇到父类定义后立即插入其对应的子类,保证依赖关系的正确性和代码的逻辑连贯性 + +* 测试用例 + + * 模型转换正确性验证 + + 测试用例的核心是验证从`modular_qwen2.py`转换生成的`modeling_qwen2.py`的功能正确性 + + * 双模型对比测试:加载原始的modular_qwen2模型和转换后的modeling_qwen2模型进行数值对比 + * 精度验证:使用相同的输入数据,两个模型输出的数值使用`numpy.allclose`进行数值对比,相对容差`rtol=1e-5`,绝对容差`atol=1e-3`;转换后的模型与原模型输出相同,模型转换正确。 + + * 精度对齐与并行能力验证 + + 测试用例的核心是验证从`modular_qwen2.py`转换生成的`modeling_qwen2.py`的并行能力正确性 + + - 分布式训练兼容性:创建分布式并行环境,并使用paddle.distributed.launch进行启动运行,在张量并行度为2的配置下,转换后模型与原模型输出相对容差`rtol=1e-5`,绝对容差`atol=1e-3` + - 对于相同的输入产生了完全相同的输出,分布式能力验证成功。 + +* 后续工作安排 + + * 对于大规模文件的处理依赖 + + 对于多文件建立依赖关系图,从底层文件一次向上开始转换 + + * 完善pre-commit,实现自动化转换模块化文件并进行验证 + + * 扩展测试用例覆盖,完善精度对齐和并行能力验证文档中的测试场景 + + * 优化基础模型的构建,真正把基础模型标准化,模块化 \ No newline at end of file diff --git a/paddlenlp/transformers/convert/main.py b/paddlenlp/transformers/convert/main.py new file mode 100644 index 000000000000..9e66d9647abd --- /dev/null +++ b/paddlenlp/transformers/convert/main.py @@ -0,0 +1,158 @@ +import argparse +import glob +import os +import multiprocessing as mp +from pathlib import Path +import re + +from until.collect_import_modeling import expand_modeling_imports,remove_imports_and_rewrite,save_results_to_txt +from until.rewrite_child_classes import rewrite_child_classes +from until.rename_identifiers import rename_identifiers + +AUTO_GENERATED_MESSAGE = """# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from {relative_path}. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# {short_name} file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +""" + +# --- 核心转换逻辑封装 --- +def run_converter(file_to_parse_str: str): + """ + 对单个 modular 文件执行完整的转换流程。 + 这个函数是并行处理的基本单元。 + """ + print(file_to_parse_str) + file_to_parse = Path(file_to_parse_str) + + # --- 动态确定模型名称和输出路径 --- + # 假设:目标模型名可以从文件名推断,例如 "modular_qwen2.py" -> "qwen2" + to_name = file_to_parse.stem.replace("modular_", "") + # 假设:源模型名在模块化文件中有定义或约定俗成(这里我们先硬编码为 "llama" 作为示例) + # 一个更健壮的实现会从 file_to_parse 文件内容中解析出 from_name + + # 输出文件将与输入文件在同一目录下,例如 "modeling_qwen2.py" + output_file = file_to_parse.parent / f"modeling_{to_name}.py" + temp_merged_file = file_to_parse.parent / (output_file.stem + "_temp_merged.py") + + print(f"--- 开始转换: {to_name} ---") + print(f" 输入文件: '{file_to_parse}'") + print(f" 最终输出: '{output_file}'") + + # 步骤 1: 收集并展开 import + expanded_code , dependencies,from_name= expand_modeling_imports(file_to_parse) + #save_results_to_txt(dependencies, "modeling_imports_dependencies.txt") + #save_results_to_txt(expanded_code, "modeling_imports_results.txt") + #print(from_name) + # 步骤 2: 重写子类并生成中间文件 + relative_path = re.search( + r"(transformers/.*|examples/.*)", os.path.abspath(file_to_parse).replace("\\", "/") + ).group(1) + formatted_message=AUTO_GENERATED_MESSAGE.format(relative_path=relative_path, short_name=os.path.basename(relative_path)) + rewrite_child_classes(expanded_code, file_to_parse,formatted_message, temp_merged_file,rename_map={ + "llama": "qwen2" # 只需要提供小写形式! + }) + remove_imports_and_rewrite(temp_merged_file) + # 步骤 3: 全局重命名 + try: + merged_code = temp_merged_file.read_text(encoding="utf-8") + final_code = rename_identifiers(merged_code, from_name, to_name) + output_file.write_text(final_code, encoding="utf-8") + print(f" ✅ 转换成功,最终代码已写入 '{output_file}'。") + except FileNotFoundError: + print(f" ❌ [错误] 找不到中间文件 '{temp_merged_file}',无法进行重命名。") + finally: + # 清理临时文件 + if temp_merged_file.exists(): + temp_merged_file.unlink() + + print(f"--- 转换结束: {to_name} ---\n") + + +# --- 主执行逻辑 --- +def main(): + parser = argparse.ArgumentParser( + description="将模块化的模型定义文件(modular_*.py)转换为独立的模型文件(modeling_*.py)。" + ) + parser.add_argument( + "files", + nargs="*", + help="要转换的模块化文件列表(可选的位置参数)。", + ) + parser.add_argument( + "--files-to-parse", "-f", + dest="files_to_parse", # 明确指定存储的目的地 + default=[], # 默认值改为空列表 + nargs="+", + help="要转换的模块化文件列表。可使用 'all' 或 'examples' 关键字。", + ) + parser.add_argument( + "--num_workers", "-w", + default=-1, + type=int, + help="使用的进程数。默认为 -1,代表使用所有 CPU核心。", + ) + args = parser.parse_args() + + # 合并位置参数和可选参数,以可选参数优先 + files_to_parse = args.files_to_parse if args.files_to_parse else args.files + if not files_to_parse: + files_to_parse = ["all"] # 如果未提供任何文件,则默认为 'all' + + num_workers = mp.cpu_count() if args.num_workers == -1 else args.num_workers + + # --- 解析文件路径 --- + print(">>> 正在解析需要转换的文件...") + if files_to_parse == ["all"]: + from pathlib import Path + # 确定项目根目录 + SCRIPT_DIR = Path(__file__).resolve().parent + PROJECT_ROOT = SCRIPT_DIR.parent.parent.parent + # 使用绝对路径进行搜索 + search_path = PROJECT_ROOT / "paddleformers/transformers/**/modular_*.py" + files_to_parse = glob.glob(str(search_path), recursive=True) + elif files_to_parse == ["examples"]: + # 查找所有 examples 目录下的 modular 文件 + files_to_parse = glob.glob("examples/**/modular_*.py", recursive=True) + else: + # 将模型简称(如 qwen2)解析为完整路径 + resolved_files = [] + for model_name in files_to_parse: + if not os.path.exists(model_name): + # 尝试在 models 目录下构建路径 + full_path = os.path.join("src", "transformers", "models", model_name, f"modular_{model_name}.py") + if not os.path.isfile(full_path): + # 如果找不到,尝试在 examples 目录下构建 + full_path = os.path.join("examples", "modular-transformers", f"modular_{model_name}.py") + + if not os.path.isfile(full_path): + raise ValueError(f"无法为 '{model_name}' 找到模块化文件。请提供完整路径或确认文件名正确。") + resolved_files.append(full_path) + else: + resolved_files.append(model_name) + files_to_parse = resolved_files + + if not files_to_parse: + print("未找到任何需要转换的文件。") + return + + print(f"发现 {len(files_to_parse)} 个文件待处理。") + + + ordered_files= [files_to_parse] + print(ordered_files) + + # --- 按依赖顺序并行处理 --- + for i, dependency_level_files in enumerate(ordered_files): + print(f"\n>>> 开始处理依赖层级 {i+1}/{len(ordered_files)} ({len(dependency_level_files)} 个文件)...") + workers = min(num_workers, len(dependency_level_files)) + if workers > 0: + with mp.Pool(workers) as pool: + pool.map(run_converter, dependency_level_files) + + print("\n--- 所有转换任务已完成 ---") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/paddlenlp/transformers/convert/test_ability/test_base/test_modeling.py b/paddlenlp/transformers/convert/test_ability/test_base/test_modeling.py new file mode 100644 index 000000000000..0b1cdf229b4b --- /dev/null +++ b/paddlenlp/transformers/convert/test_ability/test_base/test_modeling.py @@ -0,0 +1,45 @@ +"""模型组网正确性验证 +【基本流程】 + +定义原模型,加载权重,固定seed,基于numpy生成随机数,转换为PyTorch可以处理的tensor,送入网络,获取输出。 + +定义模块化转换后modeling模型,加载权重,固定seed,基于numpy生成随机数,转换为PaddlePaddle可以处理的tensor,送入网络,获取输出。 + +排查diff,小于阈值,即可完成自测。 +""" +import numpy as np +import paddle +from paddleformers.transformers.qwen2 import Qwen2Config +from paddleformers.transformers.qwen2.modeling import Qwen2ForCausalLM +from paddleformers.transformers import Qwen2Config as Qwen2Config_hf +from paddleformers.transformers import Qwen2ForCausalLM as Qwen2ForCausalLM_hf +#from paddleformers.transformers.qwen2.test_model_expanded import Qwen2ForCausalLM as Qwen2ForCausalLM_hf + + + +def eval_model_convert(): + paddle_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + torch_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + + # paddle model + paddle_ckpt_path = "Qwen/Qwen2-0.5B" + config_paddle = Qwen2Config.from_pretrained(paddle_ckpt_path) + model_paddle = Qwen2ForCausalLM.from_pretrained(paddle_ckpt_path, config=config_paddle, dtype="float32") + + # torch model + + torch_ckpt_path = "Qwen/Qwen2-0.5B" + config_torch = Qwen2Config_hf.from_pretrained(torch_ckpt_path) + config_torch.dtype = "float32" + model_torch = Qwen2ForCausalLM_hf.from_pretrained(torch_ckpt_path, config=config_torch, dtype="float32") + + model_paddle.eval() + model_torch.eval() + + out_paddle = model_paddle(paddle_input_ids)[0] + out_torch = model_torch(torch_input_ids, return_dict=False)[0] + print(out_paddle) + print(out_torch) + assert np.allclose(out_paddle.numpy(), out_torch.detach().numpy(), rtol=1e-5, atol=1e-3) + +eval_model_convert() \ No newline at end of file diff --git a/paddlenlp/transformers/convert/test_ability/test_paralle/compare_torch_with_paddle.py b/paddlenlp/transformers/convert/test_ability/test_paralle/compare_torch_with_paddle.py new file mode 100644 index 000000000000..f8414ba936d6 --- /dev/null +++ b/paddlenlp/transformers/convert/test_ability/test_paralle/compare_torch_with_paddle.py @@ -0,0 +1,50 @@ +import numpy as np +import paddle +from paddle.distributed import fleet +from paddleformers.transformers.qwen2 import Qwen2Config +from paddleformers.transformers.qwen2.modeling import Qwen2ForCausalLM +from paddleformers.transformers import Qwen2Config as Qwen2Config_hf +from paddleformers.transformers import Qwen2ForCausalLM as Qwen2ForCausalLM_hf + +def eval_model_convert_parallel(mp_degree=1): + paddle_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + torch_input_ids = paddle.to_tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 1588, 2]]) + + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": mp_degree, + "pp_degree": 1, + "sharding_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + hcg = fleet.get_hybrid_communicate_group() + + # paddle model + paddle_ckpt_path = "Qwen/Qwen2-0.5B" + config_paddle = Qwen2Config.from_pretrained(paddle_ckpt_path) + config_paddle.tensor_parallel_degree = hcg.get_model_parallel_world_size() + config_paddle.tensor_parallel_rank = hcg.get_model_parallel_rank() + config_paddle.tensor_parallel_output = False + model_paddle = Qwen2ForCausalLM.from_pretrained(paddle_ckpt_path, config=config_paddle, dtype="float32") + + # torch model + torch_ckpt_path = "Qwen/Qwen2-0.5B" + config_torch = Qwen2Config_hf.from_pretrained(torch_ckpt_path) + config_torch = Qwen2Config.from_pretrained(paddle_ckpt_path) + config_torch.tensor_parallel_degree = hcg.get_model_parallel_world_size() + config_torch.tensor_parallel_rank = hcg.get_model_parallel_rank() + config_torch.tensor_parallel_output = False + model_torch = Qwen2ForCausalLM_hf.from_pretrained(torch_ckpt_path, config=config_torch, dtype="float32") + + model_paddle.eval() + model_torch.eval() + + # 手动验证 + out_paddle = model_paddle(paddle_input_ids)[0] + out_torch = model_torch(torch_input_ids)[0] + print(out_paddle) + print(out_torch) + assert np.allclose(out_paddle.numpy(), out_torch.detach().numpy(), rtol=1e-5, atol=1e-4) + +eval_model_convert_parallel(mp_degree=2) \ No newline at end of file diff --git a/paddlenlp/transformers/convert/until/collect_import_modeling.py b/paddlenlp/transformers/convert/until/collect_import_modeling.py new file mode 100644 index 000000000000..892bfa8e3f93 --- /dev/null +++ b/paddlenlp/transformers/convert/until/collect_import_modeling.py @@ -0,0 +1,259 @@ +import libcst as cst +import os +from pathlib import Path +from typing import Dict, Set, Union, List, Tuple + +# ============================================================================== +# 以下所有函数和类均保持您提供的原始版本,没有任何改动 +# ============================================================================== +def get_unique_module_names(imports_dict: Dict[str, str]) -> Set[str]: + """ + 从字典的值中提取出所有唯一的、纯净的模块名。 + 它会移除前缀 '..' 和末尾的 '.'。 + """ + unique_names = set() + + for prefix_value in imports_dict.values(): + temp_name = prefix_value + + # 1. 移除开头的 '..' + if temp_name.startswith(".."): + temp_name = temp_name[2:] + + # 2. 移除末尾的 '.' + final_name = temp_name.rstrip('.') + + # 3. 将最终结果添加到集合中,自动保证唯一性 + if final_name: + unique_names.add(final_name) + + return unique_names +def get_full_name(node: Union[cst.Name, cst.Attribute, cst.ImportFrom]) -> str: + if isinstance(node, cst.Name): + return node.value + elif isinstance(node, cst.Attribute): + return get_full_name(node.value) + "." + node.attr.value + elif isinstance(node, cst.ImportFrom): + module_parts = [] + if node.relative: + module_parts.append("." * len(node.relative)) + if node.module: + module_parts.append(get_full_name(node.module)) + return "".join(module_parts) + else: + return "" + +class ModelingImportCollector(cst.CSTVisitor): + def __init__(self): + self.imports: Dict[str, str] = {} # name -> module_path + self.prefixes_before_modeling: Dict[str, str] = {} + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + modname = get_full_name(node) + if "modeling" in modname: + modeling_index = modname.find("modeling") + prefix = modname[:modeling_index] + for alias in node.names: + name_in_scope = alias.evaluated_name + self.imports[alias.evaluated_name] = modname + self.prefixes_before_modeling[name_in_scope] = prefix + +class DependencyCollector(cst.CSTVisitor): + def __init__(self): + self.names: Set[str] = set() + def visit_Name(self, node: cst.Name) -> None: + self.names.add(node.value) + +class ModuleInfoCollector(cst.CSTVisitor): + def __init__(self): + self.defs: Dict[str, Union[cst.ClassDef, cst.FunctionDef, cst.Assign]] = {} + self.imports: Dict[str, Union[cst.Import, cst.ImportFrom]] = {} + self.class_stack: List[str] = [] + def visit_ClassDef(self, node: cst.ClassDef) -> None: + self.defs[node.name.value] = node + self.class_stack.append(node.name.value) + def leave_ClassDef(self, original_node: cst.ClassDef) -> None: + self.class_stack.pop() + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + if not self.class_stack: + self.defs[node.name.value] = node + else: + fullname = ".".join(self.class_stack + [node.name.value]) + self.defs[fullname] = node + def visit_Assign(self, node: cst.Assign) -> None: + if not self.class_stack: + for target_wrapper in node.targets: + if isinstance(target_wrapper.target, cst.Name): + self.defs[target_wrapper.target.value] = node + def visit_Import(self, node: cst.Import) -> None: + for alias in node.names: + name_in_scope = alias.asname.name.value if alias.asname else alias.name.value + self.imports[name_in_scope] = node + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + for alias in node.names: + name_in_scope = alias.asname.name.value if alias.asname else alias.name.value + self.imports[name_in_scope] = node + +def parse_file(file_path: str) -> Tuple[Dict, Dict, cst.Module]: + with open(file_path, "r", encoding="utf-8") as f: + code = f.read() + module = cst.parse_module(code) + collector = ModuleInfoCollector() + module.visit(collector) + return collector.defs, collector.imports, module + +def collect_recursive( + name: str, defs: Dict[str, cst.CSTNode], imports: Dict[str, cst.CSTNode], + seen: Set[str], module: cst.Module, +) -> Tuple[Dict[str, str], Set[str], Dict[str, List[str]]]: + if name in seen or name not in defs: + return {}, set(), {} + seen.add(name) + node = defs[name] + dependencies = {name: []} + dep_collector = DependencyCollector() + node.visit(dep_collector) + results = {name: module.code_for_node(node)} + collected_imports = set() + for dep in dep_collector.names: + if dep in defs and dep not in seen: + dep_results, dep_imports , dep_deps = collect_recursive(dep, defs, imports, seen, module) + results.update(dep_results) + collected_imports.update(dep_imports) + dependencies.update(dep_deps) + dependencies[name].append(dep) # 记录依赖关系 A -> B + elif dep in imports: + import_node = imports[dep] + import_code = module.code_for_node(import_node) + collected_imports.add(import_code) + dependencies[name].append(dep) + return results, collected_imports, dependencies + +def resolve_file_path(current_file: str, modpath: str) -> Path: + dots = len(modpath) - len(modpath.lstrip(".")) + parts = modpath.lstrip(".").split(".") + cur_dir = Path(current_file).parent + for _ in range(dots - 1): + cur_dir = cur_dir.parent + file_path = cur_dir.joinpath(*parts).with_suffix(".py") + return file_path if file_path.exists() else None + +def expand_modeling_imports(file_path: str) -> Dict[str, str]: + with open(file_path, "r", encoding="utf-8") as f: + code = f.read() + module = cst.parse_module(code) + imp_collector = ModelingImportCollector() + module.visit(imp_collector) + expanded_defs = {} + all_imports = set() + seen = set() + dependencies = {} + for name, modpath in imp_collector.imports.items(): + target_file = resolve_file_path(file_path, modpath) + if not target_file: continue + defs, imports, parsed_module = parse_file(str(target_file)) + if name in defs: + new_defs, new_imports, new_deps = collect_recursive(name, defs, imports, seen, parsed_module) + expanded_defs.update(new_defs) + all_imports.update(new_imports) + dependencies.update(new_deps) + expanded = {} + for i, import_code in enumerate(sorted(list(all_imports))): + expanded[f"__import_{i}__"] = import_code + expanded.update(expanded_defs) + unique_modules = get_unique_module_names(imp_collector.prefixes_before_modeling) + return expanded, dependencies,unique_modules # 返回代码和依赖关系 + +def save_results_to_txt(result: Dict[str, str], output_file: str): + imports_to_write = [] + defs_to_write = {} + for key, value in result.items(): + if key.startswith("__import_"): + imports_to_write.append(value) + else: + defs_to_write[key] = value + with open(output_file, "w", encoding="utf-8") as f: + if imports_to_write: + f.write("### === Imports === ###\n") + for imp in imports_to_write: + f.write(f"{imp}\n") + f.write("\n" + "="*50 + "\n\n") + if defs_to_write: + f.write("### === Definitions === ###\n") + for k, v in sorted(defs_to_write.items()): + f.write(f"=== {k} ===\n") + f.write(f"{v}\n\n") + +# ============================================================================== +# ### NEW ### 以下是为“文件重写”这一新增功能而添加的全新、独立的模块 +# ============================================================================== + +class ModelingImportNodeCollector(cst.CSTVisitor): + """一个专门用于收集待删除 import 节点的新 Visitor。""" + def __init__(self): + self.nodes_to_remove: Set[cst.ImportFrom] = set() + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + modname = get_full_name(node) + if "modeling" in modname: + self.nodes_to_remove.add(node) + +class ImportRemover(cst.CSTTransformer): + """一个独立的转换器,用于从语法树中删除指定的import节点。""" + def __init__(self, nodes_to_remove: Set[cst.ImportFrom]): + self.nodes_to_remove = nodes_to_remove + + def leave_ImportFrom( + self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom + ) -> Union[cst.ImportFrom, cst.RemovalSentinel]: + if original_node in self.nodes_to_remove: + return cst.RemoveFromParent() + return updated_node + +def remove_imports_and_rewrite(file_path: str): + """ + 一个独立的函数,封装了文件读取、收集待删除节点、转换和重写的操作。 + """ + # 1. 再次读取和解析文件,以启动独立的重写流程 + with open(file_path, "r", encoding="utf-8") as f: + code = f.read() + module = cst.parse_module(code) + + # 2. 收集需要删除的节点 + node_collector = ModelingImportNodeCollector() + module.visit(node_collector) + + nodes_to_remove = node_collector.nodes_to_remove + if not nodes_to_remove: + print(f"No 'modeling' imports found in '{file_path}' to remove.") + return + + # 3. 使用转换器生成修改后的代码 + print(f"Removing {len(nodes_to_remove)} 'modeling' import(s) from '{file_path}'...") + remover = ImportRemover(nodes_to_remove) + modified_tree = module.visit(remover) + + # 4. 将修改后的代码写回原文件 + with open(file_path, "w", encoding="utf-8") as f: + f.write(modified_tree.code) + print("File rewrite complete.") + + +# ============================================================================== +# ### MODIFIED ### 主程序块现在按顺序执行两个功能 +# ============================================================================== + +if __name__ == "__main__": + file_to_parse = "/home/hsz/PaddleFormers/PaddleFormers/paddleformers/transformers/convert/example/test_model.py" + output_filename = "modeling_imports_results.txt" + + # --- 步骤 1: 执行完整的原有功能 --- + # 调用函数,其接口和返回值完全没有改变 + # 同时也修正了之前版本中解包错误的bug + combined_results = expand_modeling_imports(file_to_parse) + + # 保存结果,完成原有任务 + save_results_to_txt(combined_results, output_filename) + print(f"Code extraction complete. Results saved to {output_filename}") + + # --- 步骤 2: 在原有功能完成后,独立执行新增的功能 --- + remove_imports_and_rewrite(file_to_parse) \ No newline at end of file diff --git a/paddlenlp/transformers/convert/until/rename_identifiers.py b/paddlenlp/transformers/convert/until/rename_identifiers.py new file mode 100644 index 000000000000..c00440976a90 --- /dev/null +++ b/paddlenlp/transformers/convert/until/rename_identifiers.py @@ -0,0 +1,109 @@ +import libcst as cst +from libcst import CSTTransformer +import re +import os +from typing import Set, List, Union + +class GenericRenamerTransformer(CSTTransformer): + """ + 一个通用的CST转换器,用于安全地将代码中的标识符从多个源名称替换为同一个目标名称, + 并能智能地保留原始名称的大小写风格。 + """ + def __init__(self, from_names: Union[Set[str], List[str]], to_name: str): + """ + Args: + from_names: 要被替换的源名称集合或列表 (例如 {'t5', 'llama', 'utils'})。 + to_name: 用于替换的目标名称 (例如 'qwen2')。 + """ + self.to_name = to_name + + # 1. 构建一个包含所有源名称的正则表达式 | (OR 逻辑) + # - 使用 re.escape() 确保特殊字符被正确处理。 + # - 使用 | 符号连接所有名称,实现多选一匹配。 + # - 确保列表非空 + if not from_names: + raise ValueError("from_names 列表不能为空。") + + escaped_names = [re.escape(name) for name in from_names] + pattern = "|".join(escaped_names) + + # 2. 编译一个不区分大小写 (re.IGNORECASE) 的正则表达式 + self.regex = re.compile(pattern, re.IGNORECASE) + + def _case_preserving_replace(self, match: re.Match) -> str: + """ + 这是一个自定义的替换函数,它根据匹配到的字符串的大小写风格, + 来决定 to_name 应该使用哪种大小写形式。 + """ + found_str = match.group(0) + # 如果找到的是全大写 (e.g., LLAMA) + if found_str.isupper(): + return self.to_name.upper() + # 如果找到的是首字母大写 (e.g., Llama) + if found_str.istitle(): + return self.to_name.title() + # 默认情况,包括全小写 (e.g., llama),返回全小写 + return self.to_name.lower() + + def leave_Name( + self, original_node: cst.Name, updated_node: cst.Name + ) -> cst.Name: + """ + 当访问离开一个名称节点时,使用正则表达式和自定义替换函数执行重命名。 + """ + # 使用 regex.sub() 和我们的自定义函数来进行替换 + new_name_str = self.regex.sub(self._case_preserving_replace, updated_node.value) + + # 仅在名称确实发生改变时才创建一个新节点 + if new_name_str != updated_node.value: + if not new_name_str.isidentifier(): + original_name = original_node.value + # 警告,而不是跳过,因为这在依赖于上下文的重命名中可能是允许的。 + # 但对于 cst.Name 节点,它必须是有效标识符。 + print(f"警告:尝试将 '{original_name}' 重命名为无效标识符 '{new_name_str}'。跳过此重命名。") + return updated_node + return updated_node.with_changes(value=new_name_str) + + return updated_node + +def rename_identifiers(source_code: str, from_names: Union[Set[str], List[str]], to_name: str) -> str: + """ + 接收一段Python源代码,将其中的所有 from_names 相关标识符安全地重命名为 to_name。 + + Args: + source_code: 包含Python代码的字符串。 + from_names: 要被替换的源名称集合或列表 (例如 {"t5", "llama"})。 + to_name: 用于替换的目标名称 (例如 "qwen2")。 + + Returns: + 重构后的Python代码字符串。 + """ + try: + module = cst.parse_module(source_code) + transformer = GenericRenamerTransformer(from_names, to_name) + modified_module = module.visit(transformer) + return modified_module.code + except cst.ParserSyntaxError as e: + print(f"Error: Failed to parse the source code. {e}") + return source_code + except ValueError as e: + print(f"Error in rename process: {e}") + return source_code + +# --- 示例用法 --- +# source_code = """ +# class LlamaModel(T5Model): +# def forward(self, input_ids): +# return self.llama_layer(input_ids) +# LLAMA_CONFIG = 1 +# """ +# from_list = ['llama', 't5'] +# to_name = 'qwen2' + +# new_code = rename_identifiers(source_code, from_list, to_name) +# print(new_code) +# # 预期输出: +# # class Qwen2Model(Qwen2Model): +# # def forward(self, input_ids): +# # return self.qwen2_layer(input_ids) +# # QWEN2_CONFIG = 1 \ No newline at end of file diff --git a/paddlenlp/transformers/convert/until/rewrite_child_classes.py b/paddlenlp/transformers/convert/until/rewrite_child_classes.py new file mode 100644 index 000000000000..f5aca4ee72f2 --- /dev/null +++ b/paddlenlp/transformers/convert/until/rewrite_child_classes.py @@ -0,0 +1,649 @@ +import libcst as cst +from typing import Dict, Optional, List, Set, Union +from libcst import matchers as m +import builtins +import os + +# ============================================================================== +# SECTION 1: 智能类合并引擎 +# ============================================================================== + + +def get_node_code(node: cst.CSTNode) -> str: + """辅助函数,用于获取CST节点的代码字符串,以便比较。""" + return cst.Module(body=[node]).code.strip() + +def merge_parameters( + child_params: cst.Parameters, parent_params: cst.Parameters +) -> cst.Parameters: + """智能合并两个方法的参数列表。""" + child_param_map = {p.name.value: p for p in child_params.params} + + insertion_point = len(child_params.params) + for i, p in enumerate(child_params.params): + if p.star: + insertion_point = i + break + + new_params_from_parent = [] + for p in parent_params.params: + if p.name.value not in child_param_map and p.default is not None: + new_params_from_parent.append(p) + + final_params_list = list(child_params.params) + final_params_list[insertion_point:insertion_point] = new_params_from_parent + + return child_params.with_changes(params=tuple(final_params_list)) + +def _get_class_var_names(class_body: list) -> set: + """从类的 body 中提取所有类变量的名称。""" + var_names = set() + for stmt in class_body: + if m.matches(stmt, m.SimpleStatementLine(body=[m.Assign()])): + assign_node = stmt.body[0] + for target in assign_node.targets: + if isinstance(target.target, cst.Name): + var_names.add(target.target.value) + return var_names + +def merge_parent_class_final( + child_class: cst.ClassDef, parent_class: cst.ClassDef +) -> cst.ClassDef: + """ + 类合并主函数(最终智能版): + - 智能展开super()调用,避免代码冗余。 + - 智能合并方法的参数列表,防止运行时错误。 + - 正确处理类变量和未覆盖方法的继承。 + """ + child_body_list = list(child_class.body.body) + parent_body_map = { + stmt.name.value: stmt + for stmt in parent_class.body.body + if hasattr(stmt, 'name') and isinstance(stmt.name, cst.Name) + } + + final_body = list(child_body_list) + + # 1. 处理被子类覆盖的方法 (包括 __init__) + for i, child_stmt in enumerate(child_body_list): + if not isinstance(child_stmt, cst.FunctionDef): + continue + + method_name = child_stmt.name.value + parent_method = parent_body_map.get(method_name) + + if not parent_method or not isinstance(parent_method, cst.FunctionDef): + continue + + # 1a. 智能展开 super() + child_method_body = list(child_stmt.body.body) + parent_method_body = list(parent_method.body.body) + + super_call_index = -1 + for j, stmt in enumerate(child_method_body): + if m.matches(stmt, m.SimpleStatementLine(body=[m.Expr(value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")))))]) ) \ + or m.matches(stmt, m.Return(value=m.Call(func=m.Attribute(value=m.Call(func=m.Name("super")))))): + super_call_index = j + break + + new_method_body_stmts = child_method_body + if super_call_index != -1: + child_prefix_stmts = child_method_body[:super_call_index] + child_suffix_stmts = child_method_body[super_call_index + 1:] + child_prefix_codes = [get_node_code(s) for s in child_prefix_stmts] + + divergence_index = 0 + for k, parent_stmt in enumerate(parent_method_body): + if k < len(child_prefix_codes) and get_node_code(parent_stmt) == child_prefix_codes[k]: + divergence_index += 1 + else: + break + + parent_suffix_stmts = parent_method_body[divergence_index:] + new_method_body_stmts = child_prefix_stmts + parent_suffix_stmts + child_suffix_stmts + + # 1b. 合并参数列表 + new_params = merge_parameters(child_stmt.params, parent_method.params) + + # 1c. 创建最终的方法节点 + new_body_block = child_stmt.body.with_changes(body=tuple(new_method_body_stmts)) + final_method = child_stmt.with_changes(body=new_body_block, params=new_params) + + final_body[i] = final_method + + # 2. 添加父类中未被覆盖的成员 + child_member_names = {stmt.name.value for stmt in final_body if hasattr(stmt, 'name')} + child_class_var_names = _get_class_var_names(final_body) + + for parent_stmt in parent_class.body.body: + if hasattr(parent_stmt, 'name') and parent_stmt.name.value in child_member_names: + continue + + if m.matches(parent_stmt, m.SimpleStatementLine(body=[m.Assign()])): + parent_var_names = _get_class_var_names([parent_stmt]) + if not parent_var_names.isdisjoint(child_class_var_names): + continue + + final_body.append(parent_stmt) + + # 3. 清理 pass 语句 + pass_matcher = m.SimpleStatementLine(body=[m.Pass()]) + non_pass_statements = [stmt for stmt in final_body if not m.matches(stmt, pass_matcher)] + + if not non_pass_statements: + cleaned_body = (cst.SimpleStatementLine(body=(cst.Pass(),)),) + else: + cleaned_body = tuple(non_pass_statements) + + # 4. 返回最终结果 + return child_class.with_changes( + bases=parent_class.bases, + body=child_class.body.with_changes(body=cleaned_body) + ) + +# ============================================================================== +# SECTION 2:代码重构工具框架 (已集成新逻辑) +# ============================================================================== + +class ComprehensiveRenamer(cst.CSTTransformer): + """智能、大小写敏感地重命名所有匹配的名称。""" + def __init__(self, rename_map: Dict[str, str]): + self.rename_pairs = [] + for from_sub, to_sub in rename_map.items(): + self.rename_pairs.append((from_sub.lower(), to_sub.lower())) + self.rename_pairs.append((from_sub.capitalize(), to_sub.capitalize())) + self.rename_pairs.append((from_sub.upper(), to_sub.upper())) + self.rename_pairs.sort(key=lambda x: len(x[0]), reverse=True) + + def leave_Name(self, original_node: cst.Name, updated_node: cst.Name) -> cst.Name: + for from_name, to_name in self.rename_pairs: + if from_name in original_node.value: + new_value = original_node.value.replace(from_name, to_name) + return updated_node.with_changes(value=new_value) + return updated_node + +def get_base_class_name(base: cst.BaseExpression) -> Optional[str]: + """提取基类名称。""" + if isinstance(base, cst.Name): + return base.value + elif isinstance(base, cst.Attribute): + parts = [] + node = base + while isinstance(node, cst.Attribute): + parts.append(node.attr.value) + node = node.value + if isinstance(node, cst.Name): + parts.append(node.value) + return ".".join(reversed(parts)) + return None + +def find_class_in_source(module_node: cst.Module) -> Optional[cst.ClassDef]: + """从模块节点中提取第一个类定义。""" + for node in module_node.body: + if isinstance(node, cst.ClassDef): + return node + return None + +class DependencyVisitor(cst.CSTVisitor): + """扫描代码以查找所有潜在的外部引用。""" + def __init__(self): + self.scopes: List[Set[str]] = [set()] + self.dependencies: Set[str] = set() + self.builtins = set(dir(builtins)) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + param_names = {p.name.value for p in node.params.params} + self.scopes.append(param_names) + + def leave_FunctionDef(self, original_node: cst.FunctionDef) -> None: + self.scopes.pop() + + def visit_Assign(self, node: cst.Assign) -> None: + for target in node.targets: + if isinstance(target.target, cst.Name): + self.scopes[-1].add(target.target.value) + + def visit_Name(self, node: cst.Name) -> None: + is_local = any(node.value in scope for scope in self.scopes) + if not is_local and node.value not in self.builtins: + self.dependencies.add(node.value) + +def find_usage_dependencies(node: Union[cst.ClassDef, cst.FunctionDef], expanded: Dict[str, str]) -> Set[str]: + """分析节点的CST,找出其使用到的其他实体。""" + visitor = DependencyVisitor() + node.visit(visitor) + return {dep for dep in visitor.dependencies if dep in expanded} + +def get_full_name(node: Union[cst.Name, cst.Attribute, cst.ImportFrom]) -> str: + """ + 从CST节点递归获取完整名称,如 a.b.c 或 ..a.b + """ + if isinstance(node, cst.Name): + return node.value + elif isinstance(node, cst.Attribute): + # 递归获取基础部分 (a.b) + base_name = get_full_name(node.value) + # 拼接当前属性 (.c) + return f"{base_name}.{node.attr.value}" if base_name else node.attr.value + elif isinstance(node, cst.ImportFrom): + # 处理 from ... import ... 语句的模块路径 + module_parts = [] + if node.relative: + module_parts.append("." * len(node.relative)) + if node.module: + module_parts.append(get_full_name(node.module)) + return "".join(module_parts) + return "" + +def filter_specific_modeling_imports( + import_nodes: Union[Dict[str, cst.BaseSmallStatement], List[cst.BaseSmallStatement]] +) -> Dict[str, cst.BaseSmallStatement]: + """ + 【修正版】只移除严格符合 `from ..***.modeling import ...` 模式的导入。 + + 这个版本可以智能处理输入是字典或列表的情况,并且总是返回一个字典。 + """ + kept_imports_dict: Dict[str, cst.BaseSmallStatement] = {} + + # 【核心修正】: 检查输入类型,并确保我们总是遍历 CST 节点 + nodes_to_iterate = [] + if isinstance(import_nodes, dict): + # 如果输入是字典,我们只关心它的值(CST 节点) + nodes_to_iterate = list(import_nodes.values()) + elif isinstance(import_nodes, list): + # 如果输入已经是列表,直接使用 + nodes_to_iterate = import_nodes + + for node in nodes_to_iterate: + should_keep = True + + if isinstance(node, cst.ImportFrom): + is_two_dots_relative = node.relative and len(node.relative) == 2 + + if is_two_dots_relative: + module_path = get_full_name(node.module) if node.module else "" + + if module_path.endswith(".modeling"): + should_keep = False + + if should_keep: + kept_imports_dict[get_node_code(node)] = node + + return kept_imports_dict + +class EntityFinder(cst.CSTVisitor): + """ + A visitor to find the first ClassDef or FunctionDef node in a CST. + """ + def __init__(self): + self.found_node = None + + def visit_ClassDef(self, node: cst.ClassDef) -> bool: + # Found a class, store it and stop searching + if self.found_node is None: + self.found_node = node + return False # Return False to stop traversing deeper + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + # Found a function, store it and stop searching + if self.found_node is None: + self.found_node = node + return False # Return False to stop traversing deeper + +def find_entity_in_source(source_cst_node: cst.Module) -> Optional[cst.CSTNode]: + """ + Parses a CST module to find the first class or function definition. + + Args: + source_cst_node: The parsed Concrete Syntax Tree of the source file. + + Returns: + The found ClassDef or FunctionDef node, or None if not found. + """ + if not isinstance(source_cst_node, cst.Module): + # Ensure we have a valid CST to visit + return None + + finder = EntityFinder() + source_cst_node.visit(finder) + return finder.found_node + +def rewrite_child_classes( + expanded: Dict[str, str], + target_file: str, + template_comment: str, + output_file: str, + rename_map: Optional[Dict[str, str]] = None +): + """完整的类重写工具 (已集成VFinal版合并引擎)。""" + if rename_map is None: rename_map = {} + + # --- 阶段一 & 二:解析代码 --- + print("阶段一:正在预解析所有父类代码...") + parsed_expanded: Dict[str, cst.Module] = {} + imports_to_inject: Dict[str, cst.BaseSmallStatement] = {} + for name, source in expanded.items(): + try: + module_node = cst.parse_module(source) + parsed_expanded[name] = module_node + for node in module_node.body: + if m.matches(node, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + imports_to_inject[module_node.code_for_node(node)] = node + except Exception as e: + print(f"警告:预解析 {name} 失败: {e}") + + print("\n阶段二:正在分析目标文件...") + with open(target_file, "r", encoding="utf-8") as f: + module = cst.parse_module(f.read()) + + imports_from_target: Dict[str, cst.SimpleStatementLine] = {} + body_statements: List[cst.BaseStatement] = [] + for stmt in module.body: + # 匹配导入语句 + if m.matches(stmt, m.SimpleStatementLine(body=[m.Import() | m.ImportFrom()])): + imports_from_target[module.code_for_node(stmt)] = stmt + + # 匹配 try-except 块(通常用于可选导入) + elif isinstance(stmt, cst.Try): + imports_from_target[module.code_for_node(stmt)] = stmt + + # 匹配 __all__ 定义 + elif m.matches(stmt, m.SimpleStatementLine(body=[m.Assign(targets=[m.AssignTarget(target=m.Name("__all__"))])])): + imports_from_target[module.code_for_node(stmt)] = stmt + + # 其他语句放入主体 + else: + body_statements.append(stmt) + imports_from_target=filter_specific_modeling_imports(imports_from_target) + # --- 阶段三 & 四:依赖分析与合并 --- + nodes_to_inject: Dict[str, Union[cst.ClassDef, cst.FunctionDef]] = {} + existing_names: Set[str] = {stmt.name.value for stmt in body_statements if hasattr(stmt, 'name')} + visiting: Set[str] = set() + + def collect_dependencies(name: str): + # 1. 边界检查 (完全不变) + # 无论是类还是函数,这些检查(是否已解析、已收集、已存在、正在访问)都同样适用。 + if name not in parsed_expanded or name in nodes_to_inject or name in existing_names or name in visiting: + return + + # 2. 查找实体节点 (需要泛化) + # find_entity_in_source 现在可以返回 ClassDef 或 FunctionDef 节点。 + entity_node = find_entity_in_source(parsed_expanded[name]) + if not entity_node: + return + + # 3. 标记正在访问 (完全不变) + visiting.add(name) + + # 4. 处理类特有的依赖:继承 (只对类执行) + # 如果实体是类,才处理其父类依赖。函数没有继承,会自然跳过此块。 + if isinstance(entity_node, cst.ClassDef): + for base in entity_node.bases: + if base_name := get_base_class_name(base.value): + collect_dependencies(base_name) + + # 5. 处理通用依赖:使用关系 (对类和函数都执行) + # 这里的 `find_usage_dependencies` 函数也必须是通用的, + # 它需要能解析类和函数体内的依赖。 + # - 对于类: 查找成员变量的类型注解等。 + # - 对于函数: 查找参数的类型注解、返回值的类型注解、函数体内调用的其他函数、实例化的类等。 + for dep_name in find_usage_dependencies(entity_node, expanded): + collect_dependencies(dep_name) + + # 6. 完成处理,加入结果集 (完全不变) + # 无论是类还是函数,都在其所有依赖项被处理完毕后,才将自身加入结果集。 + visiting.remove(name) + nodes_to_inject[name] = entity_node + print("\n阶段三:正在进行全局依赖扫描...") + for stmt in body_statements: + if isinstance(stmt, cst.ClassDef): + for base in stmt.bases: + if base_name := get_base_class_name(base.value): + collect_dependencies(base_name) + for dep_name in find_usage_dependencies(stmt, expanded): + collect_dependencies(dep_name) + + print("\n阶段四:正在执行类合并操作...") + processed_body_statements = [] + merged_parents: Set[str] = set() + for stmt in body_statements: + if isinstance(stmt, cst.ClassDef) and stmt.bases: + if base_name := get_base_class_name(stmt.bases[0].value): + if base_name in parsed_expanded: + parent_module = parsed_expanded[base_name] + if parent_class_node := find_class_in_source(parent_module): + print(f" > 正在合并 {base_name} -> {stmt.name.value}...") + # <<<--- ★★★核心修改点:调用新的合并函数★★★ + stmt = merge_parent_class_final(stmt, parent_class_node) + merged_parents.add(base_name) + processed_body_statements.append(stmt) + + # --- 阶段五:按正确顺序重新组装文件 --- + print("\n阶段五:正在生成最终文件...") + + nodes_to_inject_after_merge = {k: v for k, v in nodes_to_inject.items() if k not in merged_parents} + main_defined_names = {stmt.name.value for stmt in processed_body_statements if hasattr(stmt, 'name')} + + print(" > 正在应用智能重命名规则并检测冲突...") + final_nodes_to_inject = {} + renamer = ComprehensiveRenamer(rename_map) + + for original_name, node in nodes_to_inject_after_merge.items(): + renamed_node = node.visit(renamer) + new_name = renamed_node.name.value + if new_name in main_defined_names: + print(f" - 检测到主代码中已存在 '{new_name}',将跳过注入 '{original_name}'") + continue + print(f" - 正在处理依赖 '{original_name}'...") + final_nodes_to_inject[new_name] = renamed_node + + final_imports = {**imports_from_target, **imports_to_inject} + new_body = [] + new_header = [] + #加转换注释 + for line in template_comment.splitlines(): + stripped_line = line.strip() + if stripped_line: + comment_node = cst.Comment(stripped_line) + new_header.append(cst.EmptyLine( + comment=comment_node, + indent=True, + whitespace=cst.SimpleWhitespace(value="") + )) + for item in module.header: + if isinstance(item, cst.EmptyLine) and item.comment: + new_header.append(item) + elif isinstance(item, cst.TrailingWhitespace) and item.comment: + new_header.append(item) + + if final_imports: + unique_imports = {module.code_for_node(n): n for n in final_imports.values()} + new_body.extend(unique_imports.values()) + + injected_items = sorted(final_nodes_to_inject.values(), key=lambda n: n.name.value) + # 2. 分类依赖项:方法和类 + methods_to_inject = [] + classes_to_inject = [] + for node in injected_items: + if isinstance(node, cst.FunctionDef): + print(node.name.value) + methods_to_inject.append(node) + elif isinstance(node, cst.ClassDef): + classes_to_inject.append(node) + else: + print(f"警告:遇到未知类型的节点,无法分类: {type(node.name.value)}") + # 3. 注入方法(放在 imports 之后,主逻辑之前) + if methods_to_inject: + new_body.extend([cst.EmptyLine(), cst.EmptyLine(comment=cst.Comment("# --- Injected Methods ---"))]) + new_body.extend(methods_to_inject) + # 4. 处理类的注入顺序 + # 分组:有父类在主逻辑中的类 vs 没有的 + classes_with_parent_in_main = [] + classes_without_parent_in_main = [] + if classes_to_inject: + # 获取主逻辑中的所有类名 + main_classes = {stmt.name.value for stmt in processed_body_statements if isinstance(stmt, cst.ClassDef)} + + + + for cls_node in classes_to_inject: + has_parent_in_main = False + if isinstance(cls_node, cst.ClassDef) and cls_node.bases: + for base in cls_node.bases: + if base_name := get_base_class_name(base.value): + if base_name in main_classes: + has_parent_in_main = True + break + + if has_parent_in_main: + classes_with_parent_in_main.append(cls_node) + else: + classes_without_parent_in_main.append(cls_node) + + # 4.1 先注入没有父类依赖的类(放在 imports 之后) + if classes_without_parent_in_main: + new_body.extend([cst.EmptyLine(), cst.EmptyLine(comment=cst.Comment("# --- Injected Classes ---"))]) + new_body.extend(classes_without_parent_in_main) + + + # 4. 动态遍历主逻辑,在父类定义后插入其子类 + if processed_body_statements: + # 4.1 收集所有主逻辑的类名 + classes_with_parent_in_main = { + cls for cls in classes_with_parent_in_main + if isinstance(cls, cst.ClassDef) + } + + # 4.2 按顺序处理主逻辑的语句 + for stmt in processed_body_statements: + new_body.append(stmt) + + # 如果是类定义,检查是否有子类需要注入 + if isinstance(stmt, cst.ClassDef): + parent_name = stmt.name.value + # 查找依赖此父类的子类 + child_classes = [ + cls for cls in classes_with_parent_in_main + if any( + get_base_class_name(base.value) == parent_name + for base in cls.bases + ) + ] + # 注入子类 + if child_classes: + new_body.extend([ + cst.EmptyLine(), + cst.EmptyLine(comment=cst.Comment(f"# --- Children of {parent_name} ---")), + *child_classes + ]) + # 从待注入列表中移除已处理的子类 + classes_with_parent_in_main = [ + cls for cls in classes_with_parent_in_main + if cls not in child_classes + ] + +# 5. 注入剩余未处理的依赖主逻辑的类(可能是跨文件的依赖) + if classes_with_parent_in_main: + new_body.extend([cst.EmptyLine(), cst.EmptyLine(comment=cst.Comment("# --- Remaining Injected Child Classes ---"))]) + new_body.extend(classes_with_parent_in_main) + + """ + if injected_items: + new_body.extend([cst.EmptyLine(), cst.EmptyLine(comment=cst.Comment("# --- Injected Dependencies ---"))]) + new_body.extend(injected_items) + + if processed_body_statements: + new_body.extend([cst.EmptyLine(), cst.EmptyLine(comment=cst.Comment("# --- Main Application Logic ---"))]) + new_body.extend(processed_body_statements) + """ + new_module = module.with_changes( + header=tuple(new_header), # 使用新的头部注释 + body=tuple(new_body) # 使用新的主体内容 +) + with open(output_file, "w", encoding="utf-8") as f: + f.write(new_module.code) + + print(f"\n成功生成合并后的文件: {output_file}") + +# ============================================================================== +# SECTION 3: 演示 +# ============================================================================== +if __name__ == "__main__": + + # --- 步骤1: 准备演示环境 --- + # 创建一个虚拟的 child_class.py 文件供脚本读取 + child_class_content = """ +class MyChildClass(ParentClass): + def __init__(self, config, child_param): + # 与父类重复的语句 + if config.flag: + self.param1 = config.param1 + else: + self.param1 = config.default_param1 + + # 调用super + super().__init__(config) + + # 新增的属性和逻辑 + self.child_param = child_param + print("Child class logic executed.") + + def child_method(self): + return "子类方法" +""" + with open("child_class.py", "w", encoding="utf-8") as f: + f.write(child_class_content) + + # --- 步骤2: 定义父类和祖父类源代码 --- + expanded_parents = { + "ParentClass": ''' +class ParentClass(GrandParentClass): + def __init__(self, config): + # 条件语句 + if config.flag: + self.param1 = config.param1 + else: + self.param1 = config.default_param1 + + # 循环语句 + for i in range(5): + self.param2 = i + + # 方法调用 + self.initialize(config) + + # super调用(指向祖父类) + super().__init__() + + def initialize(self, config): + self.param3 = config.param3 + + def parent_method(self): + return "父类方法" +''', + "GrandParentClass": ''' +class GrandParentClass: + def __init__(self): + self.grand_param = "祖父参数" + + def grand_method(self): + return "祖父方法" +''' + } + + # --- 步骤3: 运行重写工具 --- + print("--- 开始运行代码重写工具 ---") + rewrite_child_classes( + expanded=expanded_parents, + target_file="child_class.py", + output_file="merged_class.py" + ) + + # --- 步骤4: 打印结果 --- + print("\n--- 查看生成的 merged_class.py 文件 ---") + with open("merged_class.py", "r", encoding="utf-8") as f: + print(f.read()) + + # --- 步骤5: 清理 --- + os.remove("child_class.py") + os.remove("merged_class.py") \ No newline at end of file diff --git a/paddlenlp/transformers/qwen2/configuration.py b/paddlenlp/transformers/qwen2/configuration.py index c076857647ce..19fdcfdf12aa 100644 --- a/paddlenlp/transformers/qwen2/configuration.py +++ b/paddlenlp/transformers/qwen2/configuration.py @@ -16,10 +16,53 @@ from ..configuration_utils import PretrainedConfig + __all__ = [ + "QWEN2_PRETRAINED_INIT_CONFIGURATION", "Qwen2Config", + "QWEN2_PRETRAINED_RESOURCE_FILES_MAP", ] - +QWEN2_PRETRAINED_INIT_CONFIGURATION = { + # Hypothetical model weights (tiny-random-llama & micro-random-llama) for test only + "__internal_testing__/micro-random-llama": { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 64, + "initializer_range": 0.02, + "intermediate_size": 1000, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 8, + "num_hidden_layers": 1, + "rms_norm_eps": 1e-06, + "vocab_size": 32000, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + }, + "__internal_testing__/tiny-random-llama": { + "architectures": ["LlamaForCausalLM"], + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 2048, + "model_type": "llama", + "num_attention_heads": 8, + "num_hidden_layers": 2, + "rms_norm_eps": 1e-06, + "vocab_size": 32000, + "bos_token_id": 1, + "eos_token_id": 2, + "pad_token_id": 0, + }, +} + +# Hypothetical model weights (tiny-random-llama) for test only +QWEN2_PRETRAINED_RESOURCE_FILES_MAP = { + "model_state": { + "__internal_testing__/micro-random-llama": "https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/micro-random-llama/model_state.pdparams", + "__internal_testing__/tiny-random-llama": "https://bj.bcebos.com/paddlenlp/models/community/__internal_testing__/tiny-random-llama/model_state.pdparams", + }, +} class Qwen2Config(PretrainedConfig): r""" @@ -113,6 +156,9 @@ def __init__( use_sliding_window=False, sliding_window=4096, max_window_layers=28, + use_flash_attention_for_generation=False, + alibi=False, + use_last_token_for_generation=False, attention_bias=True, attention_dropout=0.0, rope_scaling_factor=1.0, @@ -153,7 +199,10 @@ def __init__( self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.dpo_config = dpo_config + self.use_flash_attention_for_generation = use_flash_attention_for_generation + self.alibi = alibi self.use_fused_head_and_loss_fn = use_fused_head_and_loss_fn + self.use_last_token_for_generation = use_last_token_for_generation super().__init__( pad_token_id=pad_token_id, diff --git a/paddlenlp/transformers/qwen2/modeling_qwen2.py b/paddlenlp/transformers/qwen2/modeling_qwen2.py new file mode 100644 index 000000000000..4a38139353e7 --- /dev/null +++ b/paddlenlp/transformers/qwen2/modeling_qwen2.py @@ -0,0 +1,2243 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from transformers/qwen2/modular_qwen2.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_qwen2.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code ijins based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import warnings +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute + +from ...utils.tools import get_env_device +from .. import linear_utils +from ..activations import ACT2FN +from ..contrastive_loss import SimpleContrastiveLoss +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..embedding_utils import dist_gather_tensor_with_gradient +from ..linear_utils import Linear +from ..model_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ..model_utils import PretrainedModel, register_base_model +from ..refined_recompute import ( + RRColumnParallelLinear, + RRColumnSequenceParallelLinear, + RRRowParallelLinear, + RRRowSequenceParallelLinear, + get_skip_recompute_ops, +) +from ..refined_recompute import recompute as rr_recompute +from ..utils import caculate_llm_per_token_flops, logger +from .configuration import Qwen2Config + + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None +__all__ = [ + "Qwen2Model", + "Qwen2PretrainedModel", + "Qwen2ForCausalLM", + "Qwen2PretrainingCriterion", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2SentenceEmbedding", +] +from . import fusion_ops +from ...utils.log import logger +from ...utils.tools import get_env_device +from ..conversion_utils import split_or_fuse_func +from ..conversion_utils import split_or_merge_func +from ..long_sequence_strategies import LongSequenceStrategies +from ..model_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, +) +from ..segment_parallel_utils import ReshardLayer +from ..utils import caculate_llm_per_token_flops +from .configuration import ( + QWEN2_PRETRAINED_INIT_CONFIGURATION, + QWEN2_PRETRAINED_RESOURCE_FILES_MAP, + Qwen2Config, +) +from paddle.autograd import PyLayer +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, +) +from paddle.incubate.nn.functional import fused_rotary_position_embedding +from paddle.incubate.nn.functional import swiglu +from paddle.nn.functional.flash_attention import flash_attention +from typing import Optional, Tuple +import math +import numpy as np +import os +import paddle + +# --- Injected Methods --- +def _expand_2d_mask(mask, dtype, tgt_length): + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape[0], mask.shape[-1] + tgt_length = tgt_length if tgt_length is not None else src_length + + if get_env_device() == "npu" or get_env_device() == "mlu": + mask = mask[:, None, None, :].astype(dtype) + else: + mask = mask[:, None, None, :].astype("bool") + mask.stop_gradient = True + expanded_mask = mask.expand([batch_size, 1, tgt_length, src_length]) + + return expanded_mask +def _get_interleave(n): + def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(np.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if np.log2(n).is_integer(): + return _get_interleave_power_of_2(n) + else: + closest_power_of_2 = int(2 ** np.floor(np.log2(n))) + return ( + _get_interleave_power_of_2(closest_power_of_2) + + _get_interleave(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) +def _get_interleave_power_of_2(n): + start = 2 ** (-(2 ** -(np.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] +def _make_causal_mask(input_ids_shape, past_key_values_length): + """ + Make casual mask used for self-attention + """ + batch_size, target_length = input_ids_shape # target_length: seq_len + + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + mask = paddle.tril(paddle.ones((target_length, target_length))).astype("int32") + else: + mask = paddle.tril(paddle.ones((target_length, target_length), dtype="bool")) + + if past_key_values_length > 0: + # [tgt_len, tgt_len + past_len] + mask = paddle.concat([paddle.ones([target_length, past_key_values_length], dtype="bool"), mask], axis=-1) + + # [bs, 1, tgt_len, tgt_len + past_len] + return mask[None, None, :, :].expand([batch_size, 1, target_length, target_length + past_key_values_length]) +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + + if position_ids is None: + # Note: Only for LlamaForCausalLMPipe model pretraining + cos = cos[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + sin = sin[:, : q.shape[1], :, :] # [bs, seq_len, 1, dim] + else: + cos = cos.squeeze(axis=[0, 2]) # [seq_len, dim] + sin = sin.squeeze(axis=[0, 2]) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed +def assign_kv_heads(num_kv_heads: int, num_gpus: int): + # Initialize the assignment list + """ + Assign kv heads to different GPUs in the Tensor Parallel Setup + + Examples: + assign_kv_heads(num_kv_heads=1, num_gpus=2): [[0], [0]] + assign_kv_heads(num_kv_heads=2, num_gpus=2): [[0], [1]] + assign_kv_heads(num_kv_heads=4, num_gpus=2): [[0,1], [2,3]] + assign_kv_heads(num_kv_heads=1, num_gpus=4): [[0],[0],[0],[0]] + assign_kv_heads(num_kv_heads=2, num_gpus=4): [[0],[0],[1],[1]] + assign_kv_heads(num_kv_heads=4, num_gpus=4): [[0],[1],[2],[3]] + """ + assignment_list = [[] for _ in range(num_gpus)] + # Case 1: more heads than cards + if num_kv_heads > num_gpus: + num_heads_per_card = num_kv_heads // num_gpus + for i in range(num_gpus): + for j in range(num_heads_per_card): + assignment_list[i].append(i * num_heads_per_card + j) + # Case 2: more cards than heads. each card get only 1 head. + else: + num_card_per_heads = num_gpus // num_kv_heads + for i in range(num_kv_heads): + for j in range(num_card_per_heads): + assignment_list[i * num_card_per_heads + j].append(i) + return assignment_list +def build_alibi_tensor( + bool_attention_mask: Tensor, num_heads: int, dtype: paddle.dtype, tensor_parallel_degree=1 +) -> Tensor: + batch_size, seq_length = bool_attention_mask.shape[0], bool_attention_mask.shape[-1] + slopes = paddle.to_tensor(_get_interleave(num_heads), dtype="float32") + alibi = slopes.unsqueeze(axis=[1, 2]) * paddle.arange(seq_length, dtype="float32").unsqueeze(axis=[0, 1]).expand( + [num_heads, -1, -1] + ) + alibi = alibi.reshape(shape=(1, num_heads, 1, seq_length)).expand([batch_size, -1, -1, -1]) + return paddle.cast(alibi, dtype) +def get_use_casual_mask(): + """Get the value of the 'USE_CASUAL_MASK' environment variable.""" + return os.getenv("USE_CASUAL_MASK", "False") == "True" +def is_casual_mask(attention_mask): + """ + Upper triangular of attention_mask equals to attention_mask is casual + """ + return (paddle.triu(attention_mask) == attention_mask).all().item() +def parallel_matmul(x: Tensor, y: Tensor, transpose_y=False, tensor_parallel_output=True): + is_fleet_init = True + tensor_parallel_degree = 1 + try: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() + tensor_parallel_degree = hcg.get_model_parallel_world_size() + except: + is_fleet_init = False + + if paddle.in_dynamic_mode(): + y_is_distributed = y.is_distributed + else: + y_is_distributed = tensor_parallel_degree > 1 + + if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed: + # if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg' + input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group) + logits = paddle.matmul(input_parallel, y, transpose_y=transpose_y) + + if tensor_parallel_output: + return logits + + return paddle.distributed.collective._c_concat(logits, group=model_parallel_group) + + else: + logits = paddle.matmul(x, y, transpose_y=transpose_y) + return logits +def repeat_kv(hidden_states: paddle.Tensor, n_rep: int) -> paddle.Tensor: + """ + This is the equivalent of paddle.repeat_interleave(hidden_states, n_rep, axis=1). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, slen, num_key_value_heads, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + + hidden_states = hidden_states.unsqueeze(-2).tile([1, 1, 1, n_rep, 1]) + return hidden_states.reshape([batch, slen, num_key_value_heads * n_rep, head_dim]) +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return paddle.concat([-x2, x1], axis=-1) # shape is the same as x +def swiglu(x, y=None): + if y is None: + x, y = paddle.chunk(x, chunks=2, axis=-1) + return F.silu(x) * y + +# --- Injected Classes --- +class ConcatMaskedLoss(PyLayer): + @staticmethod + def forward(ctx, inp, axis, group): + inputs = [] + paddle.distributed.all_gather(inputs, inp, group=group) + with paddle.no_grad(): + cat = paddle.concat(inputs, axis=axis) + ctx.args_axis = axis + ctx.args_group = group + return cat + + @staticmethod + def backward(ctx, grad): + axis = ctx.args_axis + group = ctx.args_group + with paddle.no_grad(): + grads = paddle.split(grad, paddle.distributed.get_world_size(group), axis=axis) + grad = grads[paddle.distributed.get_rank(group)] + return grad +"""Paddle Qwen2 model.""" + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=None, + training=True, + sequence_parallel=False, + skip_recompute=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + return fusion_ops.fusion_flash_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + sequence_parallel=sequence_parallel, + skip_recompute=skip_recompute, + ) + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next transpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # Add pre divided factor to fix nan under float16. + if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16: + pre_divided_factor = 32 + else: + pre_divided_factor = 1 + + attn_weights = paddle.matmul( + query_states / (math.sqrt(head_dim) * pre_divided_factor), key_states.transpose([0, 1, 3, 2]) + ) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights * pre_divided_factor, axis=-1, dtype="float32").astype( + query_states.dtype + ) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax( + attn_weights.astype("float32") * pre_divided_factor, axis=-1, dtype="float32" + ).astype(query_states.dtype) + + attn_weights = F.dropout(attn_weights, p=config.attention_dropout, training=training) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +class Qwen2RMSNorm(nn.Layer): + """Qwen2的RMSNorm,继承自LlamaRMSNorm""" + def __init__(self, config: Qwen2Config): + super().__init__() + self.hidden_size = config.hidden_size + self.weight = paddle.create_parameter( + shape=[self.hidden_size], + dtype=paddle.get_default_dtype(), + default_initializer=nn.initializer.Constant(1.0), + ) + self.variance_epsilon = config.rms_norm_eps + self.config = config + + if config.sequence_parallel: + mark_as_sequence_parallel_parameter(self.weight) + + def forward(self, hidden_states): + if self.config.use_fused_rms_norm: + return fusion_ops.fusion_rms_norm(hidden_states, self.weight, self.variance_epsilon) + + if paddle.in_dynamic_mode(): + with paddle.amp.auto_cast(False): + # hidden_states = hidden_states.astype("float32") + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + variance = hidden_states.astype("float32").pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + else: + hidden_states = hidden_states.astype("float32") + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = paddle.rsqrt(variance + self.variance_epsilon) * hidden_states + + if self.weight.dtype in [paddle.float16, paddle.bfloat16]: + hidden_states = paddle.cast(hidden_states, self.weight.dtype) + return hidden_states * self.weight +class Qwen2RotaryEmbedding(nn.Layer): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__() + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + # [dim / 2] + self.inv_freq = 1.0 / (self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + self._set_cos_sin_cache(seq_len=max_position_embeddings) + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + if self.inv_freq.dtype != paddle.float32: + self.inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim) + ) + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + if get_env_device() == "intel_hpu": + # fallback einsum to intel Gaudi TPC since MME doesn't support FP32 + freqs = t.unsqueeze(1) * self.inv_freq.unsqueeze(0) + else: + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + # x: [bs, num_attention_heads, seq_len, head_size] + if self.cos_cached.dtype != x.dtype and get_env_device() == "intel_hpu": + self.cos_cached = self.cos_cached.cast(x.dtype) + self.sin_cached = self.sin_cached.cast(x.dtype) + cos = self.cos_cached[:, :seq_len, :, :] + sin = self.sin_cached[:, :seq_len, :, :] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + def get_fused_cos_sin(self, x, seq_len=None): + if self.cos_sin_table is not None and self.cos_sin_table.dtype != x.dtype: + return self.cos_sin_table.cast(x.dtype) + else: + return self.cos_sin_table + +# --- Children of Qwen2RotaryEmbedding --- +class Qwen23RotaryEmbedding(Qwen2RotaryEmbedding): + def __init__( + self, + dim, + max_position_embeddings=8192, + base=500000, + factor=8.0, + low_freq_factor=1.0, + high_freq_factor=4.0, + original_max_position_embeddings=8192, + ): + self.factor = factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.original_max_position_embeddings = original_max_position_embeddings + super().__init__(dim, max_position_embeddings, base) + + def _set_cos_sin_cache(self, seq_len): + low_freq_wavelen = self.original_max_position_embeddings / self.low_freq_factor + high_freq_wavelen = self.original_max_position_embeddings / self.high_freq_factor + new_freqs = [] + for freq in self.inv_freq: + wavelen = 2 * math.pi / freq + if wavelen < high_freq_wavelen: + new_freqs.append(freq) + elif wavelen > low_freq_wavelen: + new_freqs.append(freq / self.factor) + else: + assert low_freq_wavelen != high_freq_wavelen + smooth = (self.original_max_position_embeddings / wavelen - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + new_freqs.append((1 - smooth) * freq / self.factor + smooth * freq) + self.inv_freq = paddle.to_tensor(new_freqs, dtype=self.inv_freq.dtype) + super()._set_cos_sin_cache(seq_len=seq_len) +class Qwen2DynamicNTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base) + + def _scale_cos_sin(self, seq_len): + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + # [seq_len, dim/2] + alpha = (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + base = self.base * alpha ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim)) + freqs = paddle.einsum("i,j->ij", t, inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + scale_cos = emb.cos()[None, :, None, :] + scale_sin = emb.sin()[None, :, None, :] + scale_cos_sin = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) + return scale_cos, scale_sin, scale_cos_sin + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_position_embeddings: + scale_cos, scale_sin, _ = self._scale_cos_sin(seq_len=seq_len) + else: + scale_cos, scale_sin = self.cos_cached, self.sin_cached + cos = scale_cos[:, :seq_len, :, ...] + sin = scale_sin[:, :seq_len, :, ...] + return ( + cos.cast(x.dtype) if cos.dtype != x.dtype else cos, + sin.cast(x.dtype) if sin.dtype != x.dtype else sin, + ) + + def get_fused_cos_sin(self, x, seq_len=None): + if seq_len > self.max_position_embeddings: + _, _, scale_cos_sin = self._scale_cos_sin(seq_len=seq_len) + else: + scale_cos_sin = self.cos_sin_table + if scale_cos_sin is not None and scale_cos_sin.dtype != x.dtype: + return scale_cos_sin.cast(x.dtype) + else: + return scale_cos_sin +class Qwen2LinearScalingRotaryEmbedding(Qwen2RotaryEmbedding): + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings * scaling_factor, base) + + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + # [seq_len] + t = paddle.arange(seq_len, dtype="float32") + t = t / self.scaling_factor + # [seq_len, dim/2] + if get_env_device() == "intel_hpu": + # fallback einsum to intel Gaudi TPC since MME doesn't support FP32 + freqs = t.unsqueeze(1) * self.inv_freq.unsqueeze(0) + else: + freqs = paddle.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + # [seq_len, dim] + emb = paddle.concat([freqs, freqs], axis=-1) + # [1, seqlen, 1, dim] + self.cos_cached = emb.cos()[None, :, None, :] + self.sin_cached = emb.sin()[None, :, None, :] + self.cos_sin_table = None if get_env_device() != "gcu" else paddle.concat([freqs.cos(), freqs.sin()], axis=-1) +class Qwen2NTKScalingRotaryEmbedding(Qwen2RotaryEmbedding): + """LlamaRotaryEmbedding extended with NTK scaling. https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, scaling_factor=1.0): + base = base * scaling_factor ** (dim / (dim - 2)) + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings * scaling_factor, base) + +class Qwen2MLP(nn.Layer): + """Qwen2的MLP,继承自LlamaMLP""" + def __init__(self, config: Qwen2Config,is_shared=False, skip_recompute_ops=None): + super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} + self.skip_recompute_ops = skip_recompute_ops + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.tensor_parallel_degree = config.tensor_parallel_degree + self.fuse_attention_ffn = config.fuse_attention_ffn + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if skip_recompute_ops.get("mlp_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if skip_recompute_ops.get("mlp_row_ln", False): + RowParallelLinear = RRRowParallelLinear + + if config.tensor_parallel_degree > 1: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size * 2, + gather_output=False, + has_bias=False, + ) + else: + self.gate_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + self.up_proj = ColumnParallelLinear( + self.hidden_size, + self.intermediate_size, + gather_output=False, + has_bias=False, + ) + + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + input_is_parallel=True, + has_bias=False, + ) + else: + if config.fuse_attention_ffn: + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + else: + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + if config.hidden_act == "silu": + self.act_fn = fusion_ops.swiglu + self.fuse_swiglu = True + else: + self.act_fn = ACT2FN[config.hidden_act] + self.fuse_swiglu = False + # Qwen2的MLP结构与Llama相同,但使用不同的配置 + def forward(self, x): + if self.fuse_attention_ffn: + x = self.gate_up_fused_proj(x) + if self.fuse_swiglu: + y = None + else: + x, y = x.chunk(2, axis=-1) + else: + x, y = self.gate_proj(x), self.up_proj(x) + + if self.fuse_swiglu: + x = self.act_fn(x, y) + else: + x = self.act_fn(x) * y + + return self.down_proj(x) + +class Qwen2Attention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_recompute_ops=None): + super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} + self.config = config + self.skip_recompute_ops = skip_recompute_ops + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + # self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + self.has_bias = config.attention_bias + self.fuse_attention_qkv = config.fuse_attention_qkv + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if get_env_device() not in ["gpu", "xpu"] or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowParallelLinear + + if config.tensor_parallel_degree > 1: + if self.fuse_attention_qkv: + self.qkv_proj = ColumnParallelLinear( + self.hidden_size, + self.num_attention_heads * self.head_dim + 2 * self.config.num_key_value_heads * self.head_dim, + has_bias=self.has_bias, + gather_output=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_attention_heads * self.head_dim, + has_bias=self.has_bias, + gather_output=False, + ) + self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=self.has_bias, gather_output=False) # fmt:skip + self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=self.has_bias, gather_output=False) # fmt:skip + self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True) + else: + if self.fuse_attention_qkv: + self.qkv_proj = Linear( + self.hidden_size, + self.num_attention_heads * self.head_dim + 2 * self.config.num_key_value_heads * self.head_dim, + ) + else: + self.q_proj = Linear( + self.hidden_size, self.num_attention_heads * self.head_dim, bias_attr=self.has_bias + ) + self.k_proj = Linear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=self.has_bias + ) + self.v_proj = Linear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=self.has_bias + ) + self.o_proj = Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias_attr=False) + + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.attn_func = scaled_dot_product_attention + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant and skip_recompute_ops.get("flash_attn", False): + self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + batch_size: Optional[int] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + + if self.fuse_attention_qkv: + mix_layer = self.qkv_proj(hidden_states) + if self.sequence_parallel: + target_shape = [ + batch_size, + -1, + self.num_key_value_heads, + (self.num_key_value_groups + 2) * self.head_dim, + ] + else: + target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim] + mix_layer = paddle.reshape_(mix_layer, target_shape) + query_states, key_states, value_states = paddle.split( + mix_layer, + num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim], + axis=-1, + ) + if self.gqa_or_mqa: + query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim]) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.sequence_parallel: + target_query_shape = [batch_size, -1, self.num_heads, self.head_dim] + target_key_value_shape = [batch_size, -1, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + if position_ids is not None and not self.use_fused_rope: + kv_seq_len = position_ids.max().item() + 1 + else: + kv_seq_len = key_states.shape[-3] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + paddle_version = float(paddle.__version__[:3]) + if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( + self.attn_func, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + training=self.training, + sequence_parallel=self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.attn_func( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + training=self.training, + sequence_parallel=self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2DecoderLayer(nn.Layer): + """Qwen2的解码器层,继承自LlamaDecoderLayer""" + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompute_ops=None): + super().__init__() + self.config = config + if skip_recompute_ops is None: + skip_recompute_ops = {} + self.skip_recompute_ops = skip_recompute_ops + self.hidden_size = config.hidden_size + self.self_attn = Qwen2Attention(config, layerwise_recompute, skip_recompute_ops=skip_recompute_ops) + self.mlp = Qwen2MLP(config, skip_recompute_ops=skip_recompute_ops) + self.input_layernorm = Qwen2RMSNorm(config) + self.post_attention_layernorm = Qwen2RMSNorm(config) + self.sequence_parallel = config.sequence_parallel + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + batch_size: Optional[int] = None, + alibi: Optional[paddle.Tensor] = None, + npu_is_casual: bool = False, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + attn_mask_startend_row_indices, + batch_size, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + batch_size=batch_size, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2PretrainedModel(PretrainedModel): + """Qwen2预训练模型基类,继承自LlamaPretrainedModel""" + config_class = Qwen2Config + base_model_prefix = "qwen2" + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + @classmethod + def _get_name_mappings(cls, config: Qwen2Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_proj.bias", None], + [f"layers.{layer_index}.self_attn.k_proj.bias", None], + [f"layers.{layer_index}.self_attn.v_proj.bias", None], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "Qwen2MoEModel" + if "Qwen2Model" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "qwen2." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True): + + from ..conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ), + ( + "layers.0.self_attn.q_proj.bias", + "layers.0.self_attn.k_proj.bias", + "layers.0.self_attn.v_proj.bias", + "layers.0.self_attn.qkv_proj.bias", + ), + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.gate_up_fused_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + + def _get_model_flops(self): + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_per_token_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def _get_hardware_flops(self): + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_per_token_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=self.config.recompute, + recompute_granularity=self.config.recompute_granularity, + ) + pretrained_init_configuration = QWEN2_PRETRAINED_INIT_CONFIGURATION + pretrained_resource_files_map = QWEN2_PRETRAINED_RESOURCE_FILES_MAP + + def _init_weights(self, layer): + """Initialization hook""" + if self.config.tensor_parallel_degree > 1: + rng_tracker = get_rng_state_tracker().rng_state + if isinstance( + layer, + ( + nn.Linear, + nn.Embedding, + mpu.VocabParallelEmbedding, + mpu.RowParallelLinear, + mpu.ColumnParallelLinear, + linear_utils.RowSequenceParallelLinear, + linear_utils.ColumnSequenceParallelLinear, + Qwen2LMHead, + ), + ): + # In the dygraph mode, use the `set_value` to reset the parameter directly, + # and reset the `state_dict` to update parameter in static mode. + if isinstance(layer.weight, paddle.Tensor): + if layer.weight.is_distributed: + with rng_tracker(): + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.qwen2.config.initializer_range, + shape=layer.weight.shape, + ) + ) + else: + layer.weight.set_value( + paddle.tensor.normal( + mean=0.0, + std=self.config.initializer_range + if hasattr(self.config, "initializer_range") + else self.qwen2.config.initializer_range, + shape=layer.weight.shape, + ) + ) + # Layer.apply is DFS https://github.com/PaddlePaddle/Paddle/blob/a6f5021fcc58b21f4414bae6bf4731ef6971582c/python/paddle/nn/layer/layers.py#L527-L530 + # sublayer is init first + # scale RowParallelLinear weight + with paddle.no_grad(): + if isinstance(layer, Qwen2MLP): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.down_proj.weight.scale_(factor) + if isinstance(layer, Qwen2Attention): + factor = 1 / math.sqrt(2 * self.config.num_hidden_layers) + layer.o_proj.weight.scale_(factor) + + +@register_base_model +class Qwen2Model(Qwen2PretrainedModel): + """Qwen2模型,继承自LlamaModel""" + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.sequence_parallel = config.sequence_parallel + self.recompute_granularity = config.recompute_granularity + self.no_recompute_layers = config.no_recompute_layers if config.no_recompute_layers is not None else [] + self.config = config + + # Recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + self.embed_tokens = mpu.VocabParallelEmbedding( + self.vocab_size, + self.hidden_size, + weight_attr=paddle.ParamAttr(initializer=nn.initializer.XavierNormal()), + ) + else: + self.embed_tokens = nn.Embedding( + self.vocab_size, + self.hidden_size, + ) + + self.layers = nn.LayerList( + [ + Qwen2DecoderLayer( + config=config, + layerwise_recompute=layer_idx not in self.no_recompute_layers, + skip_recompute_ops=get_skip_recompute_ops(config, layer_idx), + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self.norm = Qwen2RMSNorm(config) + + self.gradient_checkpointing = False + self.padding_idx = config.pad_token_id + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + attn_mask_startend_row_indices=None, + batch_size: int = None, + alibi=None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + recompute_fn = rr_recompute if any(layer_module.skip_recompute_ops.values()) else recompute + hidden_states = recompute_fn( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + batch_size, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + if self.sequence_parallel and use_cache: + raise ValueError("We currently only support sequence parallel without cache.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + if past_key_values is None: + past_key_values = tuple([None] * len(self.layers)) + # NOTE: to make cache can be clear in-time + past_key_values = list(past_key_values) + + seq_length_with_past = seq_length + cache_length = 0 + if past_key_values[0] is not None: + cache_length = past_key_values[0][0].shape[1] + seq_length_with_past += cache_length + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self.sequence_parallel: + # [bs, seq_len, num_head * head_dim] -> [bs * seq_len, num_head * head_dim] + bs, seq_len, hidden_size = inputs_embeds.shape + inputs_embeds = paddle.reshape_(inputs_embeds, [bs * seq_len, hidden_size]) + # [seq_len * bs / n, num_head * head_dim] (n is mp parallelism) + inputs_embeds = ScatterOp.apply(inputs_embeds) + + if self.config.context_parallel_degree > 1 and (attention_mask is not None or self.config.alibi): + raise NotImplementedError("Ring FlashAttention doesn't support attention_mask or alibi") + + # embed positions + if self.config.use_flash_attention_for_generation: + attention_mask = None + elif attn_mask_startend_row_indices is None and attention_mask is None: + # [bs, seq_len] + attention_mask = paddle.ones((batch_size, seq_length_with_past), dtype=paddle.bool) + if attn_mask_startend_row_indices is None and self.config.alibi: + if self.config.use_long_sequence_strategies: + alibi_layer = LongSequenceStrategies.build_long_sequence_strategy( + self.config.long_sequence_strategy_type, + self.config.long_sequence_strategy_name, + **self.config.long_sequence_init_args, + ) + alibi = alibi_layer(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + else: + alibi = build_alibi_tensor(attention_mask, self.config.num_attention_heads, dtype=inputs_embeds.dtype) + if self.config.tensor_parallel_degree > 1: + block_size = self.config.num_attention_heads // self.config.tensor_parallel_degree + alibi = alibi[ + :, + self.config.tensor_parallel_rank + * block_size : (self.config.tensor_parallel_rank + 1) + * block_size, + ] + alibi = alibi.reshape([batch_size * block_size, 1, seq_length_with_past]) + else: + alibi = alibi.reshape([batch_size * self.config.num_attention_heads, 1, seq_length_with_past]) + else: + alibi = None + + if position_ids is None: + position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length)) + + use_casual_mask = get_use_casual_mask() and not self.config.alibi + + if self.config.use_flash_attention_for_generation or use_casual_mask: + attention_mask = None + elif attn_mask_startend_row_indices is None: + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype + ) # [bs, 1, seq_len, seq_len] + + is_casual = False + + if ( + attn_mask_startend_row_indices is None + and self.config.use_flash_attention + and get_env_device() not in ["gcu", "intel_hpu"] + ): + if self.config.use_flash_attention_for_generation or use_casual_mask: + is_casual = True + else: + is_casual = is_casual_mask(attention_mask) + if get_env_device() not in ["npu", "mlu"]: + if is_casual and alibi is None: + attention_mask = None + else: + attention_mask = None if attention_mask is None else attention_mask.astype("bool") + hidden_states = inputs_embeds + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, (decoder_layer) in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + past_key_value = past_key_values[idx] if past_key_values is not None else None + + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and idx not in self.no_recompute_layers + and has_gradient + and self.recompute_granularity == "full" + ): + layer_outputs = self.recompute_training_full( + decoder_layer, + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + alibi=alibi, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + npu_is_casual=is_casual, + ) + + # NOTE: clear outdate cache after it has been used for memory saving + past_key_value = past_key_values[idx] = None + if type(layer_outputs) is tuple: + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if self.config.use_last_token_for_generation: + hidden_states = paddle.unsqueeze(hidden_states[:, -1, :], 1) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=None, + ) + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + Args: + config: LlamaConfig + """ + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @staticmethod + def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values_length, dtype): + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + if len(attention_mask.shape) == 2: + expanded_attn_mask = _expand_2d_mask(attention_mask, dtype, tgt_length=input_shape[-1]) + # For decoding phase in generation, seq_length = 1, we don't need to add causal mask + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, past_key_values_length=past_key_values_length + ) + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + expanded_attn_mask = expanded_attn_mask.astype("bool") & combined_attention_mask.astype("bool") + else: + expanded_attn_mask = expanded_attn_mask & combined_attention_mask + # [bsz, seq_len, seq_len] -> [bsz, 1, seq_len, seq_len] + elif len(attention_mask.shape) == 3: + expanded_attn_mask = attention_mask.unsqueeze(1).astype("bool") + # if attention_mask is already 4-D, do nothing + else: + expanded_attn_mask = attention_mask + else: + expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) + # Convert bool attention_mask to float attention mask, which will be added to attention_scores later + if get_env_device() in ["npu", "mlu", "intel_hpu"]: + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype="float32") + y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) + elif get_env_device() == "gcu": + min_val = paddle.finfo(dtype).min + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(min_val, dtype=dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) + else: + expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + return expanded_attn_mask +class Qwen2PretrainingCriterion(paddle.nn.Layer): + """Qwen2的预训练损失计算,继承自LlamaPretrainingCriterion""" + def __init__(self, config: Qwen2Config): + + super(Qwen2PretrainingCriterion, self).__init__() + self.ignore_index = getattr(config, "ignore_index", -100) + self.config = config + self.enable_parallel_cross_entropy = ( + config.tensor_parallel_degree > 1 + and config.vocab_size % config.tensor_parallel_degree == 0 + and config.tensor_parallel_output + ) + + if self.enable_parallel_cross_entropy: # and False: # and lm_head is distributed + self.loss_func = mpu.ParallelCrossEntropy(ignore_index=self.ignore_index) + else: + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + """ + Criterion for Llama. + It calculates the final loss. + """ + + def forward(self, prediction_scores, masked_lm_labels): + if self.enable_parallel_cross_entropy: + if prediction_scores.shape[-1] == self.config.vocab_size: + warnings.warn( + f"enable_parallel_cross_entropy, the vocab_size should be splited: {prediction_scores.shape[-1]}, {self.config.vocab_size}" + ) + self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index) + + with paddle.amp.auto_cast(False): + masked_lm_loss = self.loss_func(prediction_scores.astype("float32"), masked_lm_labels.unsqueeze(2)) + + if self.config.sep_parallel_degree > 1 or self.config.context_parallel_degree > 1: + _hcg = fleet.get_hybrid_communicate_group() + masked_lm_loss = ConcatMaskedLoss.apply(masked_lm_loss, axis=1, group=_hcg.get_sep_parallel_group()) + # skip ignore_index which loss == 0 + # masked_lm_loss = masked_lm_loss[masked_lm_loss > 0] + # loss = paddle.mean(masked_lm_loss) + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count + + return loss + +class Qwen2LMHead(nn.Layer): + """Qwen2的语言模型头,继承自LlamaLMHead""" + def __init__(self, config: Qwen2Config, embedding_weights=None, transpose_y=False): + super(Qwen2LMHead, self).__init__() + self.config = config + if config.tensor_parallel_degree > 1 and config.vocab_size % config.tensor_parallel_degree == 0: + vocab_size = config.vocab_size // config.tensor_parallel_degree + else: + vocab_size = config.vocab_size + + self.transpose_y = transpose_y + if transpose_y: + if embedding_weights is not None: + self.weight = embedding_weights + else: + self.weight = self.create_parameter( + shape=[vocab_size, config.hidden_size], + dtype=paddle.get_default_dtype(), + ) + else: + if vocab_size != config.vocab_size: + with get_rng_state_tracker().rng_state(): + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + else: + self.weight = self.create_parameter( + shape=[config.hidden_size, vocab_size], + dtype=paddle.get_default_dtype(), + ) + # Must set distributed attr for Tensor Parallel ! + self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False + if self.weight.is_distributed: + # for tie_word_embeddings + self.weight.split_axis = 0 if self.transpose_y else 1 + if get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + parallel_matmul as xpu_parallel_matmul, + ) + + self.xpu_parallel_matmul = xpu_parallel_matmul() + except ImportError: + self.xpu_parallel_matmul = None + def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) + return logits +class Qwen2ForCausalLM(Qwen2PretrainedModel): + """用于因果语言建模的Qwen2模型,继承自LlamaForCausalLM""" + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.config = config + + self.qwen2 = Qwen2Model(config) + if config.tie_word_embeddings: + self.lm_head = Qwen2LMHead(config, embedding_weights=self.qwen2.embed_tokens.weight, transpose_y=True) + self.tie_weights() + else: + self.lm_head = Qwen2LMHead(config) + self.criterion = Qwen2PretrainingCriterion(config) + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + # TODO: support attention mask for other models + attention_mask = model_kwargs["attention_mask"] + if len(attention_mask.shape) == 2: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], + axis=-1, + ) + elif len(attention_mask.shape) == 4: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([*attention_mask.shape[:3], 1], dtype=attention_mask.dtype)], + axis=-1, + )[:, :, -1:, :] + + return model_kwargs + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.qwen2( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + + hidden_states = outputs[0] + + # add this for fused_head_and_loss_fn + if self.config.use_fused_head_and_loss_fn and self.training: + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape( + [ + batch_size, + -1, + hidden_states.shape[-1], + ] + ) + return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is together with ParallelCrossEntropy + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 + + if labels is not None and self.config.use_fused_linear_cross_entropy: + from paddlenlp_kernel.triton.cut_cross_entropy import linear_cross_entropy + + assert ( + self.config.tensor_parallel_degree <= 1 + ), "The argument `use_fused_linear_cross_entropy` is imcompatiable with tensor parallel " + + masked_lm_loss = linear_cross_entropy(hidden_states, self.lm_head.weight, targets=labels) + + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count + logits = None + else: + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output, batch_size=batch_size) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + enable_to_static_method = True + _tied_weights_keys = ["lm_head.weight"] + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.qwen2 = decoder + + def get_decoder(self): + return self.qwen2 + + def prepare_inputs_for_generation( + self, input_ids, use_cache=False, past_key_values=None, inputs_embeds=None, **kwargs + ): + batch_size, seq_length = input_ids.shape + position_ids = kwargs.get("position_ids", paddle.arange(seq_length).expand((batch_size, seq_length))) + attention_mask = kwargs.get("attention_mask", None) + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(axis=-1) + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + def _get_model_inputs_spec(self, dtype: str): + return { + "input_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "attention_mask": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + "position_ids": paddle.static.InputSpec(shape=[None, None], dtype="int64"), + } + + +class Qwen2ForSequenceClassification(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.qwen2 = Qwen2Model(config) + self.score = Linear(config.hidden_size, self.num_labels, bias_attr=False) + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.qwen2( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = paddle.equal(input_ids, self.config.pad_token_id).astype("int32").argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths + else: + sequence_lengths = -1 + + # pooled_logits = logits[paddle.arange(batch_size), sequence_lengths] + pooled_logits = logits.gather_nd(paddle.stack([paddle.arange(logits.shape[0]), sequence_lengths], axis=-1)) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(pooled_logits.reshape([-1, self.num_labels]), labels.reshape([-1])) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.qwen2 = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = Linear(config.hidden_size, config.num_labels) + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.qwen2( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape([-1, self.num_labels]), labels.reshape([-1])) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2SentenceEmbedding(Qwen2PretrainedModel): + def __init__( + self, + config: Qwen2Config, + embedding_temperature: float = 0.02, + ): + """Qwen2SentenceEmbedding + For getting larger batch_size, we use tensor parallel to get larger batch_size. + + Args: + config (Qwen2Config): _description_ + model (Qwen2Model): _description_ + embedding_temperature (float, optional): _description_. Defaults to 0.02. + """ + super(Qwen2SentenceEmbedding, self).__init__(config) + self.config = config + self.qwen2 = Qwen2Model(config) + self.in_batch_negative_loss = SimpleContrastiveLoss(embedding_temperature) + self.world_size = dist.get_world_size() + self.process_rank = dist.get_rank() + self.embedding_negatives_cross_device = config.embedding_negatives_cross_device + if self.world_size <= 1: + self.embedding_negatives_cross_device = False + + def forward( + self, + query: Optional[Dict[str, paddle.Tensor]] = None, + passages: Optional[Dict[str, paddle.Tensor]] = None, + return_encode=False, + ): + """forward""" + q_reps = self.encode(**query) + p_reps = self.encode(**passages) + + q_reps = nn.functional.normalize(q_reps, axis=-1) + p_reps = nn.functional.normalize(p_reps, axis=-1) + + if return_encode: + return q_reps, p_reps + + if self.embedding_negatives_cross_device: + q_reps = dist_gather_tensor_with_gradient(q_reps) + p_reps = dist_gather_tensor_with_gradient(p_reps) + + loss = self.in_batch_negative_loss(q_reps, p_reps) + return loss + + def encode( + self, + input_ids, + position_ids=None, + embedding_indices=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + **kwargs, + ): + """encode""" + input_type = type(input_ids) + outputs = self.qwen2( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + if isinstance(outputs, input_type): + hidden_states = outputs + else: + hidden_states = outputs[0] + last_hidden_states = hidden_states.gather_nd(embedding_indices) + return last_hidden_states + diff --git a/paddlenlp/transformers/qwen2/modular_qwen2.py b/paddlenlp/transformers/qwen2/modular_qwen2.py new file mode 100644 index 000000000000..f0e2ddcf52ae --- /dev/null +++ b/paddlenlp/transformers/qwen2/modular_qwen2.py @@ -0,0 +1,1313 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. +# +# This code ijins based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Paddle Qwen2 model.""" +from __future__ import annotations + +import math +import warnings +from functools import partial +from typing import Dict, List, Optional, Tuple, Union + +import paddle +import paddle.distributed as dist +import paddle.distributed.fleet.meta_parallel as mpu +import paddle.nn.functional as F +from paddle import Tensor, nn +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.distributed.fleet.recompute.recompute import recompute + +from ...utils.tools import get_env_device +from .. import linear_utils +from ..activations import ACT2FN +from ..contrastive_loss import SimpleContrastiveLoss +from ..conversion_utils import StateDictNameMapping, init_name_mappings +from ..embedding_utils import dist_gather_tensor_with_gradient +from ..linear_utils import Linear +from ..model_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ..model_utils import PretrainedModel, register_base_model +from ..refined_recompute import ( + RRColumnParallelLinear, + RRColumnSequenceParallelLinear, + RRRowParallelLinear, + RRRowSequenceParallelLinear, + get_skip_recompute_ops, +) +from ..refined_recompute import recompute as rr_recompute +from ..utils import caculate_llm_per_token_flops, logger +from .configuration import Qwen2Config +from ..llama.modeling import ( + LlamaPretrainedModel, + LlamaModel, + LlamaForCausalLM, + LlamaMLP, + LlamaDecoderLayer, + LlamaRMSNorm, + LlamaRotaryEmbedding, + LlamaPretrainingCriterion, + LlamaLMHead, + fusion_ops, +) +from ..llama.modeling import apply_rotary_pos_emb, repeat_kv + + +try: + from paddle.incubate.nn.functional import fused_rotary_position_embedding +except ImportError: + fused_rotary_position_embedding = None + +try: + from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + GatherOp, + ScatterOp, + mark_as_sequence_parallel_parameter, + ) +except: + pass + +try: + from paddle.nn.functional.flash_attention import flash_attention +except: + flash_attention = None +__all__ = [ + "Qwen2Model", + "Qwen2PretrainedModel", + "Qwen2ForCausalLM", + "Qwen2PretrainingCriterion", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2SentenceEmbedding", +] + +def get_triangle_upper_mask(x, mask=None): + if mask is not None: + return mask + # [bsz, n_head, q_len, kv_seq_len] + shape = x.shape + # [bsz, 1, q_len, kv_seq_len] + shape[1] = 1 + mask = paddle.full(shape, paddle.finfo(x.dtype).min, dtype=x.dtype) + mask = paddle.triu(mask, diagonal=1) + mask.stop_gradient = True + return mask + +def scaled_dot_product_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=None, + training=True, + sequence_parallel=False, + skip_recompute=False, +): + bsz, q_len, num_heads, head_dim = query_states.shape + _, kv_seq_len, _, _ = value_states.shape + + if config.use_flash_attention and flash_attention: + # Paddle Flash Attention input [ bz, seqlen, nhead, head_dim] + # Torch Flash Attention input [ bz, nhead, seqlen, head_dim] + + return fusion_ops.fusion_flash_attention( + query_states, + config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + sequence_parallel=sequence_parallel, + skip_recompute=skip_recompute, + ) + else: + # [ bz, seqlen, nhead, head_dim] -> [bs, nhead, seq_len, head_dim] + query_states = paddle.transpose(query_states, [0, 2, 1, 3]) + # merge with the next transpose + key_states = paddle.transpose(key_states, [0, 2, 1, 3]) + value_states = paddle.transpose(value_states, [0, 2, 1, 3]) + + # Add pre divided factor to fix nan under float16. + if paddle.in_dynamic_mode() and query_states.dtype == paddle.float16: + pre_divided_factor = 32 + else: + pre_divided_factor = 1 + + attn_weights = paddle.matmul( + query_states / (math.sqrt(head_dim) * pre_divided_factor), key_states.transpose([0, 1, 3, 2]) + ) + + if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]: + raise ValueError( + f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.shape}" + ) + + if attention_mask is None: + attention_mask = get_triangle_upper_mask(attn_weights) + + attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len]) + if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]: + raise ValueError( + f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}" + ) + + attn_weights = attn_weights + attention_mask + + if not paddle.in_dynamic_mode(): + attn_weights = F.softmax(attn_weights * pre_divided_factor, axis=-1, dtype="float32").astype( + query_states.dtype + ) + else: + with paddle.amp.auto_cast(False): + attn_weights = F.softmax( + attn_weights.astype("float32") * pre_divided_factor, axis=-1, dtype="float32" + ).astype(query_states.dtype) + + attn_weights = F.dropout(attn_weights, p=config.attention_dropout, training=training) + + attn_output = paddle.matmul(attn_weights, value_states) + attn_output = attn_output.transpose([0, 2, 1, 3]) + + if sequence_parallel: + attn_output = attn_output.reshape([bsz * q_len, head_dim * num_heads]) + else: + attn_output = attn_output.reshape([bsz, q_len, head_dim * num_heads]) + return (attn_output, attn_weights) if output_attentions else attn_output + + +class Qwen2RMSNorm(LlamaRMSNorm): + """Qwen2的RMSNorm,继承自LlamaRMSNorm""" + def __init__(self, config: Qwen2Config): + super().__init__(config) +class Qwen2RotaryEmbedding(LlamaRotaryEmbedding): + def __init__(self, dim, max_position_embeddings=2048, base=10000): + super().__init__(dim, max_position_embeddings, base) + def _set_cos_sin_cache(self, seq_len): + self.max_seq_len_cached = seq_len + if self.inv_freq.dtype != paddle.float32: + self.inv_freq = 1.0 / ( + self.base ** (paddle.cast(paddle.arange(0, self.dim, 2), dtype="float32") / self.dim) + ) + super()._set_cos_sin_cache(seq_len) + def forward(self, x, seq_len=None): + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len) + super().forward(x, seq_len) + +class Qwen2MLP(LlamaMLP): + """Qwen2的MLP,继承自LlamaMLP""" + def __init__(self, config: Qwen2Config,is_shared=False, skip_recompute_ops=None): + super().__init__(config) + if config.hidden_act == "silu": + self.act_fn = fusion_ops.swiglu + self.fuse_swiglu = True + else: + self.act_fn = ACT2FN[config.hidden_act] + self.fuse_swiglu = False + # Qwen2的MLP结构与Llama相同,但使用不同的配置 + def forward(self, x): + if self.fuse_attention_ffn: + x = self.gate_up_fused_proj(x) + if self.fuse_swiglu: + y = None + else: + x, y = x.chunk(2, axis=-1) + else: + x, y = self.gate_proj(x), self.up_proj(x) + + if self.fuse_swiglu: + x = self.act_fn(x, y) + else: + x = self.act_fn(x) * y + + return self.down_proj(x) + +class Qwen2Attention(nn.Layer): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = True, skip_recompute_ops=None): + super().__init__() + if skip_recompute_ops is None: + skip_recompute_ops = {} + self.config = config + self.skip_recompute_ops = skip_recompute_ops + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.num_attention_heads = config.num_attention_heads + + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + + self.num_key_value_heads = config.num_key_value_heads + assert config.num_attention_heads // config.num_key_value_heads + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.gqa_or_mqa = config.num_attention_heads != config.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + # self.seq_length = config.seq_length + self.sequence_parallel = config.sequence_parallel + self.has_bias = config.attention_bias + self.fuse_attention_qkv = config.fuse_attention_qkv + + # Note that we will actually perform a recompute only if both enable_recompute and layerwise_recompute are set to True + # Enable_recompute defaults to False and is controlled by Trainer + self.enable_recompute = False + self.layerwise_recompute = layerwise_recompute + self.recompute_granularity = config.recompute_granularity + if config.tensor_parallel_degree > 1: + assert ( + self.num_heads % config.tensor_parallel_degree == 0 + ), f"num_heads: {self.num_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_heads = self.num_heads // config.tensor_parallel_degree + + assert ( + self.num_key_value_heads % config.tensor_parallel_degree == 0 + ), f"num_key_value_heads: {self.num_key_value_heads}, tensor_parallel_degree: {config.tensor_parallel_degree}" + self.num_key_value_heads = self.num_key_value_heads // config.tensor_parallel_degree + + self.use_fused_rope = config.use_fused_rope + if self.use_fused_rope: + if get_env_device() not in ["gpu", "xpu"] or fused_rotary_position_embedding is None: + warnings.warn( + "Enable fuse rope in the config, but fuse rope is not available. " + "Will disable fuse rope. Try using latest gpu version of Paddle." + ) + self.use_fused_rope = False + + if config.sequence_parallel: + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnSequenceParallelLinear + if skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowSequenceParallelLinear + else: + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant: + if skip_recompute_ops.get("attention_column_ln", False): + ColumnParallelLinear = RRColumnParallelLinear + if skip_recompute_ops.get("attention_row_ln", False): + RowParallelLinear = RRRowParallelLinear + + if config.tensor_parallel_degree > 1: + if self.fuse_attention_qkv: + self.qkv_proj = ColumnParallelLinear( + self.hidden_size, + self.num_attention_heads * self.head_dim + 2 * self.config.num_key_value_heads * self.head_dim, + has_bias=self.has_bias, + gather_output=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_attention_heads * self.head_dim, + has_bias=self.has_bias, + gather_output=False, + ) + self.k_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=self.has_bias, gather_output=False) # fmt:skip + self.v_proj = ColumnParallelLinear(self.hidden_size, self.config.num_key_value_heads * self.head_dim, has_bias=self.has_bias, gather_output=False) # fmt:skip + self.o_proj = RowParallelLinear(self.hidden_size, self.hidden_size, has_bias=False, input_is_parallel=True) + else: + if self.fuse_attention_qkv: + self.qkv_proj = Linear( + self.hidden_size, + self.num_attention_heads * self.head_dim + 2 * self.config.num_key_value_heads * self.head_dim, + ) + else: + self.q_proj = Linear( + self.hidden_size, self.num_attention_heads * self.head_dim, bias_attr=self.has_bias + ) + self.k_proj = Linear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=self.has_bias + ) + self.v_proj = Linear( + self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=self.has_bias + ) + self.o_proj = Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias_attr=False) + + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + + self.attn_func = scaled_dot_product_attention + + # NOTE: refined_recompute is only supported when `recompute_use_reentrant=False` + if config.recompute and not config.recompute_use_reentrant and skip_recompute_ops.get("flash_attn", False): + self.attn_func = partial(scaled_dot_product_attention, skip_recompute=True) + + def forward( + self, + hidden_states, + position_ids: Optional[Tuple[paddle.Tensor]] = None, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: bool = False, + use_cache: bool = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + batch_size: Optional[int] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[paddle.Tensor], Optional[Tuple[paddle.Tensor]]]: + """Input shape: Batch x Time x Channel""" + # [bs, seq_len, num_head * head_dim] -> [seq_len / n, bs, num_head * head_dim] (n is model parallelism) + + if self.fuse_attention_qkv: + mix_layer = self.qkv_proj(hidden_states) + if self.sequence_parallel: + target_shape = [ + batch_size, + -1, + self.num_key_value_heads, + (self.num_key_value_groups + 2) * self.head_dim, + ] + else: + target_shape = [0, 0, self.num_key_value_heads, (self.num_key_value_groups + 2) * self.head_dim] + mix_layer = paddle.reshape_(mix_layer, target_shape) + query_states, key_states, value_states = paddle.split( + mix_layer, + num_or_sections=[self.num_key_value_groups * self.head_dim, self.head_dim, self.head_dim], + axis=-1, + ) + if self.gqa_or_mqa: + query_states = paddle.reshape_(query_states, [0, 0, self.num_heads, self.head_dim]) + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.sequence_parallel: + target_query_shape = [batch_size, -1, self.num_heads, self.head_dim] + target_key_value_shape = [batch_size, -1, self.num_key_value_heads, self.head_dim] + else: + target_query_shape = [0, 0, self.num_heads, self.head_dim] + target_key_value_shape = [0, 0, self.num_key_value_heads, self.head_dim] + query_states = query_states.reshape(shape=target_query_shape) + key_states = key_states.reshape(shape=target_key_value_shape) + value_states = value_states.reshape(shape=target_key_value_shape) + + if position_ids is not None and not self.use_fused_rope: + kv_seq_len = position_ids.max().item() + 1 + else: + kv_seq_len = key_states.shape[-3] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-3] + if self.use_fused_rope: + assert past_key_value is None, "fuse rotary not support cache kv for now" + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states, _ = fused_rotary_position_embedding( + query_states, + key_states, + v=None, + sin=sin, + cos=cos, + position_ids=position_ids, + use_neox_rotary_style=False, + ) + else: + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # [bs, seq_len, num_head, head_dim] + if past_key_value is not None: + key_states = paddle.concat([past_key_value[0], key_states], axis=1) + value_states = paddle.concat([past_key_value[1], value_states], axis=1) + past_key_value = (key_states, value_states) if use_cache else None + + # TODO(wj-Mcat): use broadcast strategy when n_kv_heads = 1 + # repeat k/v heads if n_kv_heads < n_heads + paddle_version = float(paddle.__version__[:3]) + if not self.config.use_flash_attention or ((paddle_version != 0.0) and (paddle_version <= 2.6)): + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + has_gradient = not (query_states.stop_gradient and key_states.stop_gradient and value_states.stop_gradient) + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "core_attn" + ): + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( + self.attn_func, + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + training=self.training, + sequence_parallel=self.sequence_parallel, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.attn_func( + query_states, + self.config, + key_states, + value_states, + attention_mask, + output_attentions, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + training=self.training, + sequence_parallel=self.sequence_parallel, + ) + if output_attentions: + attn_output, attn_weights = outputs + else: + attn_output = outputs + + # if sequence_parallel is true, out shape are [q_len / n, bs, num_head * head_dim] + # else their shape are [bs, q_len, num_head * head_dim], n is mp parallelism. + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + outputs = (attn_output,) + + if output_attentions: + outputs += (attn_weights,) + + if use_cache: + outputs += (past_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2DecoderLayer(LlamaDecoderLayer): + """Qwen2的解码器层,继承自LlamaDecoderLayer""" + def __init__(self, config: Qwen2Config, layerwise_recompute: bool = False, skip_recompute_ops=None): + super().__init__(config) + def forward( + self, + hidden_states: paddle.Tensor, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + output_attentions: Optional[bool] = False, + past_key_value: Optional[Tuple[paddle.Tensor]] = None, + use_cache: Optional[bool] = False, + attn_mask_startend_row_indices: Optional[paddle.Tensor] = None, + batch_size: Optional[int] = None, + **kwargs, + ) -> Tuple[paddle.Tensor, Optional[Tuple[paddle.Tensor, paddle.Tensor]]]: + """ + Args: + hidden_states (`paddle.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`paddle.Tensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(paddle.Tensor)`, *optional*): cached past key and value projection states + """ + + # [bs * seq_len, embed_dim] -> [seq_len * bs / n, embed_dim] (sequence_parallel) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + has_gradient = not hidden_states.stop_gradient + if ( + self.enable_recompute + and self.layerwise_recompute + and has_gradient + and self.recompute_granularity == "full_attn" + ): + recompute_fn = rr_recompute if any(self.skip_recompute_ops.values()) else recompute + outputs = recompute_fn( + self.self_attn, + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + attn_mask_startend_row_indices, + batch_size, + use_reentrant=self.config.recompute_use_reentrant, + ) + else: + outputs = self.self_attn( + hidden_states, + position_ids, + past_key_value, + attention_mask, + output_attentions, + use_cache, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + batch_size=batch_size, + ) + + if type(outputs) is tuple: + hidden_states = outputs[0] + else: + hidden_states = outputs + + if output_attentions: + self_attn_weights = outputs[1] + + if use_cache: + present_key_value = outputs[2 if output_attentions else 1] + + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + if type(outputs) is tuple and len(outputs) == 1: + outputs = outputs[0] + + return outputs + + +class Qwen2PretrainedModel(LlamaPretrainedModel): + """Qwen2预训练模型基类,继承自LlamaPretrainedModel""" + config_class = Qwen2Config + base_model_prefix = "qwen2" + _keys_to_ignore_on_load_unexpected = [r"self_attn.rotary_emb.inv_freq"] + @classmethod + def _get_name_mappings(cls, config: Qwen2Config) -> list[StateDictNameMapping]: + mappings: list[StateDictNameMapping] = [] + model_mappings = [ + ["embed_tokens.weight"], + ["norm.weight"], + ] + for layer_index in range(config.num_hidden_layers): + layer_mappings = [ + [f"layers.{layer_index}.self_attn.q_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.k_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.v_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.q_proj.bias", None], + [f"layers.{layer_index}.self_attn.k_proj.bias", None], + [f"layers.{layer_index}.self_attn.v_proj.bias", None], + [f"layers.{layer_index}.self_attn.o_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.up_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.gate_proj.weight", None, "transpose"], + [f"layers.{layer_index}.mlp.down_proj.weight", None, "transpose"], + [f"layers.{layer_index}.self_attn.rotary_emb.inv_freq"], + [f"layers.{layer_index}.input_layernorm.weight"], + [f"layers.{layer_index}.post_attention_layernorm.weight"], + ] + model_mappings.extend(layer_mappings) + + init_name_mappings(mappings=model_mappings) + # base-model prefix "Qwen2MoEModel" + if "Qwen2Model" not in config.architectures: + for mapping in model_mappings: + mapping[0] = "model." + mapping[0] + mapping[1] = "qwen2." + mapping[1] + if not config.tie_word_embeddings: + model_mappings.append(["lm_head.weight", "lm_head.weight", "transpose"]) + + mappings = [StateDictNameMapping(*mapping, index=index) for index, mapping in enumerate(model_mappings)] + return mappings + + @classmethod + def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True): + + from ..conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + if config.tie_word_embeddings: + base_actions["lm_head.weight"] = partial(fn, is_column=False) + else: + base_actions["lm_head.weight"] = partial(fn, is_column=True) + + if not config.vocab_size % config.tensor_parallel_degree == 0: + base_actions.pop("lm_head.weight") + base_actions.pop("embed_tokens.weight") + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ), + ( + "layers.0.self_attn.q_proj.bias", + "layers.0.self_attn.k_proj.bias", + "layers.0.self_attn.v_proj.bias", + "layers.0.self_attn.qkv_proj.bias", + ), + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.gate_up_fused_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + + def _get_model_flops(self): + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_per_token_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=False, + ) + + def _get_hardware_flops(self): + if hasattr(self.config, "seq_length"): + seq_length = self.config.seq_length + else: + seq_length = 2048 + + return caculate_llm_per_token_flops( + hidden_size=self.config.hidden_size, + intermediate_size=self.config.intermediate_size, + layer_num=self.config.num_hidden_layers, + vocab_size=self.config.vocab_size, + seq_length=seq_length, + recompute=self.config.recompute, + recompute_granularity=self.config.recompute_granularity, + ) + + +@register_base_model +class Qwen2Model(LlamaModel): + """Qwen2模型,继承自LlamaModel""" + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + @paddle.jit.not_to_static + def recompute_training_full( + self, + layer_module: nn.Layer, + hidden_states: Tensor, + position_ids: Optional[Tensor], + attention_mask: Tensor, + output_attentions: bool, + past_key_value: Tensor, + use_cache: bool, + attn_mask_startend_row_indices=None, + batch_size: int = None, + ): + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + recompute_fn = rr_recompute if any(layer_module.skip_recompute_ops.values()) else recompute + hidden_states = recompute_fn( + create_custom_forward(layer_module), + hidden_states, + position_ids, + attention_mask, + output_attentions, + past_key_value, + use_cache, + attn_mask_startend_row_indices, + batch_size, + use_reentrant=self.config.recompute_use_reentrant, + ) + + return hidden_states + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + super().forward() +class Qwen2PretrainingCriterion(LlamaPretrainingCriterion): + """Qwen2的预训练损失计算,继承自LlamaPretrainingCriterion""" + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.enable_parallel_cross_entropy = config.tensor_parallel_degree > 1 and config.tensor_parallel_output + +class Qwen2LMHead(LlamaLMHead): + """Qwen2的语言模型头,继承自LlamaLMHead""" + def __init__(self, config: Qwen2Config): + super().__init__(config) + def forward(self, hidden_states, tensor_parallel_output=None, batch_size=None): + if self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = paddle.reshape_(hidden_states, [batch_size, -1, self.config.hidden_size]) + + if tensor_parallel_output is None: + tensor_parallel_output = self.config.tensor_parallel_output + + logits = parallel_matmul( + hidden_states, self.weight, transpose_y=self.transpose_y, tensor_parallel_output=tensor_parallel_output + ) + return logits +class Qwen2ForCausalLM(LlamaForCausalLM): + """用于因果语言建模的Qwen2模型,继承自LlamaForCausalLM""" + def __init__(self, config: Qwen2Config): + super().__init__(config) + @staticmethod + def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): + # update cache + if isinstance(outputs, tuple) and len(outputs) > 1 and not isinstance(outputs[1], paddle.Tensor): + model_kwargs["past_key_values"] = outputs[1] + + if isinstance(outputs, CausalLMOutputWithPast) and "past_key_values" in outputs: + model_kwargs["past_key_values"] = outputs.past_key_values + + # update position_ids + if "position_ids" in model_kwargs and model_kwargs["position_ids"] is not None: + position_ids = model_kwargs["position_ids"] + model_kwargs["position_ids"] = paddle.concat([position_ids, position_ids[..., -1:] + 1], axis=-1) + + if not is_encoder_decoder and "attention_mask" in model_kwargs: + # TODO: support attention mask for other models + attention_mask = model_kwargs["attention_mask"] + if len(attention_mask.shape) == 2: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([attention_mask.shape[0], 1], dtype=attention_mask.dtype)], + axis=-1, + ) + elif len(attention_mask.shape) == 4: + model_kwargs["attention_mask"] = paddle.concat( + [attention_mask, paddle.ones([*attention_mask.shape[:3], 1], dtype=attention_mask.dtype)], + axis=-1, + )[:, :, -1:, :] + + return model_kwargs + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + attn_mask_startend_row_indices=None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`paddle.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Qwen2ForCausalLM + + >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if attn_mask_startend_row_indices is not None and attention_mask is not None: + logger.warning( + "You have provided both attn_mask_startend_row_indices and attention_mask. " + "The attn_mask_startend_row_indices will be used." + ) + attention_mask = None + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.qwen2( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + past_key_values=past_key_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + ) + + hidden_states = outputs[0] + + # add this for fused_head_and_loss_fn + if self.config.use_fused_head_and_loss_fn and self.training: + if self.config.tensor_parallel_degree > 1 and self.config.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + hidden_states = hidden_states.reshape( + [ + batch_size, + -1, + hidden_states.shape[-1], + ] + ) + return hidden_states, self.lm_head.weight, None, self.lm_head.transpose_y + + # if labels is None,means we need full output, instead of tensor_parallel_output + # tensor_parallel_output is together with ParallelCrossEntropy + tensor_parallel_output = self.config.tensor_parallel_output and self.config.tensor_parallel_degree > 1 + + if labels is not None and self.config.use_fused_linear_cross_entropy: + from paddlenlp_kernel.triton.cut_cross_entropy import linear_cross_entropy + + assert ( + self.config.tensor_parallel_degree <= 1 + ), "The argument `use_fused_linear_cross_entropy` is imcompatiable with tensor parallel " + + masked_lm_loss = linear_cross_entropy(hidden_states, self.lm_head.weight, targets=labels) + + binary_sequence = paddle.where( + masked_lm_loss > 0, paddle.ones_like(masked_lm_loss), paddle.zeros_like(masked_lm_loss) + ) + count = paddle.sum(binary_sequence) + if count == 0: + loss = paddle.sum(masked_lm_loss * binary_sequence) + else: + loss = paddle.sum(masked_lm_loss * binary_sequence) / count + logits = None + else: + logits = self.lm_head(hidden_states, tensor_parallel_output=tensor_parallel_output, batch_size=batch_size) + + loss = None + if labels is not None: + loss = self.criterion(logits, labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2ForSequenceClassification(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.qwen2 = Qwen2Model(config) + self.score = Linear(config.hidden_size, self.num_labels, bias_attr=False) + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + position_ids: Optional[paddle.Tensor] = None, + attention_mask: Optional[paddle.Tensor] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.qwen2( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = paddle.equal(input_ids, self.config.pad_token_id).astype("int32").argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths + else: + sequence_lengths = -1 + + # pooled_logits = logits[paddle.arange(batch_size), sequence_lengths] + pooled_logits = logits.gather_nd(paddle.stack([paddle.arange(logits.shape[0]), sequence_lengths], axis=-1)) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == paddle.int64 or labels.dtype == paddle.int32): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = nn.MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(pooled_logits.reshape([-1, self.num_labels]), labels.reshape([-1])) + elif self.config.problem_type == "multi_label_classification": + loss_fct = nn.BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaForTokenClassification with Llama->Qwen2, LLAMA->QWEN2 +class Qwen2ForTokenClassification(Qwen2PretrainedModel): + def __init__(self, config: Qwen2Config): + super().__init__(config) + self.num_labels = config.num_labels + self.qwen2 = Qwen2Model(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = Linear(config.hidden_size, config.num_labels) + + def get_input_embeddings(self): + return self.qwen2.embed_tokens + + def set_input_embeddings(self, value): + self.qwen2.embed_tokens = value + + def forward( + self, + input_ids: paddle.Tensor = None, + attention_mask: Optional[paddle.Tensor] = None, + position_ids: Optional[paddle.Tensor] = None, + past_key_values: Optional[List[paddle.Tensor]] = None, + inputs_embeds: Optional[paddle.Tensor] = None, + labels: Optional[paddle.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`paddle.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.qwen2( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct(logits.reshape([-1, self.num_labels]), labels.reshape([-1])) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class Qwen2SentenceEmbedding(Qwen2PretrainedModel): + def __init__( + self, + config: Qwen2Config, + embedding_temperature: float = 0.02, + ): + """Qwen2SentenceEmbedding + For getting larger batch_size, we use tensor parallel to get larger batch_size. + + Args: + config (Qwen2Config): _description_ + model (Qwen2Model): _description_ + embedding_temperature (float, optional): _description_. Defaults to 0.02. + """ + super(Qwen2SentenceEmbedding, self).__init__(config) + self.config = config + self.qwen2 = Qwen2Model(config) + self.in_batch_negative_loss = SimpleContrastiveLoss(embedding_temperature) + self.world_size = dist.get_world_size() + self.process_rank = dist.get_rank() + self.embedding_negatives_cross_device = config.embedding_negatives_cross_device + if self.world_size <= 1: + self.embedding_negatives_cross_device = False + + def forward( + self, + query: Optional[Dict[str, paddle.Tensor]] = None, + passages: Optional[Dict[str, paddle.Tensor]] = None, + return_encode=False, + ): + """forward""" + q_reps = self.encode(**query) + p_reps = self.encode(**passages) + + q_reps = nn.functional.normalize(q_reps, axis=-1) + p_reps = nn.functional.normalize(p_reps, axis=-1) + + if return_encode: + return q_reps, p_reps + + if self.embedding_negatives_cross_device: + q_reps = dist_gather_tensor_with_gradient(q_reps) + p_reps = dist_gather_tensor_with_gradient(p_reps) + + loss = self.in_batch_negative_loss(q_reps, p_reps) + return loss + + def encode( + self, + input_ids, + position_ids=None, + embedding_indices=None, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=False, + **kwargs, + ): + """encode""" + input_type = type(input_ids) + outputs = self.qwen2( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + if isinstance(outputs, input_type): + hidden_states = outputs + else: + hidden_states = outputs[0] + last_hidden_states = hidden_states.gather_nd(embedding_indices) + return last_hidden_states +