Description:
Background
The current RoPE implementation tightly couples the scaling logic with the core computation loop. Specific scaling types (like LongRopeConfig) are handled via if-else branches and std::dynamic_pointer_cast inside initialize_cache. This violates the Open/Closed Principle, making it cumbersome to extend support for new RoPE scaling variants (e.g., Llama 3, YaRN) without modifying the core loop.
Proposed Changes
- Architectural Decoupling: Extract the
ScalingConfig base class and LongRopeConfig into a dedicated rope_scaling_configs.hpp/.cc file.
- Polymorphic Interface: Introduce virtual methods
get_freq_scale and get_magnitude_scale in the base class, providing default implementations that return 1.0f.
- Core Loop Simplification: Refactor
initialize_cache in rope.cc to eliminate type-checking branches, relying instead on the polymorphic interface:
float base_inv_freq = 1.0f / std::pow(...);
float freq_scale = scaling_ ? scaling_->get_freq_scale(...) : 1.0f;
float mag_scale = scaling_ ? scaling_->get_magnitude_scale(...) : 1.0f;
float angle = static_cast<float>(pos) * base_inv_freq * freq_scale;
sin_data[...] = std::sin(angle) * mag_scale;
- Llama3 Skeleton: Add the
Llama3Config class to support Llama-3/3.1 models.
Action Items / TODOs
Description:
Background
The current RoPE implementation tightly couples the scaling logic with the core computation loop. Specific scaling types (like
LongRopeConfig) are handled viaif-elsebranches andstd::dynamic_pointer_castinsideinitialize_cache. This violates the Open/Closed Principle, making it cumbersome to extend support for new RoPE scaling variants (e.g., Llama 3, YaRN) without modifying the core loop.Proposed Changes
ScalingConfigbase class andLongRopeConfiginto a dedicatedrope_scaling_configs.hpp/.ccfile.get_freq_scaleandget_magnitude_scalein the base class, providing default implementations that return1.0f.initialize_cacheinrope.ccto eliminate type-checking branches, relying instead on the polymorphic interface:Llama3Configclass to support Llama-3/3.1 models.Action Items / TODOs
Llama3Config::get_freq_scaleis a placeholder returning1.0f. It needs to be implemented with the wavelength-based smooth interpolation logic to ensure correct context extension for Llama 3 models.LongRopeConfigconstructor, explicitly castoriginal_max_position_embeddings_todoublewhen passing it tostd::logto avoid implicit type conversion warnings.