Skip to content

Commit

Permalink
test: add test for impure function correlation behavior
Browse files Browse the repository at this point in the history
Need to fix the UDF test case.

Related to ibis-project#8921,
trying to write down exactly what
the expected behavior is.
  • Loading branch information
NickCrews committed Apr 19, 2024
1 parent 88de65d commit 21b180c
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions ibis/backends/tests/test_impure.py
@@ -0,0 +1,83 @@
from __future__ import annotations

import random

import pandas.testing as tm
import pytest

import ibis
from ibis import _


@ibis.udf.scalar.python(side_effects=True)
def my_random(x: float) -> float:
return random.random()


mark_impures = pytest.mark.parametrize(
"impure",
[
pytest.param(
lambda _: ibis.random(),
id="random",
),
pytest.param(
lambda _: ibis.uuid().cast(str).contains("a").cast(float),
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
id="udf",
),
],
)


@mark_impures
def test_impure_correlated(alltypes, impure):
df = (
alltypes.select(common=impure(alltypes))
.select(x=_.common, y=_.common)
.execute()
)
tm.assert_series_equal(df.x, df.y, check_names=False)


@mark_impures
def test_chained_selections(alltypes, impure):
# https://github.com/ibis-project/ibis/issues/8921#issue-2234327722
t = alltypes.mutate(num=impure(alltypes))
t = t.mutate(isbig=(t.num > 0.5))
df = t.select("num", "isbig").execute()
df["expected"] = df.num > 0.5
tm.assert_series_equal(df.isbig, df.expected, check_names=False)


@pytest.mark.parametrize(
"impure",
[
pytest.param(
lambda _: ibis.random(),
id="random",
),
pytest.param(
# make this a float so we can compare to .5
lambda _: ibis.uuid().cast(str).contains("a").cast(float),
id="uuid",
),
pytest.param(
lambda table: my_random(table.float_col),
id="udf",
# once this is fixed, can we unify these params with the params below?
marks=pytest.mark.xfail(reason="executed only once"),
),
],
)
def test_impure_uncorrelated(alltypes, impure):
df = alltypes.select(x=impure(alltypes), y=impure(alltypes)).execute()
assert (df.x == df.y).mean() < 1
# Even if the two expressions have the exact same ID, they should still be
# uncorrelated
common = impure(alltypes)
df = alltypes.select(x=common, y=common).execute()
assert (df.x == df.y).mean() < 1

0 comments on commit 21b180c

Please sign in to comment.