Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 154 additions & 12 deletions deeptrack/optical/optics.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@
- `_pad_volume(volume, limits, padding, output_region, **kwargs)`
Pads a volume with zeros to avoid edge effects during imaging.

- `_merge_placed_volumes(contrast_volumes, contrast_limits)`
Merges multiple placed volumes into a single volume based on their positions.

Examples
--------
>>> import deeptrack as dt
Expand Down Expand Up @@ -373,19 +376,37 @@ def get(
if isinstance(scatterer, ScatteredField)
]

# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
volume_samples,
**additional_sample_kwargs,
)

if volume_samples:
# Interpret the merged volume semantically
sample_volume = self._extract_contrast_volume(
ScatteredVolume(
array=sample_volume,
properties=volume_samples[0].properties,
),
contrast_volumes = []
contrast_limits = []

for scatterer in volume_samples:
placed, limits_i = _create_volume(
[scatterer],
**additional_sample_kwargs,
)

if limits_i is None:
continue

contrast_i = self._extract_contrast_volume(
ScatteredVolume(
array=placed,
properties=scatterer.properties,
)
)

contrast_volumes.append(contrast_i)
contrast_limits.append(limits_i)

sample_volume, limits = _merge_placed_volumes(
contrast_volumes,
contrast_limits,
)
else:
sample_volume, limits = _create_volume(
volume_samples,
**additional_sample_kwargs,
)

# Let the objective know about the limits of the volume and all the fields.
Expand Down Expand Up @@ -1296,6 +1317,7 @@ def extract_contrast_volume(
itself.

"""

scale = np.asarray(get_active_scale(), float)
scale_volume = np.prod(scale)

Expand Down Expand Up @@ -3791,3 +3813,123 @@ def _create_volume(
if limits is not None:
limits = torch.as_tensor(limits, dtype=torch.int32, device=device)
return volume, limits


# This can be reafctored within _create_volume, but it is cleaner to keep it
# separate for now.
def _merge_placed_volumes(
volumes: list[np.ndarray | torch.Tensor],
limits_list: list[np.ndarray | torch.Tensor]
) -> tuple[np.ndarray | torch.Tensor, np.ndarray | torch.Tensor | None]:

"""Merges already-positioned volumes with known limits.

This function takes a list of volumes and their corresponding limits,
computes the global limits that encompass all volumes, and merges the
volumes into a single volume based on their positions. The merging is done
by summing the volumes in their respective positions within the global
limits. It is necessary to allow different scatterers within the same
volume to keep their individual gradients, which is why the merging is done
at this stage rather than during the initial volume creation.

Parameters
----------
volumes : list[np.ndarray | torch.Tensor]
Volumes already placed by _create_volume([scatterer]).
limits_list : list[np.ndarray | torch.Tensor]
Corresponding limits for each volume.

Returns
-------
merged : np.ndarray | torch.Tensor
The merged volume containing all input volumes positioned according to
their limits.
global_limits : np.ndarray | torch.Tensor | None
An array of shape (3, 2) giving the global bounds of the merged volume
in the format [[x_min, x_max], [y_min, y_max], [z_min, z_max]]. Returns
`None` if the input list of volumes is empty.

"""

if len(volumes) == 0:
return np.zeros((1, 1, 1)), None

backend = config.get_backend()

# Limits are integer geometry, so they do not need gradients.
limits_np = [
(
l.detach().cpu().numpy()
if TORCH_AVAILABLE and isinstance(l, torch.Tensor)
else np.asarray(l)
)
for l in limits_list
]

global_limits = np.zeros((3, 2), dtype=np.int32)
global_limits[:, 0] = np.min([l[:, 0] for l in limits_np], axis=0)
global_limits[:, 1] = np.max([l[:, 1] for l in limits_np], axis=0)

shape = np.diff(global_limits, axis=1)[:, 0].astype(int)

if backend == "torch":
device = None
dtype = None

for v in volumes:
if TORCH_AVAILABLE and isinstance(v, torch.Tensor):
device = v.device
dtype = v.dtype
break

if device is None:
device = torch.device("cpu")
if dtype is None:
dtype = torch.float32

merged = torch.zeros(tuple(shape), dtype=dtype, device=device)

for v, lim in zip(volumes, limits_np):
if not isinstance(v, torch.Tensor):
v = torch.as_tensor(v, dtype=dtype, device=device)

offset = lim[:, 0] - global_limits[:, 0]
sx, sy, sz = v.shape

merged[
offset[0] : offset[0] + sx,
offset[1] : offset[1] + sy,
offset[2] : offset[2] + sz,
] = (
merged[
offset[0] : offset[0] + sx,
offset[1] : offset[1] + sy,
offset[2] : offset[2] + sz,
]
+ v
)

global_limits = torch.as_tensor(
global_limits,
dtype=torch.int32,
device=device,
)

return merged, global_limits

else:
dtype = np.result_type(*[np.asarray(v).dtype for v in volumes])
merged = np.zeros(tuple(shape), dtype=dtype)

for v, lim in zip(volumes, limits_np):
v = np.asarray(v)
offset = lim[:, 0] - global_limits[:, 0]
sx, sy, sz = v.shape

merged[
offset[0] : offset[0] + sx,
offset[1] : offset[1] + sy,
offset[2] : offset[2] + sz,
] += v

return merged, global_limits
90 changes: 45 additions & 45 deletions tests/test_dlcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,31 +761,31 @@ def test_5_A(self):
>> normalization
)

expected_1 = np.array(
[[[-0.00362469], [ 0.11237531], [ 0.23637531], [ 0.32037531], [ 0.38037531], [ 0.34437531], [ 0.17637531], [ 0.16837531]],
[[ 0.07237531], [ 0.15237531], [ 0.28837531], [ 0.48837531], [ 0.46837531], [ 0.48437531], [ 0.46037531], [ 0.30837531]],
[[ 0.04837531], [ 0.12437531], [ 0.32837531], [ 0.43237531], [ 0.63637531], [ 0.65237531], [ 0.56437531], [ 0.37237531]],
[[ 0.02037531], [ 0.07637531], [ 0.16437531], [ 0.26837531], [ 0.57237531], [ 0.60837531], [ 0.67637531], [ 0.45237531]],
[[ 0.04437531], [ 0.02437531], [ 0.20437531], [ 0.30437531], [ 0.47237531], [ 0.54037531], [ 0.63237531], [ 0.40037531]],
[[ 0.07237531], [ 0.12437531], [ 0.24837531], [ 0.22037531], [ 0.38037531], [ 0.39237531], [ 0.42037531], [ 0.31237531]],
[[ 0.02837531], [ 0.12037531], [ 0.22837531], [ 0.33237531], [ 0.31637531], [ 0.22037531], [ 0.19637531], [ 0.19237531]],
[[-0.01962469], [ 0.08037531], [ 0.16037531], [ 0.22437531], [ 0.27237531], [ 0.14037531], [ 0.10437531], [ 0.03637531]]]
)
expected_1 = np.array([
[[-0.04551237], [ 0.03848763], [ 0.11448763], [ 0.17048763], [ 0.23448763], [ 0.23448763], [ 0.11048763], [ 0.11848763]],
[[ 0.02248763], [ 0.06648763], [ 0.14648763], [ 0.29448763], [ 0.29448763], [ 0.35048763], [ 0.36648763], [ 0.24248763]],
[[ 0.00248763], [ 0.04648763], [ 0.19048763], [ 0.26248763], [ 0.45848763], [ 0.50648763], [ 0.45848763], [ 0.29448763]],
[[-0.02151237], [ 0.00648763], [ 0.08248763], [ 0.19048763], [ 0.30248763], [ 0.58648763], [ 0.50648763], [ 0.43048763]],
[[-0.01751237], [ 0.04248763], [ 0.04248763], [ 0.26248763], [ 0.37048763], [ 0.50248763], [ 0.46648763], [ 0.41848763]],
[[-0.01751237], [ 0.09848763], [ 0.18248763], [ 0.31448763], [ 0.23848763], [ 0.35048763], [ 0.29448763], [ 0.25048763]],
[[ 0.01048763], [ 0.05448763], [ 0.19448763], [ 0.27848763], [ 0.30648763], [ 0.22648763], [ 0.10248763], [ 0.07448763]],
[[ 0.01848763], [-0.01351237], [ 0.12648763], [ 0.18648763], [ 0.17848763], [ 0.16648763], [ 0.03448763], [ 0.02248763]]
])

with self._suppress_expected_optics_warnings():
np.testing.assert_allclose(sim_im_pip(), expected_1,
rtol=1e-7, atol=1e-7)

expected_2 = np.array(
[[[ 0.05024189], [ 0.03024189], [ 0.05824189], [ 0.13024189], [ 0.07824189], [ 0.12224189], [ 0.13424189], [ 0.11824189]],
[[ 0.08224189], [ 0.02624189], [ 0.08624189], [ 0.11024189], [ 0.11024189], [ 0.13824189], [ 0.11824189], [ 0.16224189]],
[[ 0.09024189], [ 0.05424189], [ 0.04224189], [ 0.04224189], [ 0.06624189], [ 0.15824189], [ 0.11424189], [ 0.04224189]],
[[ 0.09024189], [ 0.00624189], [ 0.05424189], [ 0.05424189], [ 0.05424189], [ 0.05024189], [ 0.01424189], [ 0.02624189]],
[[-0.00575811], [ 0.02224189], [ 0.03424189], [ 0.04224189], [ 0.07424189], [ 0.00624189], [ 0.03424189], [ 0.01824189]],
[[-0.02175811], [-0.00575811], [ 0.01024189], [ 0.03024189], [ 0.05024189], [ 0.05424189], [ 0.08224189], [ 0.07024189]],
[[ 0.04624189], [-0.04575811], [-0.00175811], [ 0.02624189], [ 0.05424189], [ 0.12224189], [ 0.15024189], [ 0.11424189]],
[[-0.02175811], [-0.01775811], [-0.01375811], [-0.02175811], [ 0.04624189], [ 0.18624189], [ 0.22624189], [ 0.19024189]]]
)
expected_2 = np.array([
[[0.05257224], [0.05257224], [0.08457224], [0.05657224], [0.12057224], [0.12057224], [0.10857224], [0.12057224]],
[[0.13257224], [0.13657224], [0.12857224], [0.12857224], [0.06457224], [0.06057224], [0.15657224], [0.19657224]],
[[0.19257224], [0.23657224], [0.18057224], [0.12057224], [0.08857224], [0.09657224], [0.20057224], [0.16057224]],
[[0.26857224], [0.25257224], [0.30457224], [0.17657224], [0.13257224], [0.14057224], [0.25257224], [0.26057224]],
[[0.46457224], [0.57657224], [0.36857224], [0.20857224], [0.11657224], [0.12057224], [0.20457224], [0.20457224]],
[[0.51257224], [0.50857224], [0.40457224], [0.28457224], [0.18457224], [0.20457224], [0.18457224], [0.20457224]],
[[0.53657224], [0.52457224], [0.37257224], [0.23657224], [0.08057224], [0.16057224], [0.12457224], [0.08457224]],
[[0.29657224], [0.28457224], [0.29657224], [0.21657224], [0.11657224], [0.11657224], [0.08457224], [0.08857224]]
])

with self._suppress_expected_optics_warnings():
np.testing.assert_allclose(sim_im_pip.update()(), expected_2,
Expand Down Expand Up @@ -976,14 +976,14 @@ def random_ellipse_axes():
sim_im_pip = optics(ellipse)

# Checks
expected_image = np.array(
[[[0.60265415], [0.94844141], [1.14489087], [1.16483931], [1.13598992], [0.90247759]],
[[1.199768 ], [1.51251191], [1.74839492], [1.77029627], [1.72956925], [1.42921194]],
[[1.73096144], [1.825617 ], [1.87179117], [1.87245093], [1.84518863], [1.74890568]],
[[1.77330325], [1.85308512], [1.87854141], [1.87606849], [1.82860277], [1.74103692]],
[[1.53892305], [1.76151488], [1.79291875], [1.77124261], [1.55013829], [1.30663407]],
[[1.02576262], [1.2719972 ], [1.3016064 ], [1.27185945], [0.99481222], [0.63890969]]]
)
expected_image = np.array([
[[0.60265415], [0.94844141], [1.14489087], [1.16483931], [1.13598992], [0.90247759]],
[[1.199768 ], [1.51251191], [1.74839492], [1.77029627], [1.72956925], [1.42921194]],
[[1.73096144], [1.825617 ], [1.87179117], [1.87245093], [1.84518863], [1.74890568]],
[[1.77330325], [1.85308512], [1.87854141], [1.87606849], [1.82860277], [1.74103692]],
[[1.53892305], [1.76151488], [1.79291875], [1.77124261], [1.55013829], [1.30663407]],
[[1.02576262], [1.2719972 ], [1.3016064 ], [1.27185945], [0.99481222], [0.63890969]]
])
with self._suppress_expected_optics_warnings():
image = sim_im_pip()
try: # Occasional error in Ubuntu system
Expand Down Expand Up @@ -1028,15 +1028,15 @@ def random_ellipse_axes():
sim_im_pip = optics(synthetic_nuclei)

# Checks
expected_image = np.array(
[[[2.40686748], [3.57908632], [4.82880076], [5.75153091], [6.06963462], [5.57287094]],
[[3.3805992 ], [4.90299411], [6.20417148], [6.90634308], [6.83577577], [6.16933536]],
[[4.18260712], [5.97708425], [7.23448674], [7.48806701], [6.93908065], [5.96343732]],
[[4.27119652], [6.1665758 ], [7.29078817], [7.50901958], [6.86948897], [5.63460567]],
[[3.87612061], [5.88381024], [6.76433577], [7.00694866], [6.62352318], [5.28112149]],
[[3.07807345], [5.21008639], [6.18438896], [6.43448107], [6.07102741], [4.76105099]]]
)

expected_image = np.array([
[[2.33252681], [3.65196645], [5.13174265], [6.27625839], [6.73811592], [6.26292382]],
[[3.16183944], [4.93678219], [6.52890183], [7.46090439], [7.4992717 ], [6.85308454]],
[[3.86326879], [5.90470511], [7.44244168], [7.90248938], [7.47137925], [6.53303733]],
[[3.98699975], [6.06383981], [7.33834636], [7.70781358], [7.17165326], [5.98941416]],
[[3.67619447], [5.79844907], [6.74164257], [7.04195223], [6.70675367], [5.38089255]],
[[2.97939729], [5.17502092], [6.15118462], [6.3836291 ], [5.94495374], [4.62389848]]
])
with self._suppress_expected_optics_warnings():
image = sim_im_pip()
try: # Occasional error in Ubuntu system
Expand Down Expand Up @@ -1106,14 +1106,14 @@ def random_ellipse_axes():
sim_im_pip = optics(noisy_synthetic_nuclei)

# Checks
expected_image = np.array(
[[[1.93167944], [2.69410402], [3.66954369], [4.37636897], [4.48323595], [4.12289828]],
[[2.53519759], [3.60325565], [4.58956314], [5.15477629], [5.08360439], [4.53870126]],
[[3.21864851], [4.44624058], [5.32791401], [5.62321336], [5.41770116], [4.70481539]],
[[3.46683641], [4.70335513], [5.51074196], [5.77536735], [5.50595722], [4.68637212]],
[[3.34183827], [4.5430821 ], [5.33049864], [5.58676063], [5.30614662], [4.38580553]],
[[2.96852351], [4.1349709 ], [4.83801129], [4.96868391], [4.6222409 ], [3.84192146]]]
)
expected_image = np.array([
[[1.90183535], [2.764915 ], [3.92393155], [4.77404766], [4.93428988], [4.57357807]],
[[2.4628794 ], [3.68992591], [4.84933298], [5.54432313], [5.53135723], [4.98969865]],
[[3.13080826], [4.5246192 ], [5.5582536 ], [5.95485969], [5.80014392], [5.09092007]],
[[3.38026124], [4.74919891], [5.67728695], [6.0135335 ], [5.77598116], [4.95524912]],
[[3.24843652], [4.54458406], [5.4266058 ], [5.73338516], [5.44874544], [4.50070778]],
[[2.87934957], [4.11712322], [4.88513823], [5.04821763], [4.65348891], [3.82045886]]
])

with self._suppress_expected_optics_warnings():
image = sim_im_pip()
Expand Down