-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathholography.py
436 lines (340 loc) · 11.6 KB
/
holography.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
"""Core features for manipulating optical fields using Fourier transforms and
propagation matrices.
This module includes operations to simulate optical field propagation and
perform transformations in the frequency domain. These features can be combined
in processing pipelines for optical simulations and holographic
reconstructions.
Key Features
------------
- **Optical Field Processing**
Provides Fourier transforms, rescaling, and wavefront propagation for
complex-valued optical fields, handling both real and imaginary components.
- **Fourier Optics and Wave Propagation**
Implements Fourier transforms with optional padding for accurate
frequency-domain analysis and propagation matrices to simulate free-space
wavefront propagation with spatial and frequency domain shifts.
- **Phase & Amplitude Manipulation**
Enables scaling, normalization, and modulation of phase and amplitude to
preserve intensity distribution and enhance wavefront reconstruction.
Module Structure
----------------
Classes:
- `Rescale`:
Rescales an optical field by subtracting the real part of the
field before multiplication.
- `FourierTransform`:
Creates matrices for propagating an optical field.
- `InverseFourierTransform`:
Creates matrices for propagating an optical field.
- `FourierTransformTransformation`:
Applies a power of the forward or inverse
propagation matrix to an optical field.
Functions:
- `get_propagation_matrix`
def get_propagation_matrix(
shape: tuple[int, int],
to_z: float,
pixel_size: float,
wavelength: float,
dx: float = 0,
dy: float = 0
) -> np.ndarray
Computes the propagation matrix.
Examples
--------
Simulate optical field propagation with Fourier transforms:
>>> import deeptrack as dt
>>> import numpy as np
Define a random optical field:
>>> field = np.random.rand(128, 128, 2)
Rescale the field and compute the Fourier transform:
>>> rescale_op = dt.holography.Rescale(0.5)
>>> scaled_field = rescale_op(field)
>>> ft_op = dt.holography.FourierTransform()
>>> transformed_field = ft_op(scaled_field)
Reconstruct the field using the inverse Fourier transform:
>>> ift_op = dt.holography.InverseFourierTransform()
>>> reconstructed_field = ift_op(transformed_field)
"""
from __future__ import annotations
from typing import Any
from deeptrack.image import maybe_cupy, Image
from deeptrack import Feature
import numpy as np
def get_propagation_matrix(
shape: tuple[int, int],
to_z: float,
pixel_size: float,
wavelength: float,
dx: float = 0,
dy: float = 0
) -> np.ndarray:
"""Computes the propagation matrix for simulating the propagation of an
optical field.
The propagation matrix is used to model wavefront propagation in free space
based on the angular spectrum method.
Parameters
----------
shape: tuple[int, int]
The dimensions of the optical field (height, width).
to_z: float
Propagation distance along the z-axis.
pixel_size: float
The physical size of each pixel in the optical field.
wavelength: float
The wavelength of the optical field.
dx: float, optional
Lateral shift in the x-direction (default: 0).
dy: float, optional
Lateral shift in the y-direction (default: 0).
Returns
-------
np.ndarray
A complex-valued 2D NumPy array representing the propagation matrix.
Notes
-----
- Uses `np.fft.fftshift` to shift the zero-frequency component to the
center.
- Computed based on the wave equation in Fourier space.
"""
k = 2 * np.pi / wavelength
yr, xr, *_ = shape
x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2
y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2
x = 2 * np.pi / pixel_size * x / xr
y = 2 * np.pi / pixel_size * y / yr
KXk, KYk = np.meshgrid(x, y)
KXk = maybe_cupy(KXk.astype(complex))
KYk = maybe_cupy(KYk.astype(complex))
K = np.real(np.sqrt(1 - (KXk / k) ** 2 - (KYk / k) ** 2))
C = np.fft.fftshift(((KXk / k) ** 2 + (KYk / k) ** 2 < 1) * 1.0)
return C * np.fft.fftshift(
np.exp(k * 1j * (to_z * (K - 1) - dx * KXk / k - dy * KYk / k))
)
class Rescale(Feature):
"""Rescales an optical field by modifying its real and imaginary
components.
The transformation is applied as:
- The real part is shifted and scaled: `(real - 1) * rescale + 1`
- The imaginary part is scaled by `rescale`
Parameters
----------
rescale: float
The scaling factor applied to both real and imaginary components.
Methods
-------
`get(image: Image | np.ndarray, rescale: float, **kwargs: dict[str, Any]) -> Image | np.ndarray`
Rescales the image while preserving phase information.
Examples
--------
>>> import deeptrack as dt
>>> import numpy as np
>>> field = np.random.rand(128, 128, 2)
>>> rescaled_field = dt.holography.Rescale(0.5)(field)
"""
def __init__(self, rescale=1, **kwargs):
super().__init__(rescale=rescale, **kwargs)
def get(
self: Rescale,
image: Image | np.ndarray,
rescale: float,
**kwargs: dict[str, Any],
) -> Image | np.ndarray:
"""Rescales the image by subtracting the real part of the field before
multiplication.
Parameters
----------
image: Image or ndarray
The image to rescale.
rescale: float
The rescaling factor.
**kwargs: dict of str to Any
Additional keyword arguments.
Returns
-------
Image or ndarray
The rescaled image.
"""
image = np.array(image)
image[..., 0] = (image[..., 0] - 1) * rescale + 1
image[..., 1] *= rescale
return image
class FourierTransform(Feature):
"""Computes the Fourier transform of an optical field with optional
symmetric padding.
The Fourier transform converts a spatial-domain optical field into
the frequency domain.
Parameters
----------
padding: int, optional
Number of pixels to pad symmetrically around the image (default: 32).
Methods
-------
`get(image: Image | np.ndarray, padding: int, **kwargs: dict[str, Any]) -> np.ndarray`
Computes the 2D Fourier transform of the input image.
Returns
-------
np.ndarray
The complex Fourier-transformed image.
Notes
-----
- Uses `np.fft.fft2` for fast computation.
- Pads the image symmetrically to avoid edge artifacts.
- Returns a complex-valued result.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get(
self: FourierTransform,
image: Image | np.ndarray,
padding: int = 32,
**kwargs: dict[str, Any],
) -> np.ndarray:
"""Computes the Fourier transform of the image.
Parameters
----------
image: Image or ndarray
The image to transform.
padding: int, optional
Number of pixels to pad symmetrically around the image (default is 32).
**kwargs: dict of str to Any
Returns
-------
np.ndarray
The Fourier transform of the image.
"""
im = np.copy(image[..., 0] + 1j * image[..., 1])
im = np.pad(
im,
((padding, padding), (padding, padding)),
mode="symmetric"
)
f1 = np.fft.fft2(im)
return f1
class InverseFourierTransform(Feature):
"""Applies a power of the forward or inverse propagation matrix to an
optical field.
This operation simulates multiple propagation steps in Fourier optics.
Negative values of `i` apply the inverse transformation.
Parameters
----------
Tz: np.ndarray
Forward propagation matrix.
Tzinv: np.ndarray
Inverse propagation matrix.
i: int
Power of the propagation matrix to apply. Negative values apply the
inverse.
Methods
-------
`get(image: Image | np.ndarray, padding: int, **kwargs: dict[str, Any]) -> np.ndarray`
Applies the power of the propagation matrix to the image.
Returns
-------
Image | np.ndarray
The transformed image.
Examples
--------
>>> import deeptrack as dt
>>> import numpy as np
>>> Tz = np.random.rand(128, 128) + 1j * np.random.rand(128, 128)
>>> Tzinv = 1 / Tz
>>> field = np.random.rand(128, 128, 2)
>>> transformed_field = dt.holography.FourierTransformTransformation(
>>> Tz, Tzinv, i=2,
>>> )(field)
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get(
self: InverseFourierTransform,
image: Image | np.ndarray,
padding: int = 32,
**kwargs: dict[str, Any],
) -> Image | np.ndarray:
"""Computes the inverse Fourier transform and removes padding.
Parameters
----------
image: Image or ndarray
The image to transform.
padding: int, optional
Number of pixels removed symmetrically after inverse transformation
(default is 32).
**kwargs: dict of str to Any
Returns
-------
np.ndarray
The inverse Fourier transform of the image.
"""
im = np.fft.ifft2(image)
imnew = np.zeros(
(image.shape[0] - padding * 2, image.shape[1] - padding * 2, 2)
)
imnew[..., 0] = np.real(im[padding:-padding, padding:-padding])
imnew[..., 1] = np.imag(im[padding:-padding, padding:-padding])
return imnew
class FourierTransformTransformation(Feature):
"""Applies a power of the forward or inverse propagation matrix to an
optical field.
Parameters
----------
Tz: ndarray
Forward propagation matrix.
Tzinv: ndarray
Inverse propagation matrix.
i: int
Power of the propagation matrix to apply. Negative values apply the
inverse.
Methods
-------
`get(image: Image | np.ndarray, Tz: np.ndarray, Tzinv: np.ndarray, i: int, **kwargs: dict[str, Any]) -> Image | np.ndarray`
Applies the power of the propagation matrix to the image.
Returns
-------
Image | np.ndarray
The transformed image.
Examples
--------
>>> import deeptrack as dt
>>> import numpy as np
>>> Tz = np.random.rand(128, 128) + 1j * np.random.rand(128, 128)
>>> Tzinv = 1 / Tz
>>> field = np.random.rand(128, 128, 2)
>>> transformed_field = dt.holography.FourierTransformTransformation(
>>> Tz, Tzinv, i=2,
>>> )(field)
"""
def __init__(self, Tz, Tzinv, i, **kwargs):
super().__init__(Tz=Tz, Tzinv=Tzinv, i=i, **kwargs)
def get(
self: FourierTransformTransformation,
image: Image | np.ndarray,
Tz: np.ndarray,
Tzinv: np.ndarray,
i: int,
**kwargs: dict[str, Any],
) -> Image | np.ndarray:
"""Applies the power of the propagation matrix to the image.
Parameters
----------
image: Image or ndarray
The image to transform.
Tz: np.ndarray
Forward propagation matrix.
Tzinv: np.ndarray
Inverse propagation matrix.
i: int
Power of the propagation matrix to apply. Negative values apply the
inverse.
**kwargs: dict of str to Any
Additional keyword arguments.
Returns
-------
Image or ndarray
The transformed image.
"""
if i < 0:
image *= Tzinv ** np.abs(i)
else:
image *= Tz ** i
return image