Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Julien-Sahli committed Apr 25, 2023
1 parent 51c3aa5 commit e2849d3
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 29 deletions.
14 changes: 7 additions & 7 deletions lensless/admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
self._tau = tau

#3D ADMM is not supported yet
if len(psf.shape[0]) > 1:
if psf.shape[0] > 1:
raise NotImplementedError("3D ADMM is not supported yet, use gradient descent or APGD instead.")

# call reset() to initialize matrices
Expand Down Expand Up @@ -201,11 +201,11 @@ def _image_update(self):
# rk = self._convolver._pad(rk)

if self.is_torch:
freq_space_result = self._R_divmat * torch.fft.rfft2(rk, dim=(0, 1))
self._image_est = torch.fft.irfft2(freq_space_result, dim=(0, 1))
freq_space_result = self._R_divmat * torch.fft.rfft2(rk, dim=(1, 2))
self._image_est = torch.fft.irfft2(freq_space_result, dim=(1, 2))
else:
freq_space_result = self._R_divmat * fft.rfft2(rk, axes=(0, 1))
self._image_est = fft.irfft2(freq_space_result, axes=(0, 1))
freq_space_result = self._R_divmat * fft.rfft2(rk, axes=(1, 2))
self._image_est = fft.irfft2(freq_space_result, axes=(1, 2))

# self._image_est = self._convolver._crop(res)

Expand Down Expand Up @@ -308,6 +308,6 @@ def finite_diff_gram(shape, dtype=None, is_torch=False):
] = -1

if is_torch:
return torch.fft.rfft2(gram, dim=(0, 1))
return torch.fft.rfft2(gram, dim=(1, 2))
else:
return fft.rfft2(gram, axes=(0, 1))
return fft.rfft2(gram, axes=(1, 2))
41 changes: 21 additions & 20 deletions lensless/apgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,55 +48,55 @@ def __init__(self, filter, dtype: Optional[type] = None):
dtype : float32 or float64
Data type to use for optimization.
"""

assert len(filter.shape) == 3
raise NotImplementedError("Needs to be updated to work with 4D data")
assert len(filter.shape) == 4
self._filter_shape = np.array(filter.shape)
self._n_channels = filter.shape[2]
self._n_channels = filter.shape[3]

# cropping / padding indices
self._padded_shape = 2 * self._filter_shape[:2] - 1
self._padded_shape = 2 * self._filter_shape[1:3] - 1
self._padded_shape = np.array([next_fast_len(i) for i in self._padded_shape])
self._padded_shape = np.r_[self._padded_shape, [self._n_channels]]
self._start_idx = (self._padded_shape[:2] - self._filter_shape[:2]) // 2
self._end_idx = self._start_idx + self._filter_shape[:2]
self._start_idx = (self._padded_shape[1:3] - self._filter_shape[1:3]) // 2
self._end_idx = self._start_idx + self._filter_shape[1:3]

# precompute filter in frequency domain
self._H = fft.rfft2(self._pad(filter), axes=(0, 1))
self._H = fft.rfft2(self._pad(filter), axes=(1, 2))
self._Hadj = np.conj(self._H)
self._padded_data = np.zeros(self._padded_shape).astype(dtype)

shape = (int(np.prod(self._filter_shape)), int(np.prod(self._filter_shape)))
super(RealFFTConvolve2D, self).__init__(shape=shape, dtype=dtype)

def _crop(self, x):
return x[self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]]
return x[:, self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]]

def _pad(self, v):
vpad = np.zeros(self._padded_shape).astype(v.dtype)
vpad[self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]] = v
vpad[:, self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]] = v
return vpad

def __call__(self, x: Union[Number, np.ndarray]) -> Union[Number, np.ndarray]:
# like here: https://github.com/PyLops/pylops/blob/3e7eb22a62ec60e868ccdd03bc4b54806851cb26/pylops/signalprocessing/ConvolveND.py#L103
self._padded_data[
self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]
:, self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]
] = np.reshape(x, self._filter_shape)
y = self._crop(
fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._H, axes=(0, 1)),
axes=(0, 1),
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._H, axes=(1, 2)),
axes=(1, 2),
)
)
return y.ravel()

def adjoint(self, y: Union[Number, np.ndarray]) -> Union[Number, np.ndarray]:
self._padded_data[
self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]
:, self._start_idx[0] : self._end_idx[0], self._start_idx[1] : self._end_idx[1]
] = np.reshape(y, self._filter_shape)
x = self._crop(
fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._Hadj, axes=(0, 1)),
axes=(0, 1),
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._Hadj, axes=(1, 2)),
axes=(1, 2),
)
)
return x.ravel()
Expand Down Expand Up @@ -164,21 +164,21 @@ def __init__(
# initialize solvers which will be created when data is set
if diff_penalty is not None:
if diff_penalty == APGDPriors.L2:
self._diff_penalty = diff_lambda * SquaredL2Norm(dim=self._H.shape[1])
self._diff_penalty = diff_lambda * SquaredL2Norm(dim=self._H.shape[2])
else:
assert hasattr(diff_penalty, "jacobianT")
self._diff_penalty = diff_lambda * diff_penalty(dim=self._H.shape[1])
self._diff_penalty = diff_lambda * diff_penalty(dim=self._H.shape[2])
else:
self._diff_penalty = None

if prox_penalty is not None:
if prox_penalty == APGDPriors.L1:
self._prox_penalty = prox_lambda * L1Norm(dim=self._H.shape[1])
self._prox_penalty = prox_lambda * L1Norm(dim=self._H.shape[2])
elif prox_penalty == APGDPriors.NONNEG:
self._prox_penalty = prox_lambda * NonNegativeOrthant(dim=self._H.shape[1])
self._prox_penalty = prox_lambda * NonNegativeOrthant(dim=self._H.shape[2])
else:
try:
self._prox_penalty = prox_lambda * prox_penalty(dim=self._H.shape[1])
self._prox_penalty = prox_lambda * prox_penalty(dim=self._H.shape[2])
except ValueError:
print("Unexpected prior.")
else:
Expand All @@ -198,6 +198,7 @@ def set_data(self, data):
3D (RGB).
"""
print("setting data,", data.shape)
if not self._is_rgb:
assert len(data.shape) == 2
data = data[:, :, np.newaxis]
Expand Down
2 changes: 1 addition & 1 deletion lensless/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC):

if np.array_equal(img_shape, new_shape):
return img

# TODO : Use pytorch.resize if available
resized = np.array([cv2.resize(img[i], dsize=new_shape[::-1], interpolation=interpolation) for i in range(img.shape[0])])

# OpenCV discards channel dimension if it is 1, put it back
Expand Down
1 change: 0 additions & 1 deletion scripts/recon/gradient_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def gradient_descent(
)

recon.set_data(data)
print("Torch : ", config.torch, "Device : ", config.torch_device)
print(f"Setup time : {time.time() - start_time} s")

start_time = time.time()
Expand Down

0 comments on commit e2849d3

Please sign in to comment.