Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

AugmentationSequential: accept sample dict as input #2119

Closed
adamjstewart opened this issue Dec 28, 2022 · 11 comments 路 Fixed by #2799
Closed

AugmentationSequential: accept sample dict as input #2119

adamjstewart opened this issue Dec 28, 2022 · 11 comments 路 Fixed by #2799
Labels
help wanted Extra attention is needed

Comments

@adamjstewart
Copy link
Contributor

adamjstewart commented Dec 28, 2022

馃殌 Feature

I would like AugmentationSequential to support a dictionary as input.

Motivation

I'm a TorchGeo developer. In TorchGeo, every dataset returns a sample dictionary like so:

sample = {
    "input": torch.tensor(...),
    "mask": torch.tensor(...),
    "bbox": torch.tensor(...),
    ...
}

(the exact key names don't match at the moment, but we're working on standardizing those)

Pitch

With the feature I'm envisioning, the following would work:

augs = AugmentationSequential(...)
sample = augs(sample)

The exact implementation details would still need to be worked out, but *args would go from Tensor to Union[Tensor, Dict[str, Tensor]]. The dictionary may contain keys that Kornia doesn't know how to support, and these should be ignored. If a sample dictionary contains a known key that the user doesn't want to transform, they can simply pass data_keys to override the default detection. If the input is a dict, the output should also be a dict. If implemented correctly, this feature will be backwards compatible with the old behavior so people can still pass these inputs in manually if they want to.

Alternatives

At the moment, to use Kornia augmentations, we have to use:

augs = AugmentationSequential(..., data_keys=["input", "mask", "bbox", ...])
sample["input"], sample["mask"], sample["bbox"], ... = augs(sample["input"], sample["mask"], sample["bbox"], ...)

As you can see, this is much more verbose than necessary. There's no reason we need to duplicate the list of keys so many times.

Additional context

If this is something you would be interested in, I would be happy to submit a PR to support this. Just wanted to gauge interest first before working on it.

@isaaccorley

@adamjstewart adamjstewart added the help wanted Extra attention is needed label Dec 28, 2022
@shijianjian
Copy link
Member

It looks acceptable if you have decent solution, like using override. If we accept dict, I would like to support something like:

augs = AugmentationSequential(..., data_keys=NONE)  # so that we do not require to input datakeys in the first place
sample = augs(sample...)

Another question is that is we have multiple masks/bboxes for the input, how should we organize the input dict?

@adamjstewart
Copy link
Contributor Author

adamjstewart commented Dec 29, 2022

It looks acceptable if you have decent solution, like using override.

Do you mean "overload" instead of "override"? Python doesn't support overloading functions, but you can overload their type hints. It should be possible for the function to be backwards compatible and still support dicts.

Another question is that is we have multiple masks/bboxes for the input, how should we organize the input dict?

I would not bother supporting this for dict input. If a user needs multiple input/mask/bbox, they can simply use the old syntax.

@shijianjian
Copy link
Member

We can proceed the dict support with overload as our first attempt. But preferably, I would like to see something like a dict with duplicated keys, or the same key with different values. XD

It looks like the latter is much doable, but I do not want to use list values (since a user can pass a list of tensors).

@johnnv1
Copy link
Member

johnnv1 commented Jan 2, 2023

looking at the fact that we possibly wouldn't have many duplicate values, I don't think it would be a performance problem to match any case with, for example: mask-1, mask-2, maskAnything, ... -- match any mask* case

@shijianjian
Copy link
Member

The match any strategy looks reasonable to me. Then a user can mark their inputs as mask-taskA, mask-taskB, etc. We can be more flexible comparing to other libs.

I will try finish the first refactor #2117 this week. Then start implementing this feature on top of the new base. What do you think? @adamjstewart

@edgarriba
Copy link
Member

i also envision something like this in the midterm

aug = AugmentationPipeline(...)
gen = DataGenerator(...)
net = MyModel(...)
loss = MyGeometricLoss(...)

gen.output_dict >> aug.input
(gen.image | gen.mask) >> aug.input  # or this to be more selective
aug.output >> net.image
(net.output | aug.tranform_matrix) >>  loss.input

this is kinda the strategy behind https://github.com/kornia/limbus
which is now under heavy refactor to fully support asyncio

@adamjstewart idk if something like that would be interesting for you guys
/cc @lferraz

@adamjstewart
Copy link
Contributor Author

I tried to take a stab at implementing this but supporting data_keys = None and supporting dicts as input/output of inverse/forward is actually quite a large refactor. I can try to force it in there, but it might be better for the person who originally designed it to redesign it with dicts in mind. Do any core developers have any interest in this, or should I try to redesign it myself and minimize any unrequired API changes?

@edgarriba
Copy link
Member

I tried to take a stab at implementing this but supporting data_keys = None and supporting dicts as input/output of inverse/forward is actually quite a large refactor. I can try to force it in there, but it might be better for the person who originally designed it to redesign it with dicts in mind. Do any core developers have any interest in this, or should I try to redesign it myself and minimize any unrequired API changes?

This person is @shijianjian , we just did a huge refactor to the augmentations module, so makes sense to keep improving to support more features

@shijianjian
Copy link
Member

I think we may totally ignore the data_keys argument here, just to overwrite the runtime data_keys.

augs = AugmentationSequential(...)
sample = augs({"image-a": imagea, "image-b": imageb})

In the implementation, the forward signature changed to:

def forward(self, *input: Union[PREVIOUS_TYPE, Dict[str, Tensor]], data_keys: Optional=None) -> Union[PREVIOUS_TYPE, Dict[str, Tensor]]:
      if len(input) == 1 and isinstance(input, dict):
           if data_keys is not None: raise Error
           data_keys = read_datakeys_from_dict(input[0])

      input = self._preproc_dict(...)
      RUN_ITERATIONS_HERE
      output = self._postproc_dict(...)

     return output

This shall be straight-forward to implement.

@adamjstewart
Copy link
Contributor Author

@shijianjian do you want to submit a PR for this or should I try to hack on it? I have something with 10x as many lines of code changed and it still doesn't work because I'm mapping dicts to lists but I'm not yet mapping them back. It's not just data_keys that needs to change, it's also transform_op (which uses data_keys) that will need to be set dynamically. I can open a draft PR if you want to see my current solution but it's pretty ugly.

@shijianjian
Copy link
Member

@adamjstewart I may not be working on this since it is not a critical feature. The only thing you need to implement is to map the data keys. You do not need to handle the transform_op, it shall read the input data_keys and perform the augmentations automatically.

def forward(self, *input: Union[PREVIOUS_TYPE, Dict[str, Tensor]], data_keys: Optional=None) -> Union[PREVIOUS_TYPE, Dict[str, Tensor]]:
      if len(input) == 1 and isinstance(input, dict):
           if data_keys is not None: raise Error
           data_keys = read_datakeys_from_dict(input[0])

      input = self._preproc_dict(...)
      RUN_ITERATIONS_HERE
      output = self._postproc_dict(...)

     return output

As shown here, having the implementation of read_datakeys_from_dict, _preproc_dict, _postproc_dict shall be enough to make it work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants