1919 Flips images diagonally.
2020"""
2121
22- from deeptrack .features import Feature
23- from deeptrack .image import Image
22+ from .features import Feature
23+ from .image import Image
24+ from . import utils
25+
2426import numpy as np
25- from typing import Callable
27+ import scipy . ndimage as ndimage
2628from scipy .ndimage .interpolation import map_coordinates
2729from scipy .ndimage .filters import gaussian_filter
30+
31+ from typing import Callable
2832import warnings
2933
3034
@@ -72,7 +76,7 @@ def __init__(
7276 ** kwargs
7377 ):
7478
75- if load_size is not 1 :
79+ if load_size != 1 :
7680 warnings .warn (
7781 "Using an augmentation with a load size other than one is no longer supported" ,
7882 DeprecationWarning ,
@@ -115,6 +119,7 @@ def _process_and_get(self, *args, update_properties=None, **kwargs):
115119 not hasattr (self , "cache" )
116120 or kwargs ["update_tally" ] - self .last_update >= kwargs ["updates_per_reload" ]
117121 ):
122+
118123 if isinstance (self .feature , list ):
119124 self .cache = [feature .resolve () for feature in self .feature ]
120125 else :
@@ -151,11 +156,15 @@ def _process_and_get(self, *args, update_properties=None, **kwargs):
151156 ]
152157 )
153158 else :
154- new_list_of_lists .append (
155- Image (self .get (Image (image_list ), ** kwargs )).merge_properties_from (
156- image_list
157- )
158- )
159+ # DANGEROUS
160+ # if not isinstance(image_list, Image):
161+ image_list = Image (image_list )
162+
163+ output = self .get (image_list , ** kwargs )
164+
165+ if not isinstance (output , Image ):
166+ output = Image (output )
167+ new_list_of_lists .append (output .merge_properties_from (image_list ))
159168
160169 if update_properties :
161170 if not isinstance (new_list_of_lists , list ):
@@ -252,7 +261,10 @@ def update_properties(self, image, number_of_updates, **kwargs):
252261 for prop in image .properties :
253262 if "position" in prop :
254263 position = prop ["position" ]
255- new_position = (image .shape [0 ] - position [0 ] - 1 , * position [1 :])
264+ new_position = (
265+ image .shape [0 ] - position [0 ] - 1 ,
266+ * position [1 :],
267+ )
256268 prop ["position" ] = new_position
257269
258270
@@ -279,13 +291,6 @@ def update_properties(self, image, number_of_updates, **kwargs):
279291 prop ["position" ] = new_position
280292
281293
282- from deeptrack .utils import get_kwarg_names
283- import warnings
284-
285- import scipy .ndimage as ndimage
286- import deeptrack .utils as utils
287-
288-
289294class Affine (Augmentation ):
290295 """
291296 Augmenter to apply affine transformations to images.
@@ -386,7 +391,9 @@ def get(self, image, scale, translate, rotate, shear, **kwargs):
386391
387392 assert (
388393 image .ndim == 2 or image .ndim == 3
389- ), "Affine only supports 2-dimensional or 3-dimension inputs."
394+ ), "Affine only supports 2-dimensional or 3-dimension inputs, got {0}" .format (
395+ image .ndim
396+ )
390397
391398 dx , dy = translate
392399 fx , fy = scale
@@ -551,7 +558,10 @@ def get(self, image, sigma, alpha, ignore_last_dim, **kwargs):
551558 for dim in shape :
552559 deltas .append (
553560 gaussian_filter (
554- (np .random .rand (* shape ) * 2 - 1 ), sigma , mode = "constant" , cval = 0
561+ (np .random .rand (* shape ) * 2 - 1 ),
562+ sigma ,
563+ mode = "constant" ,
564+ cval = 0 ,
555565 )
556566 * alpha
557567 )
@@ -619,6 +629,8 @@ def get(self, image, corner, crop, crop_mode, **kwargs):
619629 if isinstance (crop , int ):
620630 crop = (crop ,) * image .ndim
621631
632+ crop = [c if c is not None else image .shape [i ] for i , c in enumerate (crop )]
633+
622634 # Get amount to crop from image
623635 if crop_mode == "retain" :
624636 crop_amount = np .array (image .shape ) - np .array (crop )
@@ -631,12 +643,9 @@ def get(self, image, corner, crop, crop_mode, **kwargs):
631643 crop_amount = np .amax ((np .array (crop_amount ), [0 ] * image .ndim ), axis = 0 )
632644 crop_amount = np .amin ((np .array (image .shape ) - 1 , crop_amount ), axis = 0 )
633645 # Get corner of crop
634- if corner == "random" :
646+ if isinstance ( corner , str ) and corner == "random" :
635647 # Ensure seed is consistent
636- slice_start = np .random .randint (
637- [0 ] * crop_amount .size ,
638- crop_amount + 1 ,
639- )
648+ slice_start = [np .random .randint (m + 1 ) for m in crop_amount ]
640649 elif callable (corner ):
641650 slice_start = corner (image )
642651 else :
@@ -654,6 +663,7 @@ def get(self, image, corner, crop, crop_mode, **kwargs):
654663 for slice_start_i , slice_end_i in zip (slice_start , slice_end )
655664 ]
656665 )
666+
657667 cropped_image = image [slices ]
658668
659669 # Update positions
@@ -729,15 +739,74 @@ def __init__(self, px=(0, 0, 0, 0), mode="constant", cval=0, **kwargs):
729739 def get (self , image , px , ** kwargs ):
730740
731741 padding = []
732- if isinstance (px , int ):
742+ if callable (px ):
743+ px = px (image )
744+ elif isinstance (px , int ):
733745 padding = [(px , px )] * image .ndom
746+
734747 for idx in range (0 , len (px ), 2 ):
735748 padding .append ((px [idx ], px [idx + 1 ]))
736749
737750 while len (padding ) < image .ndim :
738751 padding .append ((0 , 0 ))
739752
740- return utils .safe_call (np .pad , positional_args = (image , padding ), ** kwargs )
753+ return (
754+ utils .safe_call (np .pad , positional_args = (image , padding ), ** kwargs ),
755+ padding ,
756+ )
757+
758+ def _process_and_get (self , images , ** kwargs ):
759+ results = [self .get (image , ** kwargs ) for image in images ]
760+ for idx , result in enumerate (results ):
761+ if isinstance (result , tuple ):
762+ shape = result [0 ].shape
763+ padding = result [1 ]
764+ de_pad = tuple (
765+ slice (p [0 ], shape [dim ] - p [1 ]) for dim , p in enumerate (padding )
766+ )
767+ results [idx ] = (
768+ Image (result [0 ]).merge_properties_from (images [idx ]),
769+ {"undo_padding" : de_pad },
770+ )
771+ else :
772+ Image (results [idx ]).merge_properties_from (images [idx ])
773+ return results
774+
775+
776+ class PadToMultiplesOf (Pad ):
777+ """Pad images until their height/width is a multiple of a value.
778+
779+ Parameters
780+ ----------
781+ multiple : int or tuple of (int or None)
782+ Images will be padded until their width is a multiple of
783+ this value. If a tuple, it is assumed to be a multiple per axis.
784+ A value of None or -1 indicates to skip that axis.
785+
786+ """
787+
788+ def __init__ (self , multiple = 1 , ** kwargs ):
789+ def amount_to_pad (image ):
790+ shape = image .shape
791+ multiple = self .multiple .current_value
792+
793+ if not isinstance (multiple , (list , tuple , np .ndarray )):
794+ multiple = (multiple ,) * image .ndim
795+ new_shape = [0 ] * (image .ndim * 2 )
796+ idx = 0
797+ for dim , mul in zip (shape , multiple ):
798+ if mul is not None and mul is not - 1 :
799+ to_add = - dim % mul
800+ to_add_first = to_add // 2
801+ to_add_after = to_add - to_add_first
802+ new_shape [idx * 2 ] = to_add_first
803+ new_shape [idx * 2 + 1 ] = to_add_after
804+
805+ idx += 1
806+
807+ return new_shape
808+
809+ super ().__init__ (multiple = multiple , px = lambda : amount_to_pad , ** kwargs )
741810
742811
743- # TODO: add resizing by rescaling
812+ # TODO: add resizing by rescaling
0 commit comments