Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed May 30, 2024
1 parent 488be6d commit 7b78453
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions kornia/augmentation/container/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
_MSK_OPTIONS = {DataKey.MASK}
_CLS_OPTIONS = {DataKey.CLASS, DataKey.LABEL}

MaskDataType = Union[Tensor, List[Tensor]]


class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
r"""AugmentationSequential for handling multiple input types like inputs, masks, keypoints at once.
Expand Down Expand Up @@ -340,10 +342,16 @@ def _arguments_preproc(self, *args: DataType, data_keys: List[DataKey]) -> List[
inp: List[DataType] = []
for arg, dcate in zip(args, data_keys):
if DataKey.get(dcate) in _IMG_OPTIONS:
arg = cast(Tensor, arg)
self.input_dtype = arg.dtype
inp.append(arg)
elif DataKey.get(dcate) in _MSK_OPTIONS:
self.mask_dtype = arg[0].dtype if isinstance(inp, list) else arg.dtype
if isinstance(inp, list):
arg = cast(List[Tensor], arg)
self.mask_dtype = arg[0].dtype
else:
arg = cast(Tensor, arg)
self.mask_dtype = arg.dtype
inp.append(self._preproc_mask(arg))
elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
inp.append(self._preproc_keypoints(arg, dcate))
Expand All @@ -365,7 +373,7 @@ def _arguments_postproc(
out.append(out_arg)
# TODO: may add the float to integer (for masks), etc.
elif DataKey.get(dcate) in _MSK_OPTIONS:
_out_k = self._postproc_mask(out_arg)
_out_k = self._postproc_mask(cast(MaskDataType, out_arg))
out.append(_out_k)

elif DataKey.get(dcate) in _KEYPOINTS_OPTIONS:
Expand Down Expand Up @@ -486,7 +494,7 @@ def retrieve_key(key: str) -> DataKey:

return [DataKey.get(retrieve_key(k)) for k in keys]

def _preproc_mask(self, arg: DataType) -> Tensor:
def _preproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
new_arg = []
for a in arg:
Expand All @@ -498,7 +506,7 @@ def _preproc_mask(self, arg: DataType) -> Tensor:
arg = arg.to(self.input_dtype) if self.input_dtype else arg.to(torch.float)
return arg

def _postproc_mask(self, arg: DataType) -> Tensor:
def _postproc_mask(self, arg: MaskDataType) -> MaskDataType:
if isinstance(arg, list):
new_arg = []
for a in arg:
Expand Down

0 comments on commit 7b78453

Please sign in to comment.