Skip to content

Commit

Permalink
Added support for unsorted periods
Browse files Browse the repository at this point in the history
  • Loading branch information
ankane committed Sep 22, 2023
1 parent 0ee2202 commit b122103
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
10 changes: 5 additions & 5 deletions src/lib.rs
Expand Up @@ -147,11 +147,11 @@ mod tests {

#[test]
fn test_mstl_unsorted_periods() {
let result = Mstl::fit(&generate_series(), &[10, 6]);
assert_eq!(
result.unwrap_err(),
Error::Parameter("periods must be sorted ascending".to_string())
);
let result = Mstl::fit(&generate_series(), &[10, 6]).unwrap();
assert_elements_in_delta(&[1.4130436, 1.6048906, 0.050958008, -1.8706754, -1.7704514], &result.seasonal()[0][..5]);
assert_elements_in_delta(&[0.28318232, 0.70529824, -1.980384, 2.1643379, -2.3356874], &result.seasonal()[1][..5]);
assert_elements_in_delta(&[5.139485, 5.223691, 5.3078976, 5.387292, 5.4666862], &result.trend()[..5]);
assert_elements_in_delta(&[-1.835711, 1.4661198, -1.3784716, 3.319045, -1.3605475], &result.remainder()[..5]);
}

#[test]
Expand Down
23 changes: 12 additions & 11 deletions src/params.rs
Expand Up @@ -226,16 +226,15 @@ impl MstlParams {
return Err(Error::Parameter("periods must not be empty".to_string()));
}

if !seas_ids.windows(2).all(|w| w[0] <= w[1]) {
return Err(Error::Parameter("periods must be sorted ascending".to_string()));
}

for np in seas_ids {
if k < np * 2 {
return Err(Error::Series("series has less than two periods".to_string()));
}
}

let mut indices: Vec<usize> = (0..seas_ids.len()).collect();
indices.sort_by_key(|&i| &seas_ids[i]);

let mut iterate = self.iterate;
if seas_ids.len() == 1 {
iterate = 1;
Expand All @@ -247,30 +246,32 @@ impl MstlParams {
// TODO add lambda param
let mut deseas = x.to_vec();

if seas_ids[0] > 1 {
if seas_ids[indices[0]] > 1 {
for _ in 0..seas_ids.len() {
seasonality.push(Vec::new());
}

for j in 0..iterate {
for (i, np) in seas_ids.iter().enumerate() {
for (i, idx) in indices.iter().enumerate() {
let np = seas_ids[*idx];

if j > 0 {
for (d, s) in deseas.iter_mut().zip(&seasonality[i]) {
for (d, s) in deseas.iter_mut().zip(&seasonality[*idx]) {
*d += s;
}
}

// TODO add seasonal_lengths param
let fit = if self.stl_params.ns.is_some() {
self.stl_params.fit(&deseas, *np)?
self.stl_params.fit(&deseas, np)?
} else {
let seasonal_length = 7 + 4 * (i + 1);
self.stl_params.clone().seasonal_length(seasonal_length).fit(&deseas, *np)?
self.stl_params.clone().seasonal_length(seasonal_length).fit(&deseas, np)?
};

(seasonality[i], trend, _, _) = fit.into_parts();
(seasonality[*idx], trend, _, _) = fit.into_parts();

for (d, s) in deseas.iter_mut().zip(&seasonality[i]) {
for (d, s) in deseas.iter_mut().zip(&seasonality[*idx]) {
*d -= s;
}
}
Expand Down

0 comments on commit b122103

Please sign in to comment.