Skip to content

Commit

Permalink
feat: add support for cupy in SparseThreshold.
Browse files Browse the repository at this point in the history
Ideally we want to have such support everywhere.
  • Loading branch information
paquiteau committed Feb 4, 2024
1 parent 2b1764c commit 4dc7d10
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion modopt/opt/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
else:
import_sklearn = True

from modopt.base.backend import get_array_module
from modopt.base.transform import cube2matrix, matrix2cube
from modopt.base.types import check_callable
from modopt.interface.errors import warn
Expand Down Expand Up @@ -215,7 +216,10 @@ def _cost_method(self, *args, **kwargs):
Sparsity cost component
"""
cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0])))
xp = get_array_module(args[0])
cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0])))
if isinstance(cost_val, xp.ndarray):
cost_val = cost_val.item()

if 'verbose' in kwargs and kwargs['verbose']:
print(' - L1 NORM (X):', cost_val)
Expand Down

0 comments on commit 4dc7d10

Please sign in to comment.