Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D support #41

Closed
wants to merge 223 commits into from
Closed

3D support #41

wants to merge 223 commits into from

Conversation

Julien-Sahli
Copy link
Collaborator

@Julien-Sahli Julien-Sahli commented Apr 25, 2023

3D support for Gradient Descent (and APGD soon), also with pytorch

Benchmarks done with a AMD Ryzen 7 5800H CPU and a NVIDIA Geforce RTX 3050 Laptop GPU

Data used : see data.rst

Gradient Descent, 2D

Default : ~140 s
python scripts/recon/gradient_descent.py

Torch, CPU : ~70 s
python scripts/recon/gradient_descent.py torch=True

Torch, cuda:0 : ~22 s
python scripts/recon/gradient_descent.py -cn pytorch

Gradient Descent, 3D

Default : ~135 s
python scripts/recon/gradient_descent.py input.psf="data/psf/diffuser_cam.npy" input.data="data/raw_data/diffuser_cam.tiff" preprocess.downsample=1

Torch, CPU : ~105 s
python scripts/recon/gradient_descent.py torch=True input.psf="data/psf/diffuser_cam.npy" input.data="data/raw_data/diffuser_cam.tiff" preprocess.downsample=1

Torch, cuda:0 : ~27 s
python scripts/recon/gradient_descent.py -cn pytorch input.psf="data/psf/diffuser_cam.npy" input.data="data/raw_data/diffuser_cam.tiff" preprocess.downsample=1

Update utilities for loading and visualizing.
Add reconstruction template.
Update capture script and add display script.
Add MIR Flicker scripts and update README.
Add support for original DiffuserCam dataset.
lensless/io.py Outdated

if data.shape[3] == 1 and psf.shape[3] > 1:
print("Warning : loaded a RGB PSF with grayscale data. Repeating data across channels.")
print("This may be an error as the PSF and the data are likely from different datasets.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

combine both message and use warnings.warn

3D example
----------

It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

render algorithms with hyperlinks as

:py:class:`~lensless.GradientDescent`
:py:class:`~lensless.ADMM`
:py:class:`~lensless.APGD`

as in here

----------

It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet.
This requires to use a 3D PSF as an input in the form of a .npy file, which actually is a set of 2D PSFs corresponding to the same diffuser sampeled with light sources from different depths.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"of an .npy file..."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"sampled"


It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet.
This requires to use a 3D PSF as an input in the form of a .npy file, which actually is a set of 2D PSFs corresponding to the same diffuser sampeled with light sources from different depths.
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove extra space "The input data for..."

It is also possible to reconstruct 3D scenes using Gradient Descent or APGD. ADMM doesn't supports 3D reconstruction yet.
This requires to use a 3D PSF as an input in the form of a .npy file, which actually is a set of 2D PSFs corresponding to the same diffuser sampeled with light sources from different depths.
The input data for 3D reconstructions is still a 2D image, as collected by the camera. The reconstruction will be able to separate which part of the lensless data corresponds to which 2D PSF,
and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of a .npy file as well as a 2D projection on the depth axis to be displayed to the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"an .npy file. A 2D projection on the depth axis is also displayed to the user."

Outside of the input data and PSF, no special argument has to be given to the script for it to operate a 3D reconstruction, as actually, the 2D reconstuction is internally
viewed as a 3D reconstruction which has only one depth level. It is also the case for ADMM although for now, the reconstructions are wrong when more than one depth level is used.

3D data is not directly provided in the LenslessPiCam, but some can be :doc:`imported <data>` from the Waller Lab dataset. For this data, it is best to set the downsample to 1 :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"provided in LenslessPiCam"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

", but example data can be obtained from Waller Lab."

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"... best to set downsample=1:

and therefore to which depth, effectively generating a 3D reconstruction, which will be outputed in the form of a .npy file as well as a 2D projection on the depth axis to be displayed to the
user as an image.

As for the 2D ADMM reconstuction, scripts for 3D reconstruction can be found in ``scripts/recon/gradient_descent.py`` and ``scripts/recon/apgd_pycsou.py``.
Copy link
Member

@ebezzam ebezzam May 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"As for the 2D ADMM....one depth level is used"

All this can be replaced with:

The same scripts for 2D reconstruction can be used for 3D reconstruction, namely scripts/recon/gradient_descent.py and scripts/recon/apgd_pycsou.py.

@@ -193,16 +213,11 @@ def __init__(self, psf, dtype=None, pad=True, n_iter=100, **kwargs):
if torch_available:
self.is_torch = isinstance(psf, torch.Tensor)

assert len(psf.shape) == 4 # depth, width, height, channel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can add an error message: "PSF must be 4D: [depth, width, height, channel]."

data = data[:, :, None]
assert len(self._psf_shape) == len(data.shape)
assert len(data.shape) == 4
assert len(self._psf_shape) == 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you already have check for PSF shape in constructor

assert len(data.shape) == 2
data = data[:, :, None]
assert len(self._psf_shape) == len(data.shape)
assert len(data.shape) == 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can add similar error message as for PSF

self._n_channels = self._psf.shape[2]
self._psf_shape = np.array(self._psf.shape)

assert len(psf.shape) == 4
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can give error message with expected shape

self._start_idx = (self._padded_shape[:2] - self._psf_shape[:2]) // 2
self._end_idx = self._start_idx + self._psf_shape[:2]
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]])
self._start_idx = (self._padded_shape[1:3] - self._psf_shape[1:3]) // 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you replace with [-3:-1]?

self._padded_shape = list(np.r_[self._padded_shape, [self._n_channels]])
self._start_idx = (self._padded_shape[:2] - self._psf_shape[:2]) // 2
self._end_idx = self._start_idx + self._psf_shape[:2]
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you replace PSF shape indices with negative values? e.g. here

self._padded_shape = list(np.r_[self._psf_shape[-4], self._padded_shape, self._psf_shape[-1]])

self._end_idx = self._start_idx + self._psf_shape[:2]
self._padded_shape = list(np.r_[self._psf_shape[0], self._padded_shape, self._psf_shape[3]])
self._start_idx = (self._padded_shape[1:3] - self._psf_shape[1:3]) // 2
self._end_idx = self._start_idx + self._psf_shape[1:3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[-3:-1]

self.pad = pad # Whether necessary to pad provided data

# precompute filter in frequency domain
if self.is_torch:
self._H = torch.fft.rfft2(
self._pad(self._psf), norm=norm, dim=(0, 1), s=self._padded_shape[:2]
self._pad(self._psf), norm=norm, dim=(1, 2), s=self._padded_shape[1:3]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._pad(self._psf), norm=norm, dim=(-3, -2), s=self._padded_shape[-3:-1]

)
self._Hadj = torch.conj(self._H)
self._padded_data = torch.zeros(size=self._padded_shape, dtype=dtype, device=psf.device)

else:
self._H = fft.rfft2(self._pad(self._psf), axes=(0, 1), norm=norm)
self._H = fft.rfft2(self._pad(self._psf), axes=(1, 2), norm=norm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

conv_output = torch.fft.ifftshift(
torch.fft.irfft2(
torch.fft.rfft2(self._padded_data, dim=(0, 1)) * self._H, dim=(0, 1)
torch.fft.rfft2(self._padded_data, dim=(1, 2)) * self._H, dim=(1, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

),
dim=(0, 1),
dim=(1, 2),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

conv_output = fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._H, axes=(0, 1)),
axes=(0, 1),
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._H, axes=(1, 2)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._H, axes=(0, 1)),
axes=(0, 1),
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._H, axes=(1, 2)),
axes=(1, 2),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

deconv_output = torch.fft.ifftshift(
torch.fft.irfft2(
torch.fft.rfft2(self._padded_data, dim=(0, 1)) * self._Hadj, dim=(0, 1)
torch.fft.rfft2(self._padded_data, dim=(1, 2)) * self._Hadj, dim=(1, 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

),
dim=(0, 1),
dim=(1, 2),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

deconv_output = fft.ifftshift(
fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._Hadj, axes=(0, 1)),
axes=(0, 1),
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._Hadj, axes=(1, 2)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

fft.irfft2(fft.rfft2(self._padded_data, axes=(0, 1)) * self._Hadj, axes=(0, 1)),
axes=(0, 1),
fft.irfft2(fft.rfft2(self._padded_data, axes=(1, 2)) * self._Hadj, axes=(1, 2)),
axes=(1, 2),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants