Skip to content

Commit

Permalink
Store number of features for now.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Jun 27, 2023
1 parent c53f0d3 commit 706db42
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions skfda/preprocessing/dim_reduction/_neighbor_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@

InputBound = Union[NDArrayFloat, FData]
Input = TypeVar("Input", contravariant=True, bound=InputBound)
Target = TypeVar("Target")
SelfType = TypeVar("SelfType", bound="KNeighborsTransformer[Any, Any]")


class KNeighborsTransformer(
KNeighborsMixin[Input, Any],
InductiveTransformerMixin[Input, NDArrayFloat, Any],
):

n_neighbors: int

@overload
def __init__(
self: KNeighborsTransformer[NDArrayFloat],
Expand Down Expand Up @@ -91,6 +95,17 @@ def _init_estimator(self) -> _KNeighborsTransformer:
n_jobs=self.n_jobs,
)

def _fit(
self: SelfType,
X: Input,
y: Target,
fit_with_zeros: bool = True,
) -> SelfType:
ret = super()._fit(X, y)
self.n_features_in_ = self._estimator.n_features_in_

return ret

def transform(
self,
X: Input,
Expand Down

0 comments on commit 706db42

Please sign in to comment.