Skip to content

Commit

Permalink
Merge pull request #221 from abstractqqq/rolling_lr
Browse files Browse the repository at this point in the history
  • Loading branch information
abstractqqq committed Aug 2, 2024
2 parents 747b69e + b25f114 commit 49eca96
Show file tree
Hide file tree
Showing 10 changed files with 758 additions and 545 deletions.
556 changes: 279 additions & 277 deletions examples/basics.ipynb

Large diffs are not rendered by default.

84 changes: 76 additions & 8 deletions python/polars_ds/num.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"query_principal_components",
"query_lstsq",
"query_recursive_lstsq",
"query_rolling_lstsq",
"query_lstsq_report",
"query_jaccard_row",
"query_jaccard_col",
Expand Down Expand Up @@ -330,9 +331,7 @@ def query_lstsq(


def query_recursive_lstsq(
*x: str | pl.Expr,
target: str | pl.Expr,
start_at: int,
*x: str | pl.Expr, target: str | pl.Expr, start_at: int, null_policy: NullPolicy = "raise"
):
"""
Using the first `start_at` rows of data as basis, start computing the least square solutions
Expand All @@ -355,22 +354,91 @@ def query_recursive_lstsq(
The target variable
start_at: int
Must be >= 1. Rows before start_at will be used as the first initial fit on the data.
null_policy: Literal['raise', 'skip', 'zero', 'one']
Currently, this only supports raise and any kind of direct fill strategy. `skip` doesn't work.
This won't fill target, and if target has null, an error will be thrown.
"""

if null_policy == "skip":
raise NotImplementedError

features = [str_to_expr(z) for z in x]
if start_at >= 1:
start = start_at
if start_at < len(features):
import warnings

warnings.warn(
f"Input `start_at` must be >= the number of features. It is reset to {len(features)}",
stacklevel=2,
)

start = max(start_at, len(features))
else:
raise ValueError("You must start at >=1 for recursive lstsq.")

recursive_lr_kwargs = {"null_policy": "raise", "n": start}

kwargs = {"null_policy": "raise", "n": start}
t = str_to_expr(target).cast(pl.Float64)
cols = [t]
cols.extend(str_to_expr(z) for z in x)
cols.extend(features)
return pl_plugin(
symbol="pl_recursive_lstsq",
args=cols,
kwargs=recursive_lr_kwargs,
kwargs=kwargs,
is_elementwise=True,
pass_name_to_apply=True,
)


def query_rolling_lstsq(
*x: str | pl.Expr, target: str | pl.Expr, window_size: int, null_policy: NullPolicy = "raise"
):
"""
Using every `window_size` rows of data as feature matrix, and computes least square solutions
by rolling the window. A prediction for that row will also be included in the output.
This uses the famous Sherman-Morrison-Woodbury Formula under the hood.
Note: Currently this requires all input data to have no nulls.
Note: Recursive L2 regularized lstsq is on the roadmap and will come in later versions.
Parameters
----------
x : str | pl.Expr
The variables used to predict target
target : str | pl.Expr
The target variable
window_size: int
Must be >= 1. Rows before start_at will be used as the first initial fit on the data.
null_policy: Literal['raise', 'skip', 'zero', 'one']
Currently, this only supports raise and any kind of direct fill strategy. `skip` doesn't work.
This won't fill target, and if target has null, an error will be thrown.
"""

if null_policy == "skip":
raise NotImplementedError

features = [str_to_expr(z) for z in x]
if window_size >= 1:
if window_size < len(features):
import warnings

warnings.warn(
f"Input `window_size` must be >= the number of features. It is reset to {len(features)}",
stacklevel=2,
)

start = max(window_size, len(features))
else:
raise ValueError("You must window_size >=1 for rolling lstsq.")

kwargs = {"null_policy": "raise", "n": start}
t = str_to_expr(target).cast(pl.Float64)
cols = [t]
cols.extend(features)
return pl_plugin(
symbol="pl_rolling_lstsq",
args=cols,
kwargs=kwargs,
is_elementwise=True,
pass_name_to_apply=True,
)
Expand Down
33 changes: 10 additions & 23 deletions src/arkadia/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,7 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
radius: T,
);

fn within_count_one_step(
&self,
pending: &mut Vec<(T, &Self)>,
point: &[T],
radius: T,
) -> u32;
fn within_count_one_step(&self, pending: &mut Vec<(T, &Self)>, point: &[T], radius: T) -> u32;

fn knn(&self, k: usize, point: &[T], epsilon: T) -> Option<Vec<NB<T, A>>> {
if k == 0 || (point.len() != self.dim()) || (point.iter().any(|x| !x.is_finite())) {
Expand All @@ -83,20 +78,19 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
let mut pending = Vec::with_capacity(k + 1);
pending.push((T::min_value(), self));
while !pending.is_empty() {
self.knn_one_step(
&mut pending,
&mut top_k,
k,
point,
T::max_value(),
epsilon,
);
self.knn_one_step(&mut pending, &mut top_k, k, point, T::max_value(), epsilon);
}
Some(top_k)
}
}

fn knn_bounded(&self, k: usize, point: &[T], max_dist_bound: T, epsilon:T) -> Option<Vec<NB<T, A>>> {
fn knn_bounded(
&self,
k: usize,
point: &[T],
max_dist_bound: T,
epsilon: T,
) -> Option<Vec<NB<T, A>>> {
if k == 0
|| (point.len() != self.dim())
|| (point.iter().any(|x| !x.is_finite()))
Expand All @@ -109,14 +103,7 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
let mut pending = Vec::with_capacity(k + 1);
pending.push((T::min_value(), self));
while !pending.is_empty() {
self.knn_one_step(
&mut pending,
&mut top_k,
k,
point,
max_dist_bound,
epsilon,
);
self.knn_one_step(&mut pending, &mut top_k, k, point, max_dist_bound, epsilon);
}
Some(top_k)
}
Expand Down
4 changes: 1 addition & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,11 @@ mod utils;

use faer_ext::{IntoFaer, IntoNdarray};
use numpy::{Ix1, Ix2, PyArray, PyReadonlyArray2, ToPyArray};
use pyo3::{types::PyModule, pymodule, Bound, PyResult, Python};

use pyo3::{pymodule, types::PyModule, Bound, PyResult, Python};

#[pymodule]
#[pyo3(name = "_polars_ds")]
fn _polars_ds(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {

// How do I factor out this? I don't want to put all code here.
#[pyfn(m)]
#[pyo3(name = "pds_faer_lr")]
Expand Down
Loading

0 comments on commit 49eca96

Please sign in to comment.