Skip to content

Commit

Permalink
Merge pull request #222 from abstractqqq/knn_regression
Browse files Browse the repository at this point in the history
Knn regression
  • Loading branch information
abstractqqq committed Aug 3, 2024
2 parents 49eca96 + 56843a9 commit 1fcab73
Show file tree
Hide file tree
Showing 10 changed files with 599 additions and 332 deletions.
392 changes: 196 additions & 196 deletions examples/basics.ipynb

Large diffs are not rendered by default.

99 changes: 85 additions & 14 deletions python/polars_ds/knn_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from __future__ import annotations
import polars as pl
from typing import Iterable
from typing import Iterable, List
from .type_alias import StrOrExpr, str_to_expr, Distance
from ._utils import pl_plugin

__all__ = [
"query_knn_ptwise",
"query_knn_avg",
"is_knn_from",
"within_dist_from",
"query_radius_ptwise",
Expand All @@ -34,14 +35,14 @@ def query_knn_ptwise(
to each row. By default, this will return k + 1 neighbors, because the point (the row) itself
is a neighbor to itself and this returns k additional neighbors. The only exception to this
is when data_mask excludes the point from being a neighbor, in which case, k + 1 distinct neighbors will
be returned.
be returned. Any row with a null/NaN will never be a neighbor and will have null as its neighbor.
Note that the index column must be convertible to u32. If you do not have a u32 column,
you can generate one using pl.int_range(..), which should be a step before this. The index column
must not contain nulls.
Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find
k-neighbors within `max_bound`, then there will be < k neighbors returned.
k neighbors within `max_bound`, then there will be < k neighbors returned.
Also note that this internally builds a kd-tree for fast querying and deallocates it once we
are done. If you need to repeatedly run the same query on the same data, then it is not
Expand Down Expand Up @@ -81,25 +82,25 @@ def query_knn_ptwise(

idx = str_to_expr(index).cast(pl.UInt32).rechunk()
cols = [idx]
if eval_mask is None:
skip_eval = False
else:
skip_eval = True
cols.append(str_to_expr(eval_mask))
feats: List[pl.Expr] = [str_to_expr(e) for e in features]

if data_mask is None:
skip_data = False
skip_data = data_mask is not None
if skip_data:
keep_mask = pl.all_horizontal(str_to_expr(data_mask), *(f.is_not_null() for f in feats))
else:
skip_data = True
cols.append(str_to_expr(data_mask))
keep_mask = pl.all_horizontal(f.is_not_null() for f in feats)

cols.extend(str_to_expr(x) for x in features)
cols.append(keep_mask)
skip_eval = eval_mask is not None
if skip_eval:
cols.append(str_to_expr(eval_mask))

cols.extend(feats)
kwargs = {
"k": k,
"metric": str(dist).lower(),
"parallel": parallel,
"skip_eval": skip_eval,
"skip_data": skip_data,
"max_bound": max_bound,
"epsilon": abs(epsilon),
}
Expand All @@ -119,6 +120,76 @@ def query_knn_ptwise(
)


def query_knn_avg(
*features: StrOrExpr,
target: StrOrExpr,
k: int,
dist: Distance = "sql2",
weighted: bool = False,
parallel: bool = False,
min_bound: float = 1e-9,
max_bound: float = 99999.0,
) -> pl.Expr:
"""
Takes the target column, and uses feature columns to determine the k nearest neighbors
to each row. By default, this will return k + 1 neighbors, because the point (the row) itself
is a neighbor to itself and this returns k additional neighbors. Any row with a null/NaN will
never be a neighbor and will get null as the average.
Note that a default max distance bound of 99999.0 is applied. This means that if we cannot find
k neighbors within `max_bound`, then there will be < k neighbors returned.
This is also known as KNN Regression, but really it is just the average of the K nearest neighbors.
Parameters
----------
*features : str | pl.Expr
Other columns used as features
target : str | pl.Expr
Float, must be castable to f64. This should not contain null.
k : int
Number of neighbors to query
dist : Literal[`l1`, `l2`, `sql2`, `inf`, `cosine`]
Note `sql2` stands for squared l2.
weighted : bool
If weighted, it will use 1/distance as weights to compute the KNN average. If min_bound is
an extremely small value, this will default to 1/(1+distance) as weights to avoid division by 0.
parallel : bool
Whether to run the k-nearest neighbor query in parallel. This is recommended when you
are running only this expression, and not in group_by context.
min_bound
Min distance (>=) for a neighbor to be part of the average calculation. This prevents "identical"
points from being part of the average and prevents division by 0. Note that this filter is applied
after getting k nearest neighbors.
max_bound
Max distance the neighbors must be within (<)
"""
if k < 1:
raise ValueError("Input `k` must be >= 1.")

idx = str_to_expr(target).cast(pl.Float64).rechunk()
feats = [str_to_expr(f) for f in features]
keep_data = ~pl.any_horizontal(f.is_null() for f in feats)
cols = [idx, keep_data]
cols.extend(feats)

kwargs = {
"k": k,
"metric": str(dist).lower(),
"weighted": weighted,
"parallel": parallel,
"min_bound": min_bound,
"max_bound": max_bound,
}

return pl_plugin(
symbol="pl_knn_avg",
args=cols,
kwargs=kwargs,
is_elementwise=True,
)


def within_dist_from(
*features: StrOrExpr,
pt: Iterable[float],
Expand Down
9 changes: 9 additions & 0 deletions src/arkadia/arkadia_any.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use std::fmt::Debug;

/// A Kdtree
use crate::arkadia::{leaf::KdLeaf, suggest_capacity, Leaf, SplitMethod, KDTQ, NB};
use num::Float;

use super::KNNRegressor;

#[derive(Clone, PartialEq, Eq)]
pub enum DIST<T: Float + 'static> {
L1,
Expand Down Expand Up @@ -446,6 +450,11 @@ impl<'a, T: Float + 'static + std::fmt::Debug, A: Copy> KDTQ<'a, T, A> for AnyKD
}
}

impl<'a, T: Float + 'static + std::fmt::Debug + Into<f64>, A: Float + Into<f64>>
KNNRegressor<'a, T, A> for AnyKDT<'a, T, A>
{
}

#[cfg(test)]
mod tests {
use super::super::matrix_to_leaves;
Expand Down
116 changes: 75 additions & 41 deletions src/arkadia/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub mod utils;
pub use arkadia_any::{AnyKDT, DIST};
pub use leaf::{KdLeaf, Leaf};
pub use neighbor::NB;
use serde::Deserialize;
pub use utils::{
matrix_to_empty_leaves, matrix_to_leaves, matrix_to_leaves_w_row_num, suggest_capacity,
SplitMethod,
Expand All @@ -28,19 +29,24 @@ pub use utils::{
// ---------------------------------------------------------------------------------------------------------
use num::Float;

#[derive(Clone, Default)]
#[derive(Clone, Copy, Default, Deserialize)]
pub enum KNNMethod {
DInvW, // Distance Inversion Weighted. E.g. Use (1/(1+d)) to weight the regression / classification
P1Weighted, // Distance Inversion Weighted. E.g. Use (1/(1+d)) to weight the regression / classification
Weighted, // Distance Inversion Weighted. E.g. Use (1/d) to weight the regression / classification
#[default]
NoW, // No Weight
NotWeighted, // No Weight
}

impl From<bool> for KNNMethod {
fn from(weighted: bool) -> Self {
impl KNNMethod {
pub fn new(weighted: bool, min_dist: f64) -> Self {
if weighted {
KNNMethod::DInvW
if min_dist <= f64::epsilon() {
Self::P1Weighted
} else {
Self::Weighted
}
} else {
KNNMethod::NoW
Self::NotWeighted
}
}
}
Expand Down Expand Up @@ -94,7 +100,7 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
if k == 0
|| (point.len() != self.dim())
|| (point.iter().any(|x| !x.is_finite()))
|| max_dist_bound <= T::zero() + T::epsilon()
|| max_dist_bound <= T::epsilon()
{
None
} else {
Expand Down Expand Up @@ -161,31 +167,69 @@ pub trait KDTQ<'a, T: Float + 'static, A> {
pub trait KNNRegressor<'a, T: Float + Into<f64> + 'static, A: Float + Into<f64>>:
KDTQ<'a, T, A>
{
fn knn_regress(&self, k: usize, point: &[T], max_dist_bound: T, how: KNNMethod) -> Option<f64> {
let knn = self.knn_bounded(k, point, max_dist_bound, T::zero());
fn knn_regress(
&self,
k: usize,
point: &[T],
min_dist_bound: T,
max_dist_bound: T,
how: KNNMethod,
) -> Option<f64> {
let knn = self
.knn_bounded(k, point, max_dist_bound, T::zero())
.map(|nn| {
nn.into_iter()
.filter(|nb| nb.dist >= min_dist_bound)
.collect::<Vec<_>>()
});
match knn {
Some(nn) => match how {
KNNMethod::DInvW => {
let weights = nn
.iter()
.map(|nb| (nb.dist + T::one()).recip().into())
.collect::<Vec<f64>>();
let sum = weights.iter().copied().sum::<f64>();
Some(
nn.into_iter()
.zip(weights.into_iter())
.fold(0f64, |acc, (nb, w)| acc + w * nb.to_item().into())
/ sum,
)
KNNMethod::P1Weighted => {
if nn.is_empty() {
None
} else {
let weights = nn
.iter()
.map(|nb| (T::one() + nb.dist).recip().into())
.collect::<Vec<f64>>();
let sum = weights.iter().copied().sum::<f64>();
Some(
nn.into_iter()
.zip(weights.into_iter())
.fold(0f64, |acc, (nb, w)| acc + w * nb.to_item().into())
/ sum,
)
}
}
KNNMethod::Weighted => {
if nn.is_empty() {
None
} else {
let weights = nn
.iter()
.map(|nb| nb.dist.recip().into())
.collect::<Vec<f64>>();
let sum = weights.iter().copied().sum::<f64>();
Some(
nn.into_iter()
.zip(weights.into_iter())
.fold(0f64, |acc, (nb, w)| acc + w * nb.to_item().into())
/ sum,
)
}
}
KNNMethod::NoW => {
let n = nn.len() as f64;
Some(
nn.into_iter()
.fold(A::zero(), |acc, nb| acc + nb.to_item())
.into()
/ n,
)
KNNMethod::NotWeighted => {
if nn.is_empty() {
None
} else {
let n = nn.len() as f64;
Some(
nn.into_iter()
.fold(A::zero(), |acc, nb| acc + nb.to_item())
.into()
/ n,
)
}
}
},
None => None,
Expand All @@ -196,17 +240,7 @@ pub trait KNNRegressor<'a, T: Float + Into<f64> + 'static, A: Float + Into<f64>>
pub trait KNNClassifier<'a, T: Float + 'static>: KDTQ<'a, T, u32> {
fn knn_classif(&self, k: usize, point: &[T], max_dist_bound: T, how: KNNMethod) -> Option<u32> {
let knn = self.knn_bounded(k, point, max_dist_bound, T::zero());
match knn {
Some(nn) => match how {
KNNMethod::DInvW => {
todo!()
}
KNNMethod::NoW => {
todo!()
}
},
None => None,
}
todo!()
}
}

Expand Down
22 changes: 10 additions & 12 deletions src/linalg/lstsq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ pub fn faer_rolling_lstsq(x: MatRef<f64>, y: MatRef<f64>, n: usize) -> Vec<Mat<f
let mut weights = &inv * x0t * y0;
coefficients.push(weights.to_owned());
for j in n..xn {
let remove_x = x.get(j-n..j-n+1, ..);
let remove_y = y.get(j-n..j-n+1, ..);
let remove_x = x.get(j - n..j - n + 1, ..);
let remove_y = y.get(j - n..j - n + 1, ..);
woodbury_step(inv.as_mut(), weights.as_mut(), remove_x, remove_y, -1.0);

let next_x = x.get(j..j + 1, ..); // 1 by m, m = # of columns
Expand All @@ -201,18 +201,17 @@ pub fn faer_rolling_lstsq(x: MatRef<f64>, y: MatRef<f64>, n: usize) -> Vec<Mat<f
/// https://en.wikipedia.org/wiki/Woodbury_matrix_identity
#[inline(always)]
fn woodbury_step(
inverse: MatMut<f64>,
weights: MatMut<f64>,
new_x: MatRef<f64>,
inverse: MatMut<f64>,
weights: MatMut<f64>,
new_x: MatRef<f64>,
new_y: MatRef<f64>,
c: f64 // Should be +1 or -1, for a "update" and a "removal"
c: f64, // Should be +1 or -1, for a "update" and a "removal"
) {

// It is truly amazing that the C in the Woodbury identity essentially controls the update and
// It is truly amazing that the C in the Woodbury identity essentially controls the update and
// and removal of a new record (rolling)... Linear regression seems to be designed by God to work so well

let left = &inverse * new_x.transpose(); // corresponding to u in the reference
// right = left.transpose() by the fact that if A is symmetric, invertible, A-1 is also symmetric
// right = left.transpose() by the fact that if A is symmetric, invertible, A-1 is also symmetric
let z = (c + (new_x * &left).read(0, 0)).recip();
// Update the inverse
faer::linalg::matmul::matmul(
Expand All @@ -225,7 +224,7 @@ fn woodbury_step(
); // inv is updated

// Difference from esitmate using prior weights vs. actual next y
let y_diff = new_y - (new_x * &weights);
let y_diff = new_y - (new_x * &weights);
// Update weights
faer::linalg::matmul::matmul(
weights,
Expand All @@ -235,5 +234,4 @@ fn woodbury_step(
z,
faer::Parallelism::Rayon(0), //
); // weights are updated

}
}
Loading

0 comments on commit 1fcab73

Please sign in to comment.