Skip to content

Commit

Permalink
Added KNeighborsTransformer, without docs for the moment.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Jun 27, 2023
1 parent 3663905 commit 98273fa
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 16 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion skfda/exploratory/outliers/neighbors_outlier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from sklearn.neighbors import LocalOutlierFactor as _LocalOutlierFactor
from typing_extensions import Literal

from ..._utils._neighbors_base import AlgorithmType, KNeighborsMixin
from ...misc.metrics import PairwiseMetric, l2_distance
from ...ml._neighbors_base import AlgorithmType, KNeighborsMixin
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat, NDArrayInt
Expand Down
10 changes: 5 additions & 5 deletions skfda/ml/classification/_neighbors_classifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
)
from typing_extensions import Literal

from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat, NDArrayInt
from .._neighbors_base import (
from ..._utils._neighbors_base import (
AlgorithmType,
KNeighborsMixin,
NeighborsClassifierMixin,
RadiusNeighborsMixin,
WeightsType,
)
from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat, NDArrayInt

InputBound = Union[NDArrayFloat, FData]
Input = TypeVar("Input", contravariant=True, bound=InputBound)
Expand Down
10 changes: 5 additions & 5 deletions skfda/ml/clustering/_neighbors_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

from typing_extensions import Literal

from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat
from .._neighbors_base import (
from ..._utils._neighbors_base import (
AlgorithmType,
KNeighborsMixin,
RadiusNeighborsMixin,
)
from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat

InputBound = Union[NDArrayFloat, FData]
Input = TypeVar("Input", contravariant=True, bound=InputBound)
Expand Down
10 changes: 5 additions & 5 deletions skfda/ml/regression/_neighbors_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@
)
from typing_extensions import Literal

from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat, NDArrayInt
from .._neighbors_base import (
from ..._utils._neighbors_base import (
AlgorithmType,
KNeighborsMixin,
NeighborsRegressorMixin,
RadiusNeighborsMixin,
WeightsType,
)
from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat, NDArrayInt

InputBound = Union[NDArrayFloat, FData]
Input = TypeVar("Input", contravariant=True, bound=InputBound)
Expand Down
2 changes: 2 additions & 0 deletions skfda/preprocessing/dim_reduction/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
],
submod_attrs={
"_fpca": ["FPCA"],
"_neighbor_transforms": ["KNeighborsTransformer"]
},
)

if TYPE_CHECKING:
from ._fpca import FPCA as FPCA
from ._neighbor_transforms import KNeighborsTransformer as KNeighborsTransformer


def __getattr__(name: str) -> Any:
Expand Down
104 changes: 104 additions & 0 deletions skfda/preprocessing/dim_reduction/_neighbor_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

from typing import Any, Literal, TypeVar, Union, overload

from scipy.sparse import csr_matrix
from sklearn.neighbors import KNeighborsTransformer as _KNeighborsTransformer

from skfda._utils._sklearn_adapter import InductiveTransformerMixin

from ..._utils._neighbors_base import AlgorithmType, KNeighborsMixin
from ...misc.metrics import l2_distance
from ...representation import FData
from ...typing._metric import Metric
from ...typing._numpy import NDArrayFloat

InputBound = Union[NDArrayFloat, FData]
Input = TypeVar("Input", contravariant=True, bound=InputBound)


class KNeighborsTransformer(
KNeighborsMixin[Input, Any],
InductiveTransformerMixin[Input, NDArrayFloat, Any],
):

@overload
def __init__(
self: KNeighborsTransformer[NDArrayFloat],
*,
mode: Literal["connectivity", "distance"] = "distance",
n_neighbors: int = 5,
algorithm: AlgorithmType = 'auto',
leaf_size: int = 30,
metric: Literal["precomputed"],
n_jobs: int | None = None,
) -> None:
pass

@overload
def __init__(
self: KNeighborsTransformer[InputBound],
*,
mode: Literal["connectivity", "distance"] = "distance",
n_neighbors: int = 5,
algorithm: AlgorithmType = 'auto',
leaf_size: int = 30,
n_jobs: int | None = None,
) -> None:
pass

@overload
def __init__(
self,
*,
mode: Literal["connectivity", "distance"] = "distance",
n_neighbors: int = 5,
algorithm: AlgorithmType = 'auto',
leaf_size: int = 30,
metric: Metric[Input] = l2_distance,
n_jobs: int | None = None,
) -> None:
pass

# Not useless, it restricts parameters
def __init__( # noqa: WPS612
self,
*,
mode: Literal["connectivity", "distance"] = "distance",
n_neighbors: int = 5,
algorithm: AlgorithmType = 'auto',
leaf_size: int = 30,
metric: Literal["precomputed"] | Metric[Input] = l2_distance,
n_jobs: int | None = None,
) -> None:
self.mode = mode
super().__init__(
n_neighbors=n_neighbors,
algorithm=algorithm,
leaf_size=leaf_size,
metric=metric,
n_jobs=n_jobs,
)

def _init_estimator(self) -> _KNeighborsTransformer:

return _KNeighborsTransformer(
mode=self.mode,
n_neighbors=self.n_neighbors,
algorithm=self.algorithm,
leaf_size=self.leaf_size,
metric="precomputed",
n_jobs=self.n_jobs,
)

def transform(
self,
X: Input,
) -> csr_matrix:
self._check_is_fitted()
add_one = self.mode == "distance"
return self.kneighbors_graph(
X,
mode=self.mode,
n_neighbors=self.n_neighbors + add_one,
)
34 changes: 34 additions & 0 deletions skfda/tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
from sklearn.neighbors._base import KNeighborsMixin, RadiusNeighborsMixin
from sklearn.pipeline import Pipeline

from skfda.datasets import make_multimodal_samples, make_sinusoidal_process
from skfda.exploratory.outliers import LocalOutlierFactor # Pending theory
Expand All @@ -16,6 +17,7 @@
)
from skfda.ml.clustering import NearestNeighbors
from skfda.ml.regression import KNeighborsRegressor, RadiusNeighborsRegressor
from skfda.preprocessing.dim_reduction import KNeighborsTransformer
from skfda.representation import FDataBasis, FDataGrid
from skfda.representation.basis import FourierBasis

Expand Down Expand Up @@ -87,6 +89,38 @@ def test_predict_classifier(self) -> None:
err_msg=f'fail in {type(neigh)}',
)

def test_predict_classifier_transformer_knn(self) -> None:
"""Tests equivalence between using the knn transformer or not."""
n_neighbors_list = range(1, 11, 2)

for n_neighbors in n_neighbors_list:
classifier = KNeighborsClassifier(n_neighbors=n_neighbors)
transformer_classifier = Pipeline([
(
"transformer",
KNeighborsTransformer(
n_neighbors=max(n_neighbors_list),
mode="distance",
),
),
(
"classifier",
KNeighborsClassifier(
n_neighbors=n_neighbors,
metric="precomputed",
),
),
])

classifier.fit(self.X, self.y)
transformer_classifier.fit(self.X, self.y)
pred_classifier = classifier.predict(self.X)
pred_transformer = transformer_classifier.predict(self.X)
np.testing.assert_allclose(
pred_classifier,
pred_transformer,
)

def test_predict_proba_classifier(self) -> None:
"""Tests predict proba for k neighbors classifier."""
neigh = KNeighborsClassifier(metric=l2_distance)
Expand Down

0 comments on commit 98273fa

Please sign in to comment.