Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend shape support in DALI Dataset for TF #1723

Merged
merged 8 commits into from
Mar 5, 2020

Conversation

klecki
Copy link
Contributor

@klecki klecki commented Feb 7, 2020

Signed-off-by: Krzysztof Lecki klecki@nvidia.com

Why we need this PR?

Improve the DALIDataset for TF.

What happened in this PR?

Fill relevant points, put NA otherwise. Replace anything inside []

  • What solution was applied:
    [
    Replace ShapeFn with ScalarShape (as do other datasets).
    Implement matching between user set shape and
    the one returned by DALI in the C++ api of Dataset.
    ]
  • Affected modules and functionalities:
    [ DALIDataset for TF ]
  • Key points relevant for the review:
    [
    Can the API be changed. I can revert or add a wrapper.
    ]
  • Validation and testing:
    [ Test was added, some additional local test. ]
  • Documentation (including examples):
    [ Docs were adjusted, now it matches the from_generator approach more. ]

JIRA TASK: [DALI-1259]

@klecki
Copy link
Contributor Author

klecki commented Feb 7, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1115663]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1115663]: BUILD FAILED

@@ -94,10 +94,6 @@ PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 200
PointerAlignment: Left
RawStringFormats:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They changed something in clang 9 and I hoped to sneak it in.


#include "tensorflow/core/framework/tensor_shape.h"

// namespace dali-tf-impl {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaned it up.

@@ -227,6 +229,8 @@ class DALIDatasetOp : public DatasetOpKernel {

Status GetNextInternal(IteratorContext *context, std::vector<Tensor> *out_tensors,
bool *end_of_sequence) override {
// static int called_times = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug leftover

else:
return True

def _handle_deprecation(self, tuple_arg, list_arg, arg_name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.

(BATCH_SIZE)]
dtypes = [
shapes = (
(64, 24, 24),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you can remove IMAGE_SIZE, but I would stick to usage of BATCH_SIZE here as it is used in other places and needs to match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is leftover from testing, will revert.

file_list_path = os.path.join(data_path, 'image_list.txt')

def dali_pipe_batch_1(shapes, types, as_single_tuple = False):
target_sizes = [(853, 1280, 3),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like this strong bound between the content of DALI extra and test itself.
I would run a standalone DALI pipeline next to it and get the actual dimension from it and then compare with one returned by DALIDataset.
Otherwise any change to 'db/single/jpeg/' will break this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

@awolant awolant left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would be good to apply the same shape inference rules to daliop? If so, probably in the second PR though.

@JanuszL
Copy link
Contributor

JanuszL commented Feb 12, 2020

Maybe it would be good to apply the same shape inference rules to daliop? If so, probably in the second PR though.

Cannot agree more.

Replace ShapeFn with ScalarShape (as do other datasets).
Implement matching between user set shape and
the one returned by DALI.
Adjust docs.
Adjust order of arguments and naming
to fith with from_generator dataset arguments in TF.

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Feb 24, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1147110]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Feb 24, 2020

This pull request fixes 1 alert when merging c61cd08 into 9e08b4d - view on LGTM.com

fixed alerts:

  • 1 for Asserting a tuple

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Feb 24, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1147136]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Feb 24, 2020

This pull request fixes 1 alert when merging 6a8b2d8 into 9e08b4d - view on LGTM.com

fixed alerts:

  • 1 for Asserting a tuple

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1147136]: BUILD PASSED

@klecki
Copy link
Contributor Author

klecki commented Feb 26, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1152603]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Feb 26, 2020

This pull request fixes 1 alert when merging 785e4d0 into e5ea7ef - view on LGTM.com

fixed alerts:

  • 1 for Asserting a tuple

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1152603]: BUILD FAILED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Feb 27, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1154903]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Feb 27, 2020

This pull request fixes 1 alert when merging 0a2e14e into 3b2b02f - view on LGTM.com

fixed alerts:

  • 1 for Asserting a tuple

* stored at position `n` in the pipeline.
* This function may only be called after
* calling Output function.
* @remarks Same as calling daliShapeAt(pipe_handle, n)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean with same as calling... ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That it will give you the same result, because samples have the same type as TensorList

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eh. It was supposed to by daliTypeAt. I'm blind.
Janusz wants to leave only one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an ask, not a strong blocker if you think that this make API consistent.

return False

def _handle_deprecation(self, supported_arg, deprecated_arg, name):
if deprecated_arg is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if deprecated_arg is not None:
if deprecated_arg is not None and supported_arg is not None:
raise ...

}

dali_data_type_t daliTypeAtSample(daliPipelineHandle* pipe_handle, int n) {
return daliTypeAt(pipe_handle, n);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the value in having daliTypeAt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be able to access the type. And for the API to be complete

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it is only passing through arguments to daliTypeAtSample. So I don't see much value in having two functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Damn, forgot to remove it from here.

return False

def _handle_deprecation(self, supported_arg, deprecated_arg, name):
if deprecated_arg is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if deprecated_arg is not None:
if deprecated_arg is not None and supported_arg is not None:
raise ...

return False

def _handle_deprecation(self, supported_arg, deprecated_arg, name):
if deprecated_arg is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if deprecated_arg is not None:
if deprecated_arg is not None and supported_arg is not None:
raise ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if deprecated_arg is not None and supported_arg is not None:
    raise ...
elif deprecated_arg is not None:
    warn ...

I'm not sure that it's any better.

}
return Status::OK();
})
.SetShapeFn(shape_inference::ScalarShape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a comment here would help to understand (why are we setting to ScalarShape?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because all the TF datasets are doing the same. I can add a comment, it's because the dataset is basically returning a opaque pointer to a Tensor instead of the Tensor itself.

for (int i = 0; i < required_shape.dims(); i++) {
result.AddDim(0);
}
// Non-trivial batch size case, diffrent dims
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Non-trivial batch size case, diffrent dims
// Non-trivial batch size case, different dims

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
}
int matches = CountShapeMatches(result, required_shape, dali_shape);
if (matches != 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (matches != 1) {
if (matches > 1) {

don't you mean that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When there is 0 matches we still don't have a valid match.



bool DimSizeMatch(int64_t required, int64_t dali) {
if (required < 0 || required == dali) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so basically:

return required < 0 || required == dali;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I keep doing that. Done

return false;
}

int CountShapeMatches(TensorShape &result, const PartialTensorShape &required_shape, const TensorShape &dali_shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems convoluted enough that I'd like to see a bunch of unit tests covering this function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried covering most of the cases with dataset test itself. We don't have any non-python unit tests for TF plugin and if you want to add something that functionality is out of the scope for this PR - we probably would need to adjust the build system and the test jobs for this.

return false;
}

int CountShapeMatches(TensorShape &result, const PartialTensorShape &required_shape, const TensorShape &dali_shape,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name is a bit misleading, you are calculating a tensor shape, not only counting matches

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1154903]: BUILD FAILED

Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
@klecki
Copy link
Contributor Author

klecki commented Feb 27, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1155927]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Feb 27, 2020

This pull request fixes 1 alert when merging c2c5043 into b25864e - view on LGTM.com

fixed alerts:

  • 1 for Asserting a tuple

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1155927]: BUILD PASSED

@klecki
Copy link
Contributor Author

klecki commented Feb 28, 2020

!build

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1158153]: BUILD STARTED

@lgtm-com
Copy link
Contributor

lgtm-com bot commented Feb 28, 2020

This pull request fixes 1 alert when merging 7983f23 into b48e757 - view on LGTM.com

fixed alerts:

  • 1 for Asserting a tuple

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [1158153]: BUILD PASSED

@klecki klecki merged commit 58a8fd8 into NVIDIA:master Mar 5, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants