Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Singular Value Thresholding #209

Merged
merged 13 commits into from
Feb 7, 2022
30 changes: 26 additions & 4 deletions modopt/opt/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from modopt.math.matrix import nuclear_norm
from modopt.signal.noise import thresh
from modopt.signal.positivity import positive
from modopt.signal.svd import svd_thresh, svd_thresh_coef
from modopt.signal.svd import svd_thresh, svd_thresh_coef, svd_thresh_coef_fast


class ProximityParent(object):
Expand Down Expand Up @@ -237,6 +237,9 @@ class LowRankMatrix(ProximityParent):
lowr_type : {'standard', 'ngole'}
Low-rank implementation (options are 'standard' or 'ngole', default is
'standard')
initial_rank: int, optional
Initial guess of the rank of future input_data.
If provided this will save computation time.
operator : class
Operator class ('ngole' only)

Expand Down Expand Up @@ -268,6 +271,7 @@ def __init__(
threshold,
thresh_type='soft',
lowr_type='standard',
initial_rank=None,
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
operator=None,
):

Expand All @@ -277,8 +281,9 @@ def __init__(
self.operator = operator
self.op = self._op_method
self.cost = self._cost_method
self.rank = initial_rank

def _op_method(self, input_data, extra_factor=1.0):
def _op_method(self, input_data, extra_factor=1.0, rank=None):
"""Operator.

This method returns the input data after the singular values have been
Expand All @@ -290,22 +295,37 @@ def _op_method(self, input_data, extra_factor=1.0):
Input data array
extra_factor : float
Additional multiplication factor (default is ``1.0``)
rank: int, optional
Estimation of the rank to save computation time in standard mode,
if not set an internal estimation is used.

Returns
-------
numpy.ndarray
SVD thresholded data

Raises
------
ValueError
if lowr_type is not in ``{'standard', 'ngole'}``
"""
# Update threshold with extra factor.
threshold = self.thresh * extra_factor

if self.lowr_type == 'standard':
if self.lowr_type == 'standard' and self.rank is None and rank is None:
data_matrix = svd_thresh(
cube2matrix(input_data),
threshold,
thresh_type=self.thresh_type,
)
elif self.lowr_type == 'standard':
data_matrix, update_rank = svd_thresh_coef_fast(
cube2matrix(input_data),
threshold,
n_vals=rank or self.rank,
extra_vals=5,
thresh_type=self.thresh_type,
)
self.rank = update_rank # save for future use

elif self.lowr_type == 'ngole':
data_matrix = svd_thresh_coef(
Expand All @@ -314,6 +334,8 @@ def _op_method(self, input_data, extra_factor=1.0):
threshold,
thresh_type=self.thresh_type,
)
else:
raise ValueError('lowr_type should be standard or ngole')

# Return updated data.
return matrix2cube(data_matrix, input_data.shape[1:])
Expand Down
59 changes: 59 additions & 0 deletions modopt/signal/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
from scipy.linalg import svd
from scipy.sparse.linalg import svds

from modopt.base.transform import matrix2cube
from modopt.interface.errors import warn
Expand Down Expand Up @@ -200,6 +201,64 @@ def svd_thresh(input_data, threshold=None, n_pc=None, thresh_type='hard'):
return np.dot(u_vec, np.dot(s_new, v_vec))


def svd_thresh_coef_fast(
input_data,
threshold,
n_vals=-1,
extra_vals=5,
thresh_type='hard',
):
"""Threshold the singular values coefficients.

This method thresholds the input data by using singular value
decomposition, but only computing the the greastest ``n_vals``
values.

Parameters
----------
input_data : numpy.ndarray
Input data array, 2D matrix
Operator class instance
threshold : float or numpy.ndarray
Threshold value(s)
n_vals: int, optional
Number of singular values to compute.
If None, compute all singular values.
extra_vals: int, optional
If the number of values computed is not enough to perform thresholding,
recompute by using ``n_vals + extra_vals`` (default is ``5``)
thresh_type : {'hard', 'soft'}
Type of noise to be added (default is ``'hard'``)

Returns
-------
tuple
The thresholded data (numpy.ndarray) and the estimated rank after
thresholding (int)
"""
if n_vals == -1:
n_vals = min(input_data.shape) - 1
ok = False
while not ok:
(u_vec, s_values, v_vec) = svds(input_data, k=n_vals)
ok = (s_values[0] <= threshold or n_vals == min(input_data.shape) - 1)
n_vals = min(n_vals + extra_vals, *input_data.shape)

s_values = thresh(
s_values,
threshold,
threshold_type=thresh_type,
)
rank = np.count_nonzero(s_values)
return (
np.dot(
u_vec[:, -rank:] * s_values[-rank:],
v_vec[-rank:, :],
),
rank,
)


def svd_thresh_coef(input_data, operator, threshold, thresh_type='hard'):
"""Threshold the singular values coefficients.

Expand Down
12 changes: 12 additions & 0 deletions modopt/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,11 @@ def setUp(self):
weights,
)
self.lowrank = proximity.LowRankMatrix(10.0, thresh_type='hard')
self.lowrank_rank = proximity.LowRankMatrix(
10.0,
initial_rank=1,
thresh_type='hard',
)
self.lowrank_ngole = proximity.LowRankMatrix(
10.0,
lowr_type='ngole',
Expand Down Expand Up @@ -763,6 +768,8 @@ def tearDown(self):
self.positivity = None
self.sparsethresh = None
self.lowrank = None
self.lowrank_rank = None
self.lowrank_ngole = None
self.combo = None
self.data1 = None
self.data2 = None
Expand Down Expand Up @@ -841,6 +848,11 @@ def test_low_rank_matrix(self):
err_msg='Incorrect low rank operation: standard',
)

npt.assert_almost_equal(
self.lowrank_rank.op(self.data3),
self.data4,
err_msg='Incorrect low rank operation: standard with rank',
)
npt.assert_almost_equal(
self.lowrank_ngole.op(self.data3),
self.data5,
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ per-file-ignores =
#Justification: Needed to import matplotlib.pyplot
modopt/plot/cost_plot.py: N802,WPS301
#Todo: Investigate possible bug in find_n_pc function
modopt/signal/svd.py: WPS345
#Todo: Investigate darglint error
modopt/signal/svd.py: WPS345, DAR000
#Todo: Check security of using system executable call
modopt/signal/wavelet.py: S404,S603
#Todo: Clean up tests
Expand Down