Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 committed Dec 4, 2022
1 parent 48c0ae6 commit 9059fec
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/inference/tensorrt/convert/one_hot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class OneHotOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
#if IS_TRT_VERSION_GE(8510)
VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer";
framework::OpDesc op_desc(op, nullptr);

Expand Down Expand Up @@ -64,19 +65,22 @@ class OneHotOpConverter : public OpConverter {
} else {
nvinfer1::Dims depth_dims;
depth_dims.nbDims = 0;
const nvinfer1::ITensor* depth_tensor_paddle =
nvinfer1::ITensor* depth_tensor_paddle =
engine_->GetITensor(depth_name.front());
auto shuffle_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *depth_tensor_paddle);
shuffle_layer->setReshapeDimensions(depth_dims);
shuffle_layer->getOutput(0)->setName(depth_tensor_paddle->getName());
depth_tensor = shuffle_layer->getOutput(0);
depth_tensor->setName(depth_tensor_paddle->getName());
}
auto layer = TRT_ENGINE_ADD_LAYER(
engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1);

auto output_name = op_desc.Output("Out").front();
RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode);
#else
VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1";
#endif
}
};

Expand Down

0 comments on commit 9059fec

Please sign in to comment.