diff --git a/src/machine_learning/optimization/adam.rs b/src/machine_learning/optimization/adam.rs index 6fbebc6d39d..724b56e75f1 100644 --- a/src/machine_learning/optimization/adam.rs +++ b/src/machine_learning/optimization/adam.rs @@ -5,12 +5,19 @@ //! learning problems. Boasting memory-efficient fast convergence rates, it sets and iteratively //! updates learning rates individually for each model parameter based on the gradient history. //! +//! Setting `weight_decay > 0.0` enables the AdamW variant (Loshchilov & Hutter, 2019), which +//! applies weight decay directly to the parameters rather than folding it into the gradients. +//! This keeps the decay rate constant and independent of the gradient history — the key flaw +//! that AdamW corrects over naive L2 regularization inside Adam. With `weight_decay = 0.0` +//! (the default), the two algorithms are identical. +//! //! ## Algorithm: //! //! Given: //! - α is the learning rate //! - (β_1, β_2) are the exponential decay rates for moment estimates //! - ϵ is any small value to prevent division by zero +//! - λ is the weight decay coefficient (0.0 for standard Adam, > 0.0 for AdamW) //! - g_t are the gradients at time step t //! - m_t are the biased first moment estimates of the gradient at time step t //! - v_t are the biased second raw moment estimates of the gradient at time step t @@ -28,20 +35,25 @@ //! while θ_t not converged do //! m_t = β_1 * m_{t−1} + (1 − β_1) * g_t //! v_t = β_2 * v_{t−1} + (1 − β_2) * g_t^2 -//! m_hat_t = m_t / 1 - β_1^t -//! v_hat_t = v_t / 1 - β_2^t -//! θ_t = θ_{t-1} − α * m_hat_t / (sqrt(v_hat_t) + ϵ) +//! m_hat_t = m_t / (1 - β_1^t) +//! v_hat_t = v_t / (1 - β_2^t) +//! θ_t = θ_{t-1} − α * (m_hat_t / (sqrt(v_hat_t) + ϵ) + λ * θ_{t-1}) //! //! ## Resources: //! - Adam: A Method for Stochastic Optimization (by Diederik P. Kingma and Jimmy Ba): //! - [https://arxiv.org/abs/1412.6980] +//! - Decoupled Weight Decay Regularization (by Ilya Loshchilov and Frank Hutter): +//! - [https://arxiv.org/abs/1711.05101] //! - PyTorch Adam optimizer: -//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam] +//! - [https://pytorch.org/docs/stable/generated/torch.optim.Adam.html] +//! - PyTorch AdamW optimizer: +//! - [https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html] //! pub struct Adam { learning_rate: f64, // alpha: initial step size for iterative optimization betas: (f64, f64), // betas: exponential decay rates for moment estimates epsilon: f64, // epsilon: prevent division by zero + weight_decay: f64, // lambda: decoupled weight decay coefficient (0.0 = standard Adam) m: Vec, // m: biased first moment estimate of the gradient vector v: Vec, // v: biased second raw moment estimate of the gradient vector t: usize, // t: time step @@ -52,20 +64,38 @@ impl Adam { learning_rate: Option, betas: Option<(f64, f64)>, epsilon: Option, + weight_decay: Option, params_len: usize, ) -> Self { Adam { learning_rate: learning_rate.unwrap_or(1e-3), // typical good default lr betas: betas.unwrap_or((0.9, 0.999)), // typical good default decay rates epsilon: epsilon.unwrap_or(1e-8), // typical good default epsilon + weight_decay: weight_decay.unwrap_or(0.0), // 0.0 = standard Adam, > 0.0 = AdamW m: vec![0.0; params_len], // first moment vector elements all initialized to zero v: vec![0.0; params_len], // second moment vector elements all initialized to zero t: 0, // time step initialized to zero } } - pub fn step(&mut self, gradients: &[f64]) -> Vec { - let mut model_params = vec![0.0; gradients.len()]; + /// Computes one update step. + /// + /// `params` holds the current parameter values θ_{t-1}. When `weight_decay` + /// is `0.0` the update is standard Adam; any positive value applies the AdamW + /// decoupled decay term `λ * θ_{t-1}` directly to the parameters, independent + /// of the adaptive scaling. + /// + /// # Panics + /// + /// Panics if `gradients` and `params` have different lengths. + pub fn step(&mut self, gradients: &[f64], params: &[f64]) -> Vec { + assert_eq!( + gradients.len(), + params.len(), + "gradients and params must have the same length" + ); + + let mut updated_params = vec![0.0; params.len()]; self.t += 1; for i in 0..gradients.len() { @@ -77,10 +107,15 @@ impl Adam { let m_hat = self.m[i] / (1.0 - self.betas.0.powi(self.t as i32)); let v_hat = self.v[i] / (1.0 - self.betas.1.powi(self.t as i32)); - // update model parameters - model_params[i] -= self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon); + // Adaptive gradient step — preserves the original (lr * m_hat) / denom + // operator order so floating-point results are identical to standard Adam + // when weight_decay = 0.0. The decoupled decay term is added separately + // so it does not interact with the adaptive scaling. + updated_params[i] = params[i] + - self.learning_rate * m_hat / (v_hat.sqrt() + self.epsilon) + - self.learning_rate * self.weight_decay * params[i]; } - model_params // return updated model parameters + updated_params // return updated model parameters } } @@ -88,13 +123,16 @@ impl Adam { mod tests { use super::*; + // ── Initialisation ──────────────────────────────────────────────────────── + #[test] fn test_adam_init_default_values() { - let optimizer = Adam::new(None, None, None, 1); + let optimizer = Adam::new(None, None, None, None, 1); assert_eq!(optimizer.learning_rate, 0.001); assert_eq!(optimizer.betas, (0.9, 0.999)); assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.weight_decay, 0.0); assert_eq!(optimizer.m, vec![0.0; 1]); assert_eq!(optimizer.v, vec![0.0; 1]); assert_eq!(optimizer.t, 0); @@ -102,11 +140,12 @@ mod tests { #[test] fn test_adam_init_custom_lr_value() { - let optimizer = Adam::new(Some(0.9), None, None, 2); + let optimizer = Adam::new(Some(0.9), None, None, None, 2); assert_eq!(optimizer.learning_rate, 0.9); assert_eq!(optimizer.betas, (0.9, 0.999)); assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.weight_decay, 0.0); assert_eq!(optimizer.m, vec![0.0; 2]); assert_eq!(optimizer.v, vec![0.0; 2]); assert_eq!(optimizer.t, 0); @@ -114,11 +153,12 @@ mod tests { #[test] fn test_adam_init_custom_betas_value() { - let optimizer = Adam::new(None, Some((0.8, 0.899)), None, 3); + let optimizer = Adam::new(None, Some((0.8, 0.899)), None, None, 3); assert_eq!(optimizer.learning_rate, 0.001); assert_eq!(optimizer.betas, (0.8, 0.899)); assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.weight_decay, 0.0); assert_eq!(optimizer.m, vec![0.0; 3]); assert_eq!(optimizer.v, vec![0.0; 3]); assert_eq!(optimizer.t, 0); @@ -126,34 +166,52 @@ mod tests { #[test] fn test_adam_init_custom_epsilon_value() { - let optimizer = Adam::new(None, None, Some(1e-10), 4); + let optimizer = Adam::new(None, None, Some(1e-10), None, 4); assert_eq!(optimizer.learning_rate, 0.001); assert_eq!(optimizer.betas, (0.9, 0.999)); assert_eq!(optimizer.epsilon, 1e-10); + assert_eq!(optimizer.weight_decay, 0.0); assert_eq!(optimizer.m, vec![0.0; 4]); assert_eq!(optimizer.v, vec![0.0; 4]); assert_eq!(optimizer.t, 0); } + #[test] + fn test_adam_init_custom_weight_decay_value() { + let optimizer = Adam::new(None, None, None, Some(0.1), 3); + + assert_eq!(optimizer.learning_rate, 0.001); + assert_eq!(optimizer.betas, (0.9, 0.999)); + assert_eq!(optimizer.epsilon, 1e-8); + assert_eq!(optimizer.weight_decay, 0.1); + assert_eq!(optimizer.m, vec![0.0; 3]); + assert_eq!(optimizer.v, vec![0.0; 3]); + assert_eq!(optimizer.t, 0); + } + #[test] fn test_adam_init_all_custom_values() { - let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), 5); + let optimizer = Adam::new(Some(1.0), Some((0.001, 0.099)), Some(1e-1), Some(0.05), 5); assert_eq!(optimizer.learning_rate, 1.0); assert_eq!(optimizer.betas, (0.001, 0.099)); assert_eq!(optimizer.epsilon, 1e-1); + assert_eq!(optimizer.weight_decay, 0.05); assert_eq!(optimizer.m, vec![0.0; 5]); assert_eq!(optimizer.v, vec![0.0; 5]); assert_eq!(optimizer.t, 0); } + // ── Step: standard Adam (weight_decay = 0.0) ────────────────────────────── + #[test] fn test_adam_step_default_params() { let gradients = vec![-1.0, 2.0, -3.0, 4.0, -5.0, 6.0, -7.0, 8.0]; + let params = vec![0.0; 8]; - let mut optimizer = Adam::new(None, None, None, 8); - let updated_params = optimizer.step(&gradients); + let mut optimizer = Adam::new(None, None, None, None, 8); + let updated_params = optimizer.step(&gradients, ¶ms); assert_eq!( updated_params, @@ -173,9 +231,10 @@ mod tests { #[test] fn test_adam_step_custom_params() { let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0]; + let params = vec![0.0; 9]; - let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), 9); - let updated_params = optimizer.step(&gradients); + let mut optimizer = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), None, 9); + let updated_params = optimizer.step(&gradients, ¶ms); assert_eq!( updated_params, @@ -195,24 +254,93 @@ mod tests { #[test] fn test_adam_step_empty_gradients_array() { - let gradients = vec![]; + let gradients: Vec = vec![]; + let params: Vec = vec![]; - let mut optimizer = Adam::new(None, None, None, 0); - let updated_params = optimizer.step(&gradients); + let mut optimizer = Adam::new(None, None, None, None, 0); + let updated_params = optimizer.step(&gradients, ¶ms); assert_eq!(updated_params, vec![]); } + // ── Step: AdamW (weight_decay > 0.0) ───────────────────────────────────── + + #[test] + fn test_adamw_step_nonzero_params_applies_decay() { + // When params are non-zero and weight_decay > 0.0, the decay term must pull + // every parameter strictly closer to zero than the plain adaptive step would. + // Comparing against a no-decay run avoids replicating the internal floating + // point computation path and tests the property that actually matters. + let gradients = vec![1.0, -2.0, 3.0]; + let params = vec![0.5, -0.5, 1.0]; + + let mut with_decay = Adam::new(None, None, None, Some(0.01), 3); + let decayed = with_decay.step(&gradients, ¶ms); + + let mut no_decay = Adam::new(None, None, None, None, 3); + let not_decayed = no_decay.step(&gradients, ¶ms); + + for i in 0..params.len() { + assert!( + decayed[i].abs() < not_decayed[i].abs(), + "param[{i}]: with_decay={}, no_decay={}", + decayed[i], + not_decayed[i] + ); + } + } + + #[test] + fn test_adamw_step_weight_decay_zero_matches_adam() { + // weight_decay = 0.0 must be numerically identical to standard Adam. + let gradients = vec![9.0, -8.0, 7.0, -6.0, 5.0, -4.0, 3.0, -2.0, 1.0]; + let params = vec![0.0; 9]; + + let mut adamw = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), Some(0.0), 9); + let mut adam = Adam::new(Some(0.005), Some((0.5, 0.599)), Some(1e-5), None, 9); + + assert_eq!( + adamw.step(&gradients, ¶ms), + adam.step(&gradients, ¶ms) + ); + } + + #[test] + fn test_adamw_step_decay_pulls_params_toward_zero() { + // Each updated parameter must be closer to zero than its predecessor. + let gradients = vec![1.0, -1.0, 2.0, -2.0]; + let params = vec![0.1, -0.1, 0.2, -0.2]; + + let mut optimizer = Adam::new(Some(0.01), Some((0.9, 0.999)), Some(1e-8), Some(0.01), 4); + let updated = optimizer.step(&gradients, ¶ms); + + assert!(updated[0] < params[0]); // positive param, positive grad → decrease + assert!(updated[1] > params[1]); // negative param, negative grad → increase + assert!(updated[2] < params[2]); + assert!(updated[3] > params[3]); + } + + // ── Step: shared edge cases ─────────────────────────────────────────────── + + #[test] + #[should_panic(expected = "gradients and params must have the same length")] + fn test_step_mismatched_lengths_panics() { + let mut optimizer = Adam::new(None, None, None, None, 3); + optimizer.step(&[1.0, 2.0, 3.0], &[0.0, 0.0]); // params too short + } + + // ── Convergence (slow; marked #[ignore]) ───────────────────────────────── + #[ignore] #[test] fn test_adam_step_iteratively_until_convergence_with_default_params() { const CONVERGENCE_THRESHOLD: f64 = 1e-5; let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; - let mut optimizer = Adam::new(None, None, None, 6); + let mut optimizer = Adam::new(None, None, None, None, 6); let mut model_params = vec![0.0; 6]; - let mut updated_params = optimizer.step(&gradients); + let mut updated_params = optimizer.step(&gradients, &model_params); while (updated_params .iter() @@ -226,7 +354,7 @@ mod tests { > CONVERGENCE_THRESHOLD { model_params = updated_params; - updated_params = optimizer.step(&gradients); + updated_params = optimizer.step(&gradients, &model_params); } assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 6]); @@ -250,10 +378,10 @@ mod tests { const CONVERGENCE_THRESHOLD: f64 = 1e-7; let gradients = vec![7.0, -8.0, 9.0, -10.0, 11.0, -12.0, 13.0]; - let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), 7); + let mut optimizer = Adam::new(Some(0.005), Some((0.8, 0.899)), Some(1e-5), None, 7); let mut model_params = vec![0.0; 7]; - let mut updated_params = optimizer.step(&gradients); + let mut updated_params = optimizer.step(&gradients, &model_params); while (updated_params .iter() @@ -267,7 +395,7 @@ mod tests { > CONVERGENCE_THRESHOLD { model_params = updated_params; - updated_params = optimizer.step(&gradients); + updated_params = optimizer.step(&gradients, &model_params); } assert!(updated_params < vec![CONVERGENCE_THRESHOLD; 7]); @@ -285,4 +413,33 @@ mod tests { ] ); } + + #[ignore] + #[test] + fn test_adamw_step_iteratively_until_convergence() { + const CONVERGENCE_THRESHOLD: f64 = 1e-5; + let gradients = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + + let mut optimizer = Adam::new(None, None, None, Some(0.0), 6); + + let mut params = vec![0.0; 6]; + let mut updated = optimizer.step(&gradients, ¶ms); + + while (updated + .iter() + .zip(params.iter()) + .map(|(x, y)| x - y) + .collect::>()) + .iter() + .map(|&x| x.powi(2)) + .sum::() + .sqrt() + > CONVERGENCE_THRESHOLD + { + params = updated; + updated = optimizer.step(&gradients, ¶ms); + } + + assert_ne!(updated, params); + } }