Skip to content

Commit

Permalink
Merge pull request #189 from abstractqqq/mann_whitney
Browse files Browse the repository at this point in the history
  • Loading branch information
abstractqqq committed Jun 23, 2024
2 parents 7b2fcbc + 29935d8 commit 6a0ce2e
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 68 deletions.
59 changes: 56 additions & 3 deletions python/polars_ds/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"query_ttest_ind_from_stats",
"query_ks_2samp",
"query_f_test",
"query_mann_whitney_u",
"query_chi2",
"query_first_digit_cnt",
"query_c3_stats",
Expand Down Expand Up @@ -610,11 +611,12 @@ def query_ks_2samp(
sanitize data (only non-null finite values are used) before doing the computation. If
is_binary is true, it will compare the statistics by comparing var2(var1=0) and var2(var1=1).
Note, this returns a stastics and a threshold value. The threshold value is not the p-value, but
Note, this returns a stastics and a threshold value. The threshold is not the p-value, but
rather it is used in the following way: if the statistic is > the threshold value, then the null
hypothesis should be rejected. This is suitable only for large sameple sizes. See the reference.
hypothesis should be rejected. This is suitable only for large sameple sizes. See more details
in the reference.
If either var1 or var2 has less than 30 values, a ks stats of INFINITY will be returned.
If either var1 or var2 has less than 20 values, a ks stats of INFINITY will be returned.
Parameters
----------
Expand Down Expand Up @@ -740,6 +742,57 @@ def query_cid_ce(x: StrOrExpr, normalize: bool = False) -> pl.Expr:
return z.dot(z).sqrt()


def query_mann_whitney_u(
var1: StrOrExpr,
var2: StrOrExpr,
alternative: Alternative = "two-sided",
) -> pl.Expr:
"""
Computes the Mann-Whitney U statistic and the p-value. Note: this function will sanitize data (drop
all non-finite values) before computing the statistic. This implementation follows method 2 in reference.
This always applies tie correction, which may slow down computation by a little.
WIP. PVALUE NOT DONE YET.
Parameters
----------
var1 : pl.Expr
Either the name of the column or a Polars expression
var2 : pl.Expr
Either the name of the column or a Polars expression
alternative: str
The alternative for the test. `two-sided`, `greater` or `less`
Reference
---------
https://en.wikipedia.org/wiki/Mann%E2%80%93Whitney_U_test
"""
x = str_to_expr(var1)
y = str_to_expr(var2)
xx = x.filter(x.is_finite())
yy = y.filter(y.is_finite())
n1 = xx.len().cast(pl.Float64)
n2 = yy.len().cast(pl.Float64)

ranks = (xx.append(yy)).rank()

u1 = ranks.slice(0, length=n1).sum() - (n1 * (n1 + 1)) / 2
u2 = (n1 * n2) - u1
# # This step is very slow
# tie_term = ranks.sort().rle().struct.field("lengths").cast(pl.Float64)
mean = (n1 * n2) / 2
# std_ties = (
# ((n1 * n2) / 12) * (
# (n + 1) - (tie_term.dot((tie_term + 1) * (tie_term - 1))) / (n * (n - 1))
# )
# ).sqrt()

return pl_plugin(
symbol="pl_mann_whitney_u",
args=[u1, u2, mean, ranks.sort(), pl.lit(alternative, dtype=pl.String)],
)


def winsorize(
x: StrOrExpr,
lower: float = 0.05,
Expand Down
15 changes: 9 additions & 6 deletions src/num/tp_fp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ fn combo_output(_: &[Field]) -> PolarsResult<Field> {
Ok(Field::new("", DataType::Struct(v)))
}

fn tp_fp_frame(predicted: &Series, actual: &Series, positive_count:u32, as_ratio: bool) -> PolarsResult<LazyFrame> {
fn tp_fp_frame(
predicted: &Series,
actual: &Series,
positive_count: u32,
as_ratio: bool,
) -> PolarsResult<LazyFrame> {
// Checking for data quality issues
if (actual.len() != predicted.len())
|| actual.is_empty()
Expand Down Expand Up @@ -84,7 +89,7 @@ fn pl_combo_b(inputs: &[Series]) -> PolarsResult<Series> {

let positive_count = actual.sum::<u32>().unwrap_or(0);
if positive_count == 0 {
return Ok(Series::from_iter([f64::NAN]))
return Ok(Series::from_iter([f64::NAN]));
}

let mut binding = tp_fp_frame(predicted, actual, positive_count, true)?.collect()?;
Expand Down Expand Up @@ -147,7 +152,6 @@ fn binary_confusion_matrix(combined_series: &UInt32Chunked) -> [u32; 4] {
output
}


#[polars_expr(output_type=Float64)]
fn pl_roc_auc(inputs: &[Series]) -> PolarsResult<Series> {
// actual, when passed in, is always u32 (done in Python extension side)
Expand All @@ -156,7 +160,7 @@ fn pl_roc_auc(inputs: &[Series]) -> PolarsResult<Series> {

let positive_count = actual.sum::<u32>().unwrap_or(0);
if positive_count == 0 {
return Ok(Series::from_iter([f64::NAN]))
return Ok(Series::from_iter([f64::NAN]));
}

let mut binding = tp_fp_frame(predicted, actual, positive_count, true)?
Expand All @@ -179,7 +183,6 @@ fn pl_roc_auc(inputs: &[Series]) -> PolarsResult<Series> {
Ok(Series::from_iter([out]))
}


// bcm = binary confusion matrix
fn bcm_output(_: &[Field]) -> PolarsResult<Field> {
let tp = Field::new("tp", DataType::UInt32);
Expand Down Expand Up @@ -303,4 +306,4 @@ fn pl_binary_confusion_matrix(inputs: &[Series]) -> PolarsResult<Series> {

let out = result.into_struct("confusion_matrix");
Ok(out.into_series())
}
}
2 changes: 1 addition & 1 deletion src/stats/ks.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::simple_stats_output;
/// KS statistics.
use super::simple_stats_output;
use crate::stats::StatsResult;
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;
Expand Down
65 changes: 65 additions & 0 deletions src/stats/mann_whitney_u.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/// Mann-Whitney U Statistics
use super::{simple_stats_output, Alternative};
use crate::stats_utils::{is_zero, normal};
use polars::prelude::*;
use pyo3_polars::derive::polars_expr;

fn mann_whitney_tie_sum(ranks: &Float64Chunked) -> f64 {
// NaN won't exist in ranks.
let mut rank_number = f64::NAN;
let mut rank_cnt: f64 = 0f64;
let mut accumulant = 0f64;
for v in ranks.into_no_null_iter() {
if v == rank_number {
rank_cnt += 1.;
} else {
accumulant += rank_cnt * (rank_cnt + 1.0) * (rank_cnt - 1.0);
rank_number = v;
rank_cnt = 1.0;
}
}
accumulant
}

#[polars_expr(output_type_func=simple_stats_output)]
fn pl_mann_whitney_u(inputs: &[Series]) -> PolarsResult<Series> {
// Reference: https://github.com/scipy/scipy/blob/v1.13.1/scipy/stats/_mannwhitneyu.py#L177

let u1 = inputs[0].f64().unwrap();
let u1 = u1.get(0).unwrap();

let u2 = inputs[1].f64().unwrap();
let u2 = u2.get(0).unwrap();

let mean = inputs[2].f64().unwrap();
let mean = mean.get(0).unwrap();

// Custom RLE
let sorted_ranks = inputs[3].f64().unwrap();
let n = sorted_ranks.len() as f64;
let tie_term_sum = mann_whitney_tie_sum(sorted_ranks);
let std_ = ((mean / 6.0) * ((n + 1.0) - tie_term_sum / (n * (n - 1.0)))).sqrt();

let alt = inputs[4].str()?;
let alt = alt.get(0).unwrap();
let alt = Alternative::from(alt);

let (u, factor) = match alt {
// if I use min here, always wrong p value. But wikipedia says it is min. I wonder wtf..
Alternative::TwoSided => (u1.max(u2), 2.0),
Alternative::Less => (u2, 1.0),
Alternative::Greater => (u1, 1.0),
};

let p = if is_zero(std_) {
0.
} else {
// -0.5 is some continuity adjustment. See Scipy's impl
(factor * normal::sf_unchecked(u, mean + 0.5, std_)).clamp(0., 1.)
};

let s = Series::from_vec("statistic", vec![u1]);
let p = Series::from_vec("pvalue", vec![p]);
let out = StructChunked::new("", &[s, p])?;
Ok(out.into_series())
}
1 change: 1 addition & 0 deletions src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mod chi2;
mod fstats;
mod kendall_tau;
mod ks;
mod mann_whitney_u;
mod normal_test;
mod sample;
mod t_test;
Expand Down
61 changes: 3 additions & 58 deletions tests/test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
"metadata": {},
"outputs": [],
"source": [
"df = pds.random_data(size=2000, n_cols=0).select(\n",
"df = pds.random_data(size=500_000, n_cols=0).select(\n",
" pds.random(0.0, 1.0).alias(\"x1\"),\n",
" pds.random(0.0, 1.0).alias(\"x2\"),\n",
" pds.random(0.0, 1.0).alias(\"x3\"),\n",
" pds.random_int(0, 3).cast(pl.String).alias(\"str\"),\n",
" pds.random_int(0, 2).alias(\"target\"),\n",
" pds.random_int(0, 100).alias(\"test\"),\n",
")\n",
"df"
]
Expand All @@ -33,65 +33,10 @@
"outputs": [],
"source": [
"df.select(\n",
" pl.col(\"x1\").cut(breaks=[0.2, 0.4], left_closed=True, include_breaks=True)\n",
" pds.query_mann_whitney_u(\"x1\", \"x2\")\n",
").unnest(\"x1\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(\n",
" df\n",
" .select(pds.query_lstsq(*[\"x1\", \"x2\"], target=\"x3\", add_bias=False, skip_null=True))\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pl.DataFrame({\n",
" \"a\": [1,2,3,4,5, None],\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.select(\n",
" pds.query_longest_streak_below(\"a\", 2).alias(\"<=2\"),\n",
" pds.query_longest_streak_below(\"a\", 6).alias(\"<=6\"),\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df = pl.DataFrame({\"actual\": [1, 0, 1], \"pred\": [0.4, 0.6, 0.9]})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"df.select(\n",
" pds.query_confusion_matrix(\"actual\", \"pred\").alias(\"metrics\")\n",
").unnest(\"metrics\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
25 changes: 25 additions & 0 deletions tests/test_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,31 @@ def test_f_test(df):
assert np.isclose(pvalue, scikit_p)


@pytest.mark.parametrize(
"df",
[
(
pl.DataFrame(
{
"x1": np.random.normal(size=1_000),
"x2": np.random.normal(size=1_000),
}
)
),
],
)
def test_mann_whitney_u(df):
from scipy.stats import mannwhitneyu

res = df.select(pds.query_mann_whitney_u("x1", "x2"))
res = res.item(0, 0) # A dictionary
res_statistic = res["statistic"]
res_pvalue = res["pvalue"]
answer = mannwhitneyu(df["x1"].to_numpy(), df["x2"].to_numpy())
assert np.isclose(res_statistic, answer.statistic)
assert np.isclose(res_pvalue, answer.pvalue)


@pytest.mark.parametrize(
"df, res",
[
Expand Down

0 comments on commit 6a0ce2e

Please sign in to comment.