Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MNT, BUG] Modify njit func generator to support python functions #1044

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ aeon/transformations/time_since.py @KishManani
aeon/similarity_search/ @baraline

aeon/utils/mlflow_aeon.py @benjaminbluhm
aeon/utils/numba/ @baraline

docs/get_involved/code_of_conduct.rst @aeon/aeon-code-of-conduct-committee
docs/get_involved/governance.rst @aeon/aeon-core-developers
19 changes: 10 additions & 9 deletions aeon/similarity_search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ class BaseSimiliaritySearch(BaseEstimator, ABC):
distance : str, default="euclidean"
Name of the distance function to use. A list of valid strings can be found in
the documentation for :func:`aeon.distances.get_distance_function`.
If a callable is passed it must be a numba function with nopython=True, that
takes two 1d numpy arrays as input and returns a float.
If a callable is passed it must either be a python function or numba function
with nopython=True, that takes two 1d numpy arrays as input and returns a float.
distance_args : dict, default=None
Optional keyword arguments for the distance function.
normalize : bool, default=False
Expand Down Expand Up @@ -299,8 +299,8 @@ def _get_distance_profile_function(self):
------
ValueError
If the distance parameter given at initialization is not a string nor a
numba function, or if the speedup parameter is unknow or unsupported, raise
a ValueError..
numba function or a callable, or if the speedup parameter is unknow or
unsupported, raisea ValueError.

Returns
-------
Expand All @@ -324,12 +324,13 @@ def _get_distance_profile_function(self):
)
return speed_up_profile
else:
if isinstance(self.distance, CPUDispatcher):
if isinstance(self.distance, CPUDispatcher) or callable(self.distance):
self.distance_function_ = self.distance
else:
raise ValueError(
"If distance argument is not a string, it is expected to be a "
f"numba function (CPUDispatcher), but got {type(self.distance)}."
"If distance argument is not a string, it is expected to be either "
"a callable or a numba function (CPUDispatcher), but got "
f"{type(self.distance)}."
)
if self.normalize:
return normalized_naive_distance_profile
Expand Down Expand Up @@ -393,15 +394,15 @@ def _call_distance_profile(self, q, mask):
self._q_means,
self._q_stds,
self.distance_function_,
numba_distance_args=self.distance_args,
distance_args=self.distance_args,
)
else:
distance_profile = self.distance_profile_function(
self._X,
q,
mask,
self.distance_function_,
numba_distance_args=self.distance_args,
distance_args=self.distance_args,
)
else:
if self.normalize:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,7 @@
)


def naive_distance_profile(
X, q, mask, numba_distance_function, numba_distance_args=None
):
def naive_distance_profile(X, q, mask, distance_function, distance_args=None):
r"""
Compute a distance profile in a brute force way.

Expand All @@ -39,10 +37,11 @@ def naive_distance_profile(
mask : array, shape (n_instances, n_channels, n_timepoints - query_length + 1)
Boolean mask of the shape of the distance profile indicating for which part
of it the distance should be computed.
numba_distance_function : func
A numba njit function used to compute the distance between two 1D vectors.
numba_distance_args : dict, default=None
Dictionary containing keywords arguments to use for the numba_distance_function
distance_function : func
A python function or a numba njit function used to compute the distance between
two 1D vectors.
distance_args : dict, default=None
Dictionary containing keywords arguments to use for the distance_function

Returns
-------
Expand All @@ -52,9 +51,7 @@ def naive_distance_profile(
for each channel.

"""
dist_func = generate_new_default_njit_func(
numba_distance_function, numba_distance_args
)
dist_func = generate_new_default_njit_func(distance_function, distance_args)
# This will compile the new function and check for errors outside the numba loops
dist_func(np.ones(3, dtype=X.dtype), np.zeros(3, dtype=X.dtype))
return _naive_distance_profile(X, q, mask, dist_func)
Expand All @@ -68,8 +65,8 @@ def normalized_naive_distance_profile(
X_stds,
q_means,
q_stds,
numba_distance_function,
numba_distance_args=None,
distance_function,
distance_args=None,
):
"""
Compute a distance profile in a brute force way.
Expand Down Expand Up @@ -101,10 +98,11 @@ def normalized_naive_distance_profile(
Means of the query q
q_stds : array, shape (n_channels)
Stds of the query q
numba_distance_function : func
A numba njit function used to compute the distance between two 1D vectors.
numba_distance_args : dict, default=None
Dictionary containing keywords arguments to use for the numba_distance_function
distance_function : func
A python function or a numba njit function used to compute the distance between
two 1D vectors.
distance_args : dict, default=None
Dictionary containing keywords arguments to use for the distance_function

Returns
-------
Expand All @@ -114,9 +112,7 @@ def normalized_naive_distance_profile(
for each channel.

"""
dist_func = generate_new_default_njit_func(
numba_distance_function, numba_distance_args
)
dist_func = generate_new_default_njit_func(distance_function, distance_args)
# This will compile the new function and check for errors outside the numba loops
dist_func(np.ones(3, dtype=X.dtype), np.zeros(3, dtype=X.dtype))
return _normalized_naive_distance_profile(
Expand Down
8 changes: 4 additions & 4 deletions aeon/similarity_search/tests/test_top_k_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,19 +49,19 @@ def test_TopKSimilaritySearch_euclidean(dtype):

@pytest.mark.parametrize("dtype", DATATYPES)
def test_TopKSimilaritySearch_custom_func(dtype):
@njit(fastmath=True)
def dist(x: np.ndarray, y: np.ndarray) -> float:
def _dist(x: np.ndarray, y: np.ndarray) -> float:
return np.sqrt(np.sum((x - y) ** 2))

dist = njit(_dist)
X = np.asarray(
[[[1, 2, 3, 4, 5, 6, 7, 8]], [[1, 2, 4, 4, 5, 6, 5, 4]]], dtype=dtype
)
q = np.asarray([[3, 4, 5]], dtype=dtype)

search = TopKSimilaritySearch(k=1, distance=dist)
search = TopKSimilaritySearch(k=3, distance=_dist)
search.fit(X)
idx = search.predict(q)
assert_array_equal(idx, [(0, 2)])
assert_array_equal(idx, [(0, 2), (1, 2), (1, 1)])

search = TopKSimilaritySearch(k=3, distance=dist)
search.fit(X)
Expand Down
27 changes: 15 additions & 12 deletions aeon/similarity_search/top_k_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,21 @@ class TopKSimilaritySearch(BaseSimiliaritySearch):
----------
k : int, default=1
The number of nearest matches from Q to return.
distance : str, default ="euclidean"
Name of the distance function to use.
distance_args : dict, default=None
Optional keyword arguments for the distance function.
normalize : bool, default = False
Whether the distance function should be z-normalized.
store_distance_profile : bool, default = =False.
Whether to store the computed distance profile in the attribute
"_distance_profile" after calling the predict method.
speed_up : str, default=None
Which speed up technique to use with for the selected distance
function.
distance : str, default="euclidean"
Name of the distance function to use. A list of valid strings can be found in
the documentation for :func:`aeon.distances.get_distance_function`.
If a callable is passed it must either be a python function or numba function
with nopython=True, that takes two 1d numpy arrays as input and returns a
float.
distance_args : dict, default=None
Optional keyword arguments for the distance function.
normalize : bool, default=False
Whether the distance function should be z-normalized.
store_distance_profile : bool, default=False.
Whether to store the computed distance profile in the attribute
"_distance_profile" after calling the predict method.
speed_up : str, default=None
Which speed up technique to use with for the selected distance function.

Attributes
----------
Expand Down
103 changes: 69 additions & 34 deletions aeon/utils/numba/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,66 +38,101 @@
AEON_NUMBA_STD_THRESHOLD = 1e-8


def generate_new_default_njit_func(base_func, new_defaults_args):
def generate_new_default_njit_func(
base_func,
new_defaults_args,
use_fastmath_for_callable=True,
use_cache_for_callable=True,
):
"""
Return a function with same code, globals, defaults, closure, and name.

If the function is not a CPUDispatcher (numba) function, it will try to create
a numba function if the base function is callable.

Parameters
----------
base_func : CPUDispatcher
Numba function to copy.
base_func : function or CPUDispatcher
A Python or Numba function to modify.
new_defaults_args : dict
Dictionnary of new default keyword args. If new_defaults_args is None or empty,
directly return base_func.
use_fastmath_for_callable : bool
If base_func is a callable, add fastmath as numba option when compiling
the new function to numba.
use_cache_for_callable : bool
If base_func is a callable, add cache as numba option when compiling
the new function to numba.

Returns
-------
new_func_njit : CPUDispatcher
Created numba function with new default args.

"""
if not isinstance(base_func, CPUDispatcher):
raise TypeError(
"Expected base_func to be of CPUDispatcher type (numba function),"
f"but got {type(base_func)}"
)
elif new_defaults_args is None:
return base_func
# empty dict evaluate to false
elif isinstance(new_defaults_args, dict) and not new_defaults_args:
return base_func
if new_defaults_args is None or (
isinstance(new_defaults_args, dict) and not new_defaults_args
):
if isinstance(base_func, CPUDispatcher):
return base_func
else:
numba_args_for_callable = {}
if use_fastmath_for_callable:
numba_args_for_callable.update({"fastmath": True})
if use_cache_for_callable:
numba_args_for_callable.update({"cache": True})
return njit(base_func, **numba_args_for_callable)

elif not isinstance(new_defaults_args, dict):
raise TypeError(
f"Expected new_defaults_args to be a dict but got {type(new_defaults_args)}"
"Expected new_defaults_args to be a dict but got "
f"{type(new_defaults_args)}"
)
else:
if isinstance(base_func, CPUDispatcher):
base_func_py = base_func.py_func
signature = inspect.signature(base_func_py)

_new_defaults = []
for k, v in signature.parameters.items():
if v.default is not inspect.Parameter.empty:
if k in new_defaults_args.keys():
_new_defaults.append(new_defaults_args[k])
else:
_new_defaults.append(v.default)

_new_name = "_tmp_" + base_func_py.__name__
new_func = types.FunctionType(
base_func_py.__code__,
base_func_py.__globals__,
_new_name,
tuple(_new_defaults),
base_func_py.__closure__,
elif callable(base_func):
base_func_py = base_func
else:
raise TypeError(
"Expected base_func to be of callable or CPUDispatcher type (numba "
f"function), but got {type(base_func)}"
)
# If new_func was given attrs (this dict is a shallow copy but we don't modify)
new_func.__dict__.update(base_func_py.__dict__)
signature = inspect.signature(base_func_py)

_new_defaults = []
for k, v in signature.parameters.items():
if v.default is not inspect.Parameter.empty:
if k in new_defaults_args.keys():
_new_defaults.append(new_defaults_args[k])
else:
_new_defaults.append(v.default)

new_func = types.FunctionType(
base_func_py.__code__,
base_func_py.__globals__,
"_tmp_" + base_func_py.__name__,
tuple(_new_defaults),
base_func_py.__closure__,
)
# If new_func was given attrs (dict is a shallow copy we shouldn't modify)
new_func.__dict__.update(base_func_py.__dict__)
if isinstance(base_func, CPUDispatcher):
numba_options = deepcopy(base_func.targetoptions)
# remove nopython option as we already use njit to avoid a warning
numba_options.pop("nopython")
new_func_njit = njit(new_func, **numba_options)
return new_func_njit

elif callable(base_func):
# This should return a Python function when DISABLE_NJIT = True
numba_args_for_callable = {}
if use_fastmath_for_callable:
numba_args_for_callable.update({"fastmath": True})
if use_cache_for_callable:
numba_args_for_callable.update({"cache": True})
new_func_njit = njit(new_func, **numba_args_for_callable)

return new_func_njit


@njit(fastmath=True, cache=True)
Expand Down
25 changes: 17 additions & 8 deletions aeon/utils/numba/tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest
from numba import njit
from numba.core.registry import CPUDispatcher
from numpy.testing import assert_array_almost_equal, assert_array_equal

from aeon.utils.numba.general import (
Expand All @@ -22,22 +23,30 @@


def test_generate_new_default_njit_func():
@njit(fastmath=True)
def _dummy_func(x, arg1=0.0, arg2=1.0):
return x - arg1 + arg2

_new_dummy_func = generate_new_default_njit_func(_dummy_func, {"arg1": -1.0})
dummy_func = njit(_dummy_func, fastmath=True)

new_dummy_func = generate_new_default_njit_func(dummy_func, {"arg1": -1.0})

expected_targetoptions = {"fastmath": True, "nopython": True, "boundscheck": None}

assert _dummy_func.py_func.__defaults__ == (0.0, 1.0)
assert _new_dummy_func.py_func.__defaults__ == (-1.0, 1.0)
if isinstance(dummy_func, CPUDispatcher):
assert dummy_func.py_func.__defaults__ == (0.0, 1.0)
assert new_dummy_func.py_func.__defaults__ == (-1.0, 1.0)

assert dummy_func.targetoptions == expected_targetoptions
assert new_dummy_func.targetoptions == expected_targetoptions

assert _dummy_func.targetoptions == expected_targetoptions
assert _new_dummy_func.targetoptions == expected_targetoptions
assert dummy_func.__name__ != new_dummy_func.__name__
assert dummy_func.py_func.__code__ == new_dummy_func.py_func.__code__

assert _dummy_func.__name__ != _new_dummy_func.__name__
assert _dummy_func.py_func.__code__ == _new_dummy_func.py_func.__code__
elif callable(dummy_func):
assert dummy_func.__defaults__ == (0.0, 1.0)
assert new_dummy_func.__defaults__ == (-1.0, 1.0)
assert dummy_func.__name__ != new_dummy_func.__name__
assert dummy_func.__code__ == new_dummy_func.__code__


@pytest.mark.parametrize("type", DATATYPES)
Expand Down
4 changes: 2 additions & 2 deletions examples/similarity_search/code_speed.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -314,15 +314,15 @@
" normalized_naive_distance_profile(\n",
" X, q, mask, X_means, X_stds, q_means, q_stds, squared_distance\n",
" )\n",
" _times = %timeit -r 3 -n 7 -q -o normalized_naive_distance_profile(X,q,mask,X_means,X_stds,q_means,q_stds, squared_distance)\n",
" _times = %timeit -r 3 -n 7 -q -o normalized_naive_distance_profile(X, q, mask, X_means, X_stds, q_means, q_stds, squared_distance)\n",
" times.loc[\n",
" (size, _query_length), \"Naive Normalized Euclidean distance\"\n",
" ] = _times.average\n",
" # Used for numba compilation before timings\n",
" normalized_squared_distance_profile(\n",
" X, q, mask, X_means, X_stds, q_means, q_stds\n",
" )\n",
" _times = %timeit -r 3 -n 7 -q -o normalized_squared_distance_profile(X,q,mask,X_means,X_stds,q_means,q_stds)\n",
" _times = %timeit -r 3 -n 7 -q -o normalized_squared_distance_profile(X, q, mask, X_means, X_stds, q_means, q_stds)\n",
" times.loc[\n",
" (size, _query_length), \"Normalized Euclidean as dot product\"\n",
" ] = _times.average"
Expand Down