Skip to content

MatthewK78/Rose

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

5 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Rose

Range-Of-Slice Equilibration
PyTorch Optimizer

Stateless optimization through range-normalized gradient updates.

In loving memory of my mother, Rose Kieren.


Python 3.10+ License
Github Sponsors PayPal


πŸ“° News

2026-04-17 [v1.0.0] β€” Initial public release

🌹 Introduction

Most adaptive optimizers (such as Adam, RMSprop, Adafactor, and their many variants) accumulate running statistics for every parameter: first-moment estimates, second-moment estimates, step counters, and sometimes more. These buffers can double or triple the memory footprint of a model's parameters and introduce temporal entanglements (bias correction, momentum decay, sensitivity to $\beta$ schedules) that make training dynamics harder to reason about.

Rose asks a simple question: how much can you accomplish with just the gradient you have right now?

At each step, Rose normalizes every gradient tensor by a per-slice range (the difference between the maximum and minimum values computed across all dimensions beyond the leading axis) yielding one adaptive scale factor per output unit. An optional coefficient-of-variation trust gate blends per-slice ranges with their global mean when the ranges are noisy, and optional gradient centralization removes shared directional bias before scaling.

In my own preliminary experiments, it has shown extremely promising convergence and generalization.

πŸ“¦ Installation

pip install git+https://github.com/MatthewK78/Rose

Usage:

from rose import Rose

optimizer = Rose(params, lr=1e-3)

Requires Python β‰₯ 3.10 and PyTorch β‰₯ 2.0

✨ Features

Feature Detail
Zero optimizer state No momentum, variance estimates, or step counters. Memory cost is parameters + gradients + processing, nothing else.
Gradient centralization Removes the per-slice mean from gradients of rank β‰₯ 2, reducing internal covariate shift in the gradient signal and often improving stability and generalization.
CV trust gating Automatically detects when per-slice ranges are noisy and gracefully falls back to a robust global estimate. No manual tuning required.
Decoupled weight decay Standard or schedule-coupled weight decay, preventing late-training decay from overpowering vanishing learning rates.
BF16 stochastic rounding Unbiased rounding for BFloat16 parameters eliminates systematic truncation drift, meaningfully improving low-precision training fidelity.
Configurable compute precision Promotes intermediates to FP64 by default (FP32, BF16, FP16, or native dtype also supported) so that range and division arithmetic stays precise.

πŸ”¬ Method

Consider a linear layer with weight matrix $W \in \mathbb{R}^{m \times n}$. Its gradient $G$ has the same shape: $m$ rows, one per output neuron. Rose computes the range (max $-$ min) across the $n$ input-facing elements of each row independently, producing $m$ per-neuron scale factors. For rank-1 parameters such as biases, the range is computed over the full tensor and damped by adding 1, which provides a smooth interpolation between SGD-like behavior (for small ranges) and range-normalized behavior (for large ranges). Truly scalar parameters receive a softsign-like update instead.

This is analogous to how Adam assigns each scalar parameter its own adaptive denominator via a running variance estimate. Rose instead assigns each output slice a denominator based on the instantaneous spread of its gradient, requiring no history at all.

The trust gate addresses a practical concern: when per-slice ranges vary wildly (high coefficient of variation), the individual ranges may become unreliable. The trust factor $\tau = \mu / (\mu + \sigma)$ is close to 1 when ranges are self-consistent and close to 0 when they are noisy. The denominator smoothly interpolates between the local range (full detail) and the global mean range (maximum noise resistance).

Why range instead of variance? Range is cheaper to compute, requires no centering, and maps cleanly to the idea of scale equilibration: it answers "how wide is this gradient slice?" rather than "how energetic is it?" In practice, for the shapes common in deep learning, the two carry similar information, but range has the advantage of depending only on two order statistics and producing a scale factor that directly normalizes the gradient's dynamic range.

πŸŽ›οΈ Hyperparameters

lr: Learning Rate

The global step size. Start with values you would try for Adam (e.g., 1e-3). Because the denominator is range-based rather than RMS-based, effective update magnitudes differ; some tuning is expected, but the neighborhood is similar.

Rose(params, lr=1e-3)

weight_decay: Decoupled Weight Decay

Default 1e-4
Disable 0 or None

A decoupled multiplicative coefficient applied independently of the gradient step, shrinking weights toward zero each step. This is the same formulation used by AdamW.

Rose(params, lr=1e-3, weight_decay=1e-4)  # default
Rose(params, lr=1e-3, weight_decay=0)     # disabled

wd_schedule: Schedule-Coupled Weight Decay

Default False

Scales weight decay proportionally with a learning-rate schedule so that decay weakens as the learning rate drops. This prevents weight decay from dominating the update in the late phase of training when the learning rate is small. The per-step multiplicative factor becomes:

$1 - \frac{\eta_t}{\eta_{\text{ref}}} \cdot \lambda$

Value Behavior
False Standard decoupled weight decay.
True $\eta_{\text{ref}}$ is resolved from max_lr β†’ initial_lr β†’ constructor lr.
float The provided value is used directly as $\eta_{\text{ref}}$.
Rose(params, lr=1e-3, weight_decay=1e-4, wd_schedule=True)  # auto reference
Rose(params, lr=1e-3, weight_decay=1e-4, wd_schedule=1e-3)  # explicit reference

centralize: Gradient Centralization

Default True

Subtracts the mean of each gradient slice along the non-leading axes before the range computation. Only applies to parameters with rank β‰₯ 2, biases and other 1-D parameters are never centralized.

Gradient centralization constrains updates in the subspace orthogonal to the slice mean, which can act as a mild regularizer and improve training stability.

Rose(params, lr=1e-3, centralize=True)   # default
Rose(params, lr=1e-3, centralize=False)  # disabled

stabilize: Coefficient-of-Variation Trust Gating

Default True

Computes a trust factor from the coefficient of variation of the per-slice range tensor and interpolates between the local per-slice range and the global mean range.

  • Trust β‰ˆ 1 (consistent ranges) β†’ local detail preserved.
  • Trust β‰ˆ 0 (noisy ranges) β†’ smooth global fallback.
Rose(params, lr=1e-3, stabilize=True)   # default
Rose(params, lr=1e-3, stabilize=False)  # raw per-slice ranges only

bf16_sr: BFloat16 Stochastic Rounding

Default True

When a parameter is stored in BFloat16, promotes it to higher precision for the update, then stochastically rounds on write-back. This produces statistically unbiased rounding, correcting for the systematic truncation drift that BF16's limited mantissa otherwise introduces. Has no effect on parameters with any other dtype.

Value Effect
False BF16 stochastic rounding is disabled.
True Uses the default random-number generator.
torch.Generator Treats bf16_sr as enabled and forwards the generator to random_. Useful for deterministic output.
Rose(params, lr=1e-3, bf16_sr=True)   # default
Rose(params, lr=1e-3, bf16_sr=False)  # plain truncation

sr_gen = torch.Generator(device="cuda").manual_seed(0xd1ce)
Rose(params, lr=1e-3, bf16_sr=sr_gen)

compute_dtype: Internal Compute Precision

Default "fp64"

The dtype to which parameters and gradients are promoted for the update step. FP64 is recommended; the intermediate range computation and division benefit from the extra mantissa bits, especially for parameters with large fan-in or near-zero gradient spread.

Value Effect
torch.float64 / "fp64" Full FP64 precision for all intermediates.
torch.float32 / "fp32" Reasonable fallback when FP64 is too costly.
torch.float16 / "fp16" Not generally recommended; listed for completeness.
torch.bfloat16 / "bf16" Not generally recommended; listed for completeness.
None / "none" No promotion; compute in native dtype (however, BF16 params still use FP32 if bf16_sr=True).
Rose(params, lr=1e-3, compute_dtype="fp64")         # default
Rose(params, lr=1e-3, compute_dtype=torch.float32)  # lighter
Rose(params, lr=1e-3, compute_dtype=None)           # native

πŸ’– Acknowledgements

This optimizer is named in loving memory of my mother, Rose Kieren, who always listened intently with genuine interest as I rambled about AI and computers, and whose unconditional love and presence through both the best and hardest of times made all of this possible.

I am deeply grateful to my late father for always taking an interest in my exploration of technology, and for the encouragement and warmth he brought to every conversation.

To my wife, for her extraordinary patience through the countless days and nights I've spent absorbed in programming and research, and for her unwavering intellectual and emotional support.

To my son, a source of joy, perspective, and inspiration, whose curiosity and bright spirit remind me why building for the future matters.

As an independent researcher, I am grateful for the love and support of all of my family and friends. This work would not have been possible without them.

😊 A Kind Request

If you use Rose in your research, project, or product, I would be grateful if you would mention it by name and credit its author, Matthew E. Kieren. A citation (see below), a footnote, a line in your README; any acknowledgment, however small, helps motivate me to do more. And if you have a moment, I would love to hear your story.

If you'd like to support my ongoing development efforts of this project and others, you can send a donation here on GitHub, or through PayPal.

Your support and acknowledgment are sincerely appreciated! 😊

πŸ“„ Citation

@software{kieren2026rose,
  author       = {Kieren, Matthew E.},
  title        = {Rose: Range-Of-Slice Equilibration optimizer},
  year         = {2026},
  publisher    = {Zenodo},
  doi          = {10.5281/zenodo.19589765},
  url          = {https://doi.org/10.5281/zenodo.19589765}
}

πŸ“š References

1 Kingma, D. P. & Ba, J. (2014), Adam: A Method for Stochastic Optimization. arXiv:1412.6980

2 Loshchilov, I. & Hutter, F. (2017), Decoupled Weight Decay Regularization. arXiv:1711.05101

3 Hazan, E., Levy, K. Y., & Shalev-Shwartz, S. (2015), Beyond Convexity: Stochastic Quasi-Convex Optimization. arXiv:1507.02030

4 You, Y., Gitman, I., & Ginsburg, B. (2017), Large Batch Training of Convolutional Networks arXiv:1708.03888

5 Yong, H., Huang, J., Hua, X. & Zhang, L. (2020). Gradient Centralization: A New Optimization Technique for Deep Neural Networks. arXiv:2004.01461

6 Zamirai, P., Zhang, J., Aberger, C. R. & De Sa, C. (2020). Revisiting BFloat16 Training. arXiv:2010.06192

βš–οΈ License

Copyright Β© 2026 Matthew Everet Kieren

Licensed under the Apache License, Version 2.0. You may use, modify, and distribute this software in accordance with the license. See LICENSE for the full text.


Releases

No releases published

Sponsor this project

 

Packages

 
 
 

Contributors

Languages