Skip to content

Commit c76f57e

Browse files
merged changes from the apdiocyte challenge
1 parent ddb29ea commit c76f57e

File tree

15 files changed

+488
-267
lines changed

15 files changed

+488
-267
lines changed

deeptrack/aberrations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@
3333
"""
3434

3535
import numpy as np
36-
from deeptrack.features import Feature
37-
from deeptrack.utils import as_list
36+
from .features import Feature
37+
from .utils import as_list
3838

3939

4040
class Aberration(Feature):

deeptrack/augmentations.py

Lines changed: 96 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@
1919
Flips images diagonally.
2020
"""
2121

22-
from deeptrack.features import Feature
23-
from deeptrack.image import Image
22+
from .features import Feature
23+
from .image import Image
24+
from . import utils
25+
2426
import numpy as np
25-
from typing import Callable
27+
import scipy.ndimage as ndimage
2628
from scipy.ndimage.interpolation import map_coordinates
2729
from scipy.ndimage.filters import gaussian_filter
30+
31+
from typing import Callable
2832
import warnings
2933

3034

@@ -72,7 +76,7 @@ def __init__(
7276
**kwargs
7377
):
7478

75-
if load_size is not 1:
79+
if load_size != 1:
7680
warnings.warn(
7781
"Using an augmentation with a load size other than one is no longer supported",
7882
DeprecationWarning,
@@ -115,6 +119,7 @@ def _process_and_get(self, *args, update_properties=None, **kwargs):
115119
not hasattr(self, "cache")
116120
or kwargs["update_tally"] - self.last_update >= kwargs["updates_per_reload"]
117121
):
122+
118123
if isinstance(self.feature, list):
119124
self.cache = [feature.resolve() for feature in self.feature]
120125
else:
@@ -151,11 +156,15 @@ def _process_and_get(self, *args, update_properties=None, **kwargs):
151156
]
152157
)
153158
else:
154-
new_list_of_lists.append(
155-
Image(self.get(Image(image_list), **kwargs)).merge_properties_from(
156-
image_list
157-
)
158-
)
159+
# DANGEROUS
160+
# if not isinstance(image_list, Image):
161+
image_list = Image(image_list)
162+
163+
output = self.get(image_list, **kwargs)
164+
165+
if not isinstance(output, Image):
166+
output = Image(output)
167+
new_list_of_lists.append(output.merge_properties_from(image_list))
159168

160169
if update_properties:
161170
if not isinstance(new_list_of_lists, list):
@@ -252,7 +261,10 @@ def update_properties(self, image, number_of_updates, **kwargs):
252261
for prop in image.properties:
253262
if "position" in prop:
254263
position = prop["position"]
255-
new_position = (image.shape[0] - position[0] - 1, *position[1:])
264+
new_position = (
265+
image.shape[0] - position[0] - 1,
266+
*position[1:],
267+
)
256268
prop["position"] = new_position
257269

258270

@@ -279,13 +291,6 @@ def update_properties(self, image, number_of_updates, **kwargs):
279291
prop["position"] = new_position
280292

281293

282-
from deeptrack.utils import get_kwarg_names
283-
import warnings
284-
285-
import scipy.ndimage as ndimage
286-
import deeptrack.utils as utils
287-
288-
289294
class Affine(Augmentation):
290295
"""
291296
Augmenter to apply affine transformations to images.
@@ -386,7 +391,9 @@ def get(self, image, scale, translate, rotate, shear, **kwargs):
386391

387392
assert (
388393
image.ndim == 2 or image.ndim == 3
389-
), "Affine only supports 2-dimensional or 3-dimension inputs."
394+
), "Affine only supports 2-dimensional or 3-dimension inputs, got {0}".format(
395+
image.ndim
396+
)
390397

391398
dx, dy = translate
392399
fx, fy = scale
@@ -551,7 +558,10 @@ def get(self, image, sigma, alpha, ignore_last_dim, **kwargs):
551558
for dim in shape:
552559
deltas.append(
553560
gaussian_filter(
554-
(np.random.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0
561+
(np.random.rand(*shape) * 2 - 1),
562+
sigma,
563+
mode="constant",
564+
cval=0,
555565
)
556566
* alpha
557567
)
@@ -619,6 +629,8 @@ def get(self, image, corner, crop, crop_mode, **kwargs):
619629
if isinstance(crop, int):
620630
crop = (crop,) * image.ndim
621631

632+
crop = [c if c is not None else image.shape[i] for i, c in enumerate(crop)]
633+
622634
# Get amount to crop from image
623635
if crop_mode == "retain":
624636
crop_amount = np.array(image.shape) - np.array(crop)
@@ -631,12 +643,9 @@ def get(self, image, corner, crop, crop_mode, **kwargs):
631643
crop_amount = np.amax((np.array(crop_amount), [0] * image.ndim), axis=0)
632644
crop_amount = np.amin((np.array(image.shape) - 1, crop_amount), axis=0)
633645
# Get corner of crop
634-
if corner == "random":
646+
if isinstance(corner, str) and corner == "random":
635647
# Ensure seed is consistent
636-
slice_start = np.random.randint(
637-
[0] * crop_amount.size,
638-
crop_amount + 1,
639-
)
648+
slice_start = [np.random.randint(m + 1) for m in crop_amount]
640649
elif callable(corner):
641650
slice_start = corner(image)
642651
else:
@@ -654,6 +663,7 @@ def get(self, image, corner, crop, crop_mode, **kwargs):
654663
for slice_start_i, slice_end_i in zip(slice_start, slice_end)
655664
]
656665
)
666+
657667
cropped_image = image[slices]
658668

659669
# Update positions
@@ -729,15 +739,74 @@ def __init__(self, px=(0, 0, 0, 0), mode="constant", cval=0, **kwargs):
729739
def get(self, image, px, **kwargs):
730740

731741
padding = []
732-
if isinstance(px, int):
742+
if callable(px):
743+
px = px(image)
744+
elif isinstance(px, int):
733745
padding = [(px, px)] * image.ndom
746+
734747
for idx in range(0, len(px), 2):
735748
padding.append((px[idx], px[idx + 1]))
736749

737750
while len(padding) < image.ndim:
738751
padding.append((0, 0))
739752

740-
return utils.safe_call(np.pad, positional_args=(image, padding), **kwargs)
753+
return (
754+
utils.safe_call(np.pad, positional_args=(image, padding), **kwargs),
755+
padding,
756+
)
757+
758+
def _process_and_get(self, images, **kwargs):
759+
results = [self.get(image, **kwargs) for image in images]
760+
for idx, result in enumerate(results):
761+
if isinstance(result, tuple):
762+
shape = result[0].shape
763+
padding = result[1]
764+
de_pad = tuple(
765+
slice(p[0], shape[dim] - p[1]) for dim, p in enumerate(padding)
766+
)
767+
results[idx] = (
768+
Image(result[0]).merge_properties_from(images[idx]),
769+
{"undo_padding": de_pad},
770+
)
771+
else:
772+
Image(results[idx]).merge_properties_from(images[idx])
773+
return results
774+
775+
776+
class PadToMultiplesOf(Pad):
777+
"""Pad images until their height/width is a multiple of a value.
778+
779+
Parameters
780+
----------
781+
multiple : int or tuple of (int or None)
782+
Images will be padded until their width is a multiple of
783+
this value. If a tuple, it is assumed to be a multiple per axis.
784+
A value of None or -1 indicates to skip that axis.
785+
786+
"""
787+
788+
def __init__(self, multiple=1, **kwargs):
789+
def amount_to_pad(image):
790+
shape = image.shape
791+
multiple = self.multiple.current_value
792+
793+
if not isinstance(multiple, (list, tuple, np.ndarray)):
794+
multiple = (multiple,) * image.ndim
795+
new_shape = [0] * (image.ndim * 2)
796+
idx = 0
797+
for dim, mul in zip(shape, multiple):
798+
if mul is not None and mul is not -1:
799+
to_add = -dim % mul
800+
to_add_first = to_add // 2
801+
to_add_after = to_add - to_add_first
802+
new_shape[idx * 2] = to_add_first
803+
new_shape[idx * 2 + 1] = to_add_after
804+
805+
idx += 1
806+
807+
return new_shape
808+
809+
super().__init__(multiple=multiple, px=lambda: amount_to_pad, **kwargs)
741810

742811

743-
# TODO: add resizing by rescaling
812+
# TODO: add resizing by rescaling

0 commit comments

Comments
 (0)