From 4dc7d10f7ba30bc046162eabbe3e50af892afaf8 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Sun, 4 Feb 2024 21:22:01 +0100 Subject: [PATCH] feat: add support for cupy in SparseThreshold. Ideally we want to have such support everywhere. --- modopt/opt/proximity.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py index e8492367..fc81a753 100644 --- a/modopt/opt/proximity.py +++ b/modopt/opt/proximity.py @@ -22,6 +22,7 @@ else: import_sklearn = True +from modopt.base.backend import get_array_module from modopt.base.transform import cube2matrix, matrix2cube from modopt.base.types import check_callable from modopt.interface.errors import warn @@ -215,7 +216,10 @@ def _cost_method(self, *args, **kwargs): Sparsity cost component """ - cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0]))) + xp = get_array_module(args[0]) + cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0]))) + if isinstance(cost_val, xp.ndarray): + cost_val = cost_val.item() if 'verbose' in kwargs and kwargs['verbose']: print(' - L1 NORM (X):', cost_val)