Skip to content

Commit

Permalink
Added lambda parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 22, 2023
1 parent 19ff9f6 commit 6cd4a14
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -87,6 +87,7 @@ Set MSTL parameters [unreleased]
```rust
Mstl::params()
.iterations(2) // number of iterations
.lambda(0.5) // lambda for Box-Cox transformation
.stl_params(Stl::params()) // STL params
```

Expand Down
19 changes: 19 additions & 0 deletions src/mstl.rs
Expand Up @@ -63,6 +63,25 @@ mod tests {
assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]);
}

#[test]
fn test_lambda() {
let result = Mstl::params().lambda(0.5).fit(&generate_series(), &[6, 10]).unwrap();
assert_elements_in_delta(&[0.43371448, 0.10503793, -0.7178911, 1.2356076, -1.8253292], &result.seasonal()[0][..5]);
assert_elements_in_delta(&[1.0437742, 0.8650516, 0.07303603, -1.428663, -1.1990008], &result.seasonal()[1][..5]);
assert_elements_in_delta(&[2.0748303, 2.1291165, 2.1834028, 2.2330272, 2.2826517], &result.trend()[..5]);
assert_elements_in_delta(&[-1.0801829, 0.900794, -0.7101207, 1.9600279, -1.2583216], &result.remainder()[..5]);
}

#[test]
fn test_lambda_zero() {
let series: Vec<f32> = generate_series().iter().map(|&v| v + 1.0).collect();
let result = Mstl::params().lambda(0.0).fit(&series, &[6, 10]).unwrap();
assert_elements_in_delta(&[0.18727916, 0.029921893, -0.2716494, 0.47748315, -0.7320051], &result.seasonal()[0][..5]);
assert_elements_in_delta(&[0.42725056, 0.32145387, -0.019030934, -0.56607914, -0.46765903], &result.seasonal()[1][..5]);
assert_elements_in_delta(&[1.592807, 1.6144379, 1.6360688, 1.6559447, 1.6758206], &result.trend()[..5]);
assert_elements_in_delta(&[-0.41557717, 0.33677137, -0.24677622, 0.7352363, -0.47615635], &result.remainder()[..5]);
}

#[test]
fn test_empty_periods() {
let periods: Vec<usize> = Vec::new();
Expand Down
22 changes: 20 additions & 2 deletions src/mstl_params.rs
Expand Up @@ -6,13 +6,15 @@ use super::{Error, MstlResult, StlParams};

pub struct MstlParams {
iterate: usize,
lambda: Option<f32>,
stl_params: StlParams,
}

impl MstlParams {
pub fn new() -> Self {
Self {
iterate: 2,
lambda: None,
stl_params: StlParams::new(),
}
}
Expand All @@ -22,6 +24,11 @@ impl MstlParams {
self
}

pub fn lambda(&mut self, lambda: f32) -> &mut Self {
self.lambda = Some(lambda);
self
}

pub fn stl_params(&mut self, stl_params: StlParams) -> &mut Self {
self.stl_params = stl_params;
self
Expand Down Expand Up @@ -59,8 +66,11 @@ impl MstlParams {
let mut seasonality = Vec::with_capacity(seas_ids.len());
let mut trend = Vec::new();

// TODO add lambda param
let mut deseas = x.to_vec();
let mut deseas = if let Some(lambda) = self.lambda {
box_cox(x, lambda)
} else {
x.to_vec()
};

if !seas_ids.is_empty() {
for _ in 0..seas_ids.len() {
Expand Down Expand Up @@ -115,3 +125,11 @@ impl Default for MstlParams {
Self::new()
}
}

fn box_cox(y: &[f32], lambda: f32) -> Vec<f32> {
if lambda > 0.0 {
y.iter().map(|yi| (yi.powf(lambda) - 1.0) / lambda).collect()
} else {
y.iter().map(|yi| yi.ln()).collect()
}
}

0 comments on commit 6cd4a14

Please sign in to comment.