diff --git a/src/aspire/numeric/complex_pca/complex_pca.py b/src/aspire/numeric/complex_pca/complex_pca.py index 5c8cd1e2d6..bd270ef64b 100644 --- a/src/aspire/numeric/complex_pca/complex_pca.py +++ b/src/aspire/numeric/complex_pca/complex_pca.py @@ -13,7 +13,10 @@ import numpy as np import scipy.sparse as sp +import sklearn +from packaging.version import Version from sklearn.decomposition import PCA +from sklearn.utils._array_api import get_namespace from .validation import check_array @@ -45,6 +48,8 @@ def _fit(self, X): allow_complex=True, ) + xp, is_array_api_compliant = get_namespace(X) + # Handle n_components==None if self.n_components is None: if self.svd_solver != "arpack": @@ -66,11 +71,22 @@ def _fit(self, X): else: self._fit_svd_solver = "full" + # sci-kit changed `_fit_*()` API in latest release v1.5.0 + # which supports Python 3.9 - 3.12. This can be removed after + # our minimal support is Python 3.9. + API_dep = Version(sklearn.__version__) < Version("1.5.0") + # Call different fits for either full or truncated SVD if self._fit_svd_solver == "full": - return self._fit_full(X, n_components) + if API_dep: + return self._fit_full(X, n_components) + else: + return self._fit_full(X, n_components, xp, is_array_api_compliant) elif self._fit_svd_solver in ["arpack", "randomized"]: - return self._fit_truncated(X, n_components, self._fit_svd_solver) + if API_dep: + return self._fit_truncated(X, n_components, self._fit_svd_solver) + else: + return self._fit_truncated(X, n_components, xp) else: raise ValueError( "Unrecognized svd_solver='{0}'" "".format(self._fit_svd_solver)