-
Notifications
You must be signed in to change notification settings - Fork 618
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extend shape support in DALI Dataset for TF #1723
Conversation
!build |
CI MESSAGE: [1115663]: BUILD STARTED |
CI MESSAGE: [1115663]: BUILD FAILED |
@@ -94,10 +94,6 @@ PenaltyBreakString: 1000 | |||
PenaltyExcessCharacter: 1000000 | |||
PenaltyReturnTypeOnItsOwnLine: 200 | |||
PointerAlignment: Left | |||
RawStringFormats: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They changed something in clang 9 and I hoped to sneak it in.
dali_tf_plugin/dali_shape_helper.h
Outdated
|
||
#include "tensorflow/core/framework/tensor_shape.h" | ||
|
||
// namespace dali-tf-impl { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cleaned it up.
dali_tf_plugin/dali_dataset_op.cc
Outdated
@@ -227,6 +229,8 @@ class DALIDatasetOp : public DatasetOpKernel { | |||
|
|||
Status GetNextInternal(IteratorContext *context, std::vector<Tensor> *out_tensors, | |||
bool *end_of_sequence) override { | |||
// static int called_times = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Debug leftover
dali/python/nvidia/dali/plugin/tf.py
Outdated
else: | ||
return True | ||
|
||
def _handle_deprecation(self, tuple_arg, list_arg, arg_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed.
(BATCH_SIZE)] | ||
dtypes = [ | ||
shapes = ( | ||
(64, 24, 24), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So you can remove IMAGE_SIZE, but I would stick to usage of BATCH_SIZE here as it is used in other places and needs to match.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is leftover from testing, will revert.
file_list_path = os.path.join(data_path, 'image_list.txt') | ||
|
||
def dali_pipe_batch_1(shapes, types, as_single_tuple = False): | ||
target_sizes = [(853, 1280, 3), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like this strong bound between the content of DALI extra and test itself.
I would run a standalone DALI pipeline next to it and get the actual dimension from it and then compare with one returned by DALIDataset.
Otherwise any change to 'db/single/jpeg/' will break this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it would be good to apply the same shape inference rules to daliop
? If so, probably in the second PR though.
Cannot agree more. |
Replace ShapeFn with ScalarShape (as do other datasets). Implement matching between user set shape and the one returned by DALI. Adjust docs. Adjust order of arguments and naming to fith with from_generator dataset arguments in TF. Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
902718f
to
c61cd08
Compare
!build |
CI MESSAGE: [1147110]: BUILD STARTED |
This pull request fixes 1 alert when merging c61cd08 into 9e08b4d - view on LGTM.com fixed alerts:
|
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
35ce032
to
6a8b2d8
Compare
!build |
CI MESSAGE: [1147136]: BUILD STARTED |
This pull request fixes 1 alert when merging 6a8b2d8 into 9e08b4d - view on LGTM.com fixed alerts:
|
CI MESSAGE: [1147136]: BUILD PASSED |
!build |
CI MESSAGE: [1152603]: BUILD STARTED |
This pull request fixes 1 alert when merging 785e4d0 into e5ea7ef - view on LGTM.com fixed alerts:
|
CI MESSAGE: [1152603]: BUILD FAILED |
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
!build |
CI MESSAGE: [1154903]: BUILD STARTED |
This pull request fixes 1 alert when merging 0a2e14e into 3b2b02f - view on LGTM.com fixed alerts:
|
dali/c_api/c_api.h
Outdated
* stored at position `n` in the pipeline. | ||
* This function may only be called after | ||
* calling Output function. | ||
* @remarks Same as calling daliShapeAt(pipe_handle, n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean with same as calling...
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That it will give you the same result, because samples have the same type as TensorList
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eh. It was supposed to by daliTypeAt. I'm blind.
Janusz wants to leave only one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just an ask, not a strong blocker if you think that this make API consistent.
return False | ||
|
||
def _handle_deprecation(self, supported_arg, deprecated_arg, name): | ||
if deprecated_arg is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if deprecated_arg is not None: | |
if deprecated_arg is not None and supported_arg is not None: | |
raise ... |
dali/c_api/c_api.cc
Outdated
} | ||
|
||
dali_data_type_t daliTypeAtSample(daliPipelineHandle* pipe_handle, int n) { | ||
return daliTypeAt(pipe_handle, n); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the value in having daliTypeAt
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be able to access the type. And for the API to be complete
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But it is only passing through arguments to daliTypeAtSample. So I don't see much value in having two functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Damn, forgot to remove it from here.
return False | ||
|
||
def _handle_deprecation(self, supported_arg, deprecated_arg, name): | ||
if deprecated_arg is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if deprecated_arg is not None: | |
if deprecated_arg is not None and supported_arg is not None: | |
raise ... |
return False | ||
|
||
def _handle_deprecation(self, supported_arg, deprecated_arg, name): | ||
if deprecated_arg is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if deprecated_arg is not None: | |
if deprecated_arg is not None and supported_arg is not None: | |
raise ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if deprecated_arg is not None and supported_arg is not None:
raise ...
elif deprecated_arg is not None:
warn ...
I'm not sure that it's any better.
} | ||
return Status::OK(); | ||
}) | ||
.SetShapeFn(shape_inference::ScalarShape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a comment here would help to understand (why are we setting to ScalarShape?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because all the TF datasets are doing the same. I can add a comment, it's because the dataset is basically returning a opaque pointer to a Tensor instead of the Tensor itself.
dali_tf_plugin/dali_dataset_op.cc
Outdated
for (int i = 0; i < required_shape.dims(); i++) { | ||
result.AddDim(0); | ||
} | ||
// Non-trivial batch size case, diffrent dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// Non-trivial batch size case, diffrent dims | |
// Non-trivial batch size case, different dims |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
} | ||
int matches = CountShapeMatches(result, required_shape, dali_shape); | ||
if (matches != 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (matches != 1) { | |
if (matches > 1) { |
don't you mean that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When there is 0 matches we still don't have a valid match.
dali_tf_plugin/dali_dataset_op.cc
Outdated
|
||
|
||
bool DimSizeMatch(int64_t required, int64_t dali) { | ||
if (required < 0 || required == dali) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so basically:
return required < 0 || required == dali;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I keep doing that. Done
return false; | ||
} | ||
|
||
int CountShapeMatches(TensorShape &result, const PartialTensorShape &required_shape, const TensorShape &dali_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems convoluted enough that I'd like to see a bunch of unit tests covering this function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried covering most of the cases with dataset test itself. We don't have any non-python unit tests for TF plugin and if you want to add something that functionality is out of the scope for this PR - we probably would need to adjust the build system and the test jobs for this.
return false; | ||
} | ||
|
||
int CountShapeMatches(TensorShape &result, const PartialTensorShape &required_shape, const TensorShape &dali_shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is a bit misleading, you are calculating a tensor shape, not only counting matches
CI MESSAGE: [1154903]: BUILD FAILED |
Signed-off-by: Krzysztof Lecki <klecki@nvidia.com>
!build |
CI MESSAGE: [1155927]: BUILD STARTED |
This pull request fixes 1 alert when merging c2c5043 into b25864e - view on LGTM.com fixed alerts:
|
CI MESSAGE: [1155927]: BUILD PASSED |
!build |
CI MESSAGE: [1158153]: BUILD STARTED |
This pull request fixes 1 alert when merging 7983f23 into b48e757 - view on LGTM.com fixed alerts:
|
CI MESSAGE: [1158153]: BUILD PASSED |
Signed-off-by: Krzysztof Lecki klecki@nvidia.com
Why we need this PR?
Improve the DALIDataset for TF.
What happened in this PR?
Fill relevant points, put NA otherwise. Replace anything inside []
[
Replace ShapeFn with ScalarShape (as do other datasets).
Implement matching between user set shape and
the one returned by DALI in the C++ api of Dataset.
]
[ DALIDataset for TF ]
[
Can the API be changed. I can revert or add a wrapper.
]
[ Test was added, some additional local test. ]
[ Docs were adjusted, now it matches the from_generator approach more. ]
JIRA TASK: [DALI-1259]