Skip to content

issue/401 refactor(rope): add scaling factory and TP-safe RoPE cache#402

Open
rubik-hua wants to merge 1 commit into
InfiniTensor:mainfrom
rubik-hua:refactor_rope
Open

issue/401 refactor(rope): add scaling factory and TP-safe RoPE cache#402
rubik-hua wants to merge 1 commit into
InfiniTensor:mainfrom
rubik-hua:refactor_rope

Conversation

@rubik-hua
Copy link
Copy Markdown

@rubik-hua rubik-hua commented May 28, 2026

  • Decouple scaling config instantiation from ModelConfig via factory and registry pattern.
  • Add thread-local RoPE cache with device-scoped keys to reduce VRAM usage and ensure TP safety.
  • Centralize rotary dimension calculation into ModelConfig.

ModelConfig 扩展(纯粹的数据承载)
ModelConfig 不再掺杂任何具体模型的业务判断,仅提供默认值和读写接口。RoPE::Algo 的差异由具体的模型构建入口(如 csrc/models/chatglm/chatglm_for_causal_lm.cpp)显式指定:
// model_config.hpp
class ModelConfig {
private:
infinicore::nn::RoPE::Algo rope_algo_ = infinicore::nn::RoPE::Algo::GPT_NEOX; // 默认值
public:
infinicore::nn::RoPE::Algo get_rope_algo() const { return rope_algo_; }

};
// csrc/models/chatglm/chatglm_for_causal_lm.cpp
std::shared_ptr create_chatglm_config(const json& hf_config) {
auto config = std::make_shared(hf_config);
// 只有 ChatGLM/GLM4 需要 GPT_J,在此处显式注入,不污染基类
config->set_rope_algo(infinicore::nn::RoPE::Algo::GPT_J);
return config;
}
工厂与注册表机制(字符串路由分发)
引入注册表模式,将 JSON 中的字符串(如 "longrope")映射到具体的对象构造逻辑,替代冗长的 if-else。
// rotary_embedding_factory.hpp
using ScalingCreator = std::function<std::shared_ptrinfinicore::nn::ScalingConfig(
const std::shared_ptrinfinilm::config::ModelConfig&)>;
std::unordered_map<std::string, ScalingCreator>& get_scaling_registry();
std::shared_ptrinfinicore::nn::RoPE make_rope(/* ... */);
工厂核心实现极简,仅负责组装与路由,不因为新增类型而修改:
// rotary_embedding_factory.cpp
std::shared_ptrinfinicore::nn::ScalingConfig
make_scaling_config(const std::shared_ptrinfinilm::config::ModelConfig& model_config) {
std::string scaling_type = model_config->get_orstd::string("rope_scaling_type", "default");

// 分发点:注册表路由,将字符串映射到具体的 Creator 函数
auto& registry = get_scaling_registry();
auto it = registry.find(scaling_type);
if (it != registry.end()) {
    return it->second(model_config); 
}
throw std::runtime_error("Unsupported rope_scaling_type: " + scaling_type);

}

需要与InfiniCore的下面PR一起合入,
InfiniTensor/InfiniCore#1181

重构后,新增的rope实现都集中在csrc/layers/rotary_embedding/rope_scaling_creators.cpp增加,其它地方无需修改,跟rotary_embedding.cpp和model_config.cpp解耦掉。

之前@pengcheng888 给的建议是把algo参数收编进model_config中,然后在xx_for_causal_lm.cpp 中写入,我实现了一版,但感觉特别别扭,我理解model_config还是纯粹一点好,能从json中读出来或者加工出来的。后来,我又改动了一下,还是直接放到运行时传参更加优雅吧。

重构后所有现有支持的模型已经跑通

image image image image image image image image image image image image image image image image image image

- Decouple scaling config instantiation from ModelConfig via factory
  and registry pattern.
- Add thread-local RoPE cache with device-scoped keys to reduce VRAM
  usage and ensure TP safety.
- Centralize rotary dimension calculation into ModelConfig.
@rubik-hua rubik-hua requested a review from a team May 28, 2026 09:36
@rubik-hua
Copy link
Copy Markdown
Author

@wooway777 @pengcheng888 rope重构可以帮忙检视起来了,infinicore上也有一个pr

@wooway777
Copy link
Copy Markdown
Collaborator

@wooway777 @pengcheng888 rope重构可以帮忙检视起来了,infinicore上也有一个pr

谢谢老师,在看了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants