Make NS coefficients parameter 2D in Python API#2904
Conversation
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Greptile SummaryThis PR refactors the Newton-Schulz coefficient API to use a list of
Confidence Score: 4/5Safe to merge once the breaking-change nature is acknowledged and communicated to users. The refactoring logic is internally consistent and correct — the C++ backend path is unchanged. The single P1 concern is that the public transformer_engine/pytorch/newton_schulz.py — the Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Caller: newton_schulz(x, ctx, num_iterations, coefficients)"]
B{coefficients is None?}
C["get_coefficients(num_iterations) - returns list of CoeffT tuples"]
D{len == num_iterations?}
E["ValueError: wrong length"]
F["Flatten: validate each tuple len==3, extend flat_coefficients"]
G["tex.newton_schulz - C++ backend receives flat list, unchanged"]
A --> B
B -- Yes --> C --> D
B -- No --> D
D -- No --> E
D -- Yes --> F --> G
Reviews (1): Last reviewed commit: "Make NS coefficients parameter 2D in Pyt..." | Re-trigger Greptile |
|
/te-ci pytorch L1 |
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Description
Make passing coefficient to Newton-Schulz more consistent with the one in EmergingOptimizers
Fixes # (issue)
Type of change
Changes
Checklist: