issue/1180 refactor(nn): decouple RoPE scaling logic with polymorphic interfaces#1181
Open
rubik-hua wants to merge 1 commit into
Open
issue/1180 refactor(nn): decouple RoPE scaling logic with polymorphic interfaces#1181rubik-hua wants to merge 1 commit into
rubik-hua wants to merge 1 commit into
Conversation
… interfaces - Extract `RopeScalingConfig` and `LongRopeConfig` to dedicated `rope_scaling_configs.hpp/.cc` files. - Introduce `get_freq_scale` and `get_magnitude_scale` virtual methods to eliminate type-checking branches in the core `initialize_cache` loop. - Add `Llama3Config` skeleton for future Llama 3/3.1 support.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
RopeScalingConfigandLongRopeConfigto dedicatedrope_scaling_configs.hpp/.ccfiles.get_freq_scaleandget_magnitude_scalevirtual methods to eliminate type-checking branches in the coreinitialize_cacheloop.Llama3Configskeleton for future Llama 3/3.1 support.ScalingConfig 体系重构(策略抽象与集中管理)
在 infinicore/nn/rope.hpp 中,将不同 Scaling 引起的频率缩放与幅度缩放抽象为虚函数,移除原有的 ScalingType 枚举:
namespace infinicore::nn {
class ScalingConfig {
public:
virtual ~ScalingConfig() = default;
};
} // namespace infinicore::nn
所有具体的 ScalingConfig 子类统一集中定义在 infinicore/nn/rope_scaling_configs.hpp 中。这些类仅封装纯数学逻辑与必要参数,构造函数只接收原始数值:
// infinicore/nn/rope_scaling_configs.hpp
#pragma once
#include "infinicore/nn/rope.hpp"
#include
namespace infinicore::nn {
class LongRopeConfig : public ScalingConfig {
public:
LongRopeConfig(std::vector short_factor, std::vector long_factor,
size_t original_max_pos, float factor) /* ... /
float get_freq_scale(size_t j, size_t head_dim, float theta) const override {
float ext = (pos < original_max_pos) ? short_factor_[j] : long_factor_[j];
return 1.0f / ext;
}
float get_magnitude_scale(size_t pos) const override { return factor; }
// ...
};
class Llama3Config : public ScalingConfig { / ... / };
class LinearConfig : public ScalingConfig { / ... */ };
} // namespace infinicore::nn
RoPE::initialize_cache() 重构(多态分发)
重构后的 initialize_cache 通过虚函数多态进行分发,不再有任何 if-else 或 dynamic_pointer_cast,彻底消除未来扩展带来的代码修改:
void RoPE::initialize_cache() {
size_t cache_dim = head_dim_ / 2;
// ... ...
for (size_t pos = 0; pos < max_seq_len_; pos++) {
for (size_t j = 0; j < cache_dim; j++) {
float base_inv_freq = 1.0f / std::pow(theta_, 2.0f * j / head_dim_);
}
重构后,新增一种实现,只需要在rope_scaling_configs.cc/h文件中增加就行,解耦其它文件实现。
重构后跑过一遍所有现有支持的模型:

















