Skip to content

issue/1180 refactor(nn): decouple RoPE scaling logic with polymorphic interfaces#1181

Open
rubik-hua wants to merge 1 commit into
InfiniTensor:mainfrom
rubik-hua:refactor_rope
Open

issue/1180 refactor(nn): decouple RoPE scaling logic with polymorphic interfaces#1181
rubik-hua wants to merge 1 commit into
InfiniTensor:mainfrom
rubik-hua:refactor_rope

Conversation

@rubik-hua
Copy link
Copy Markdown

  • Extract RopeScalingConfig and LongRopeConfig to dedicated rope_scaling_configs.hpp/.cc files.
  • Introduce get_freq_scale and get_magnitude_scale virtual methods to eliminate type-checking branches in the core initialize_cache loop.
  • Add Llama3Config skeleton for future Llama 3/3.1 support.

ScalingConfig 体系重构(策略抽象与集中管理)
在 infinicore/nn/rope.hpp 中,将不同 Scaling 引起的频率缩放与幅度缩放抽象为虚函数,移除原有的 ScalingType 枚举:
namespace infinicore::nn {
class ScalingConfig {
public:
virtual ~ScalingConfig() = default;

// 频率缩放因子(默认1.0,不缩放)
virtual float get_freq_scale(size_t j, size_t head_dim, float theta) const { return 1.0f; }

// 幅度缩放因子(默认1.0,不缩放,如 LongRoPE 的 table_factor)
virtual float get_magnitude_scale(size_t pos) const { return 1.0f; }

};
} // namespace infinicore::nn
所有具体的 ScalingConfig 子类统一集中定义在 infinicore/nn/rope_scaling_configs.hpp 中。这些类仅封装纯数学逻辑与必要参数,构造函数只接收原始数值:
// infinicore/nn/rope_scaling_configs.hpp
#pragma once
#include "infinicore/nn/rope.hpp"
#include
namespace infinicore::nn {
class LongRopeConfig : public ScalingConfig {
public:
LongRopeConfig(std::vector short_factor, std::vector long_factor,
size_t original_max_pos, float factor) /* ... /
float get_freq_scale(size_t j, size_t head_dim, float theta) const override {
float ext = (pos < original_max_pos) ? short_factor_[j] : long_factor_[j];
return 1.0f / ext;
}
float get_magnitude_scale(size_t pos) const override { return factor
; }
// ...
};
class Llama3Config : public ScalingConfig { /
... / };
class LinearConfig : public ScalingConfig { /
... */ };
} // namespace infinicore::nn

RoPE::initialize_cache() 重构(多态分发)
重构后的 initialize_cache 通过虚函数多态进行分发,不再有任何 if-else 或 dynamic_pointer_cast,彻底消除未来扩展带来的代码修改:
void RoPE::initialize_cache() {
size_t cache_dim = head_dim_ / 2;
// ... ...
for (size_t pos = 0; pos < max_seq_len_; pos++) {
for (size_t j = 0; j < cache_dim; j++) {
float base_inv_freq = 1.0f / std::pow(theta_, 2.0f * j / head_dim_);

        // 多态调用,自动路由到 LongRopeConfig、Llama3Config 等具体实现
        float freq_scale = scaling_ ? scaling_->get_freq_scale(j, head_dim_, theta_) : 1.0f;
        float mag_scale = scaling_ ? scaling_->get_magnitude_scale(pos) : 1.0f;
        float angle = static_cast<float>(pos) * base_inv_freq * freq_scale;
        sin_data[pos * cache_dim + j] = std::sin(angle) * mag_scale;
        cos_data[pos * cache_dim + j] = std::cos(angle) * mag_scale;
    }
}

}

重构后,新增一种实现,只需要在rope_scaling_configs.cc/h文件中增加就行,解耦其它文件实现。

重构后跑过一遍所有现有支持的模型:
image
image
image
image
image
image
image
image
image
image
image
image
image
image
image
image
image
image

… interfaces

- Extract `RopeScalingConfig` and `LongRopeConfig` to dedicated
  `rope_scaling_configs.hpp/.cc` files.
- Introduce `get_freq_scale` and `get_magnitude_scale` virtual methods
  to eliminate type-checking branches in the core `initialize_cache` loop.
- Add `Llama3Config` skeleton for future Llama 3/3.1 support.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant