Skip to content

Commit

Permalink
Fix CHECK-fail due to passing invalid tensors in SobolOp.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 460794378
  • Loading branch information
mihaimaruseac authored and tensorflower-gardener committed Jul 13, 2022
1 parent 7d3ed0f commit c65c67f
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions tensorflow/core/kernels/sobol_op.cc
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "sobol_data.h" // from @sobol_data
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/platform_strings.h"

Expand Down Expand Up @@ -134,8 +135,14 @@ class SobolSampleOp : public OpKernel {
: OpKernel(context) {}

void Compute(OpKernelContext* context) override {
OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(0).shape()),
errors::InvalidArgument("dim must be a scalar"));
int32_t dim = context->input(0).scalar<int32_t>()();
OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(0).shape()),
errors::InvalidArgument("num_results must be a scalar"));
int32_t num_results = context->input(1).scalar<int32_t>()();
OP_REQUIRES(context, TensorShapeUtils::IsScalar(context->input(0).shape()),
errors::InvalidArgument("skip must be a scalar"));
int32_t skip = context->input(2).scalar<int32_t>()();

OP_REQUIRES(context, dim >= 1,
Expand Down

0 comments on commit c65c67f

Please sign in to comment.