diff --git a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc index a2172f7be0c7f..cb5afb491fee5 100644 --- a/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/one_hot_op.cc @@ -36,6 +36,7 @@ class OneHotOpConverter : public OpConverter { void operator()(const framework::proto::OpDesc& op, const framework::Scope& scope, bool test_mode) override { +#if IS_TRT_VERSION_GE(8510) VLOG(3) << "convert a fluid one_hot op to tensorrt one_hot layer"; framework::OpDesc op_desc(op, nullptr); @@ -64,19 +65,22 @@ class OneHotOpConverter : public OpConverter { } else { nvinfer1::Dims depth_dims; depth_dims.nbDims = 0; - const nvinfer1::ITensor* depth_tensor_paddle = + nvinfer1::ITensor* depth_tensor_paddle = engine_->GetITensor(depth_name.front()); auto shuffle_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *depth_tensor_paddle); shuffle_layer->setReshapeDimensions(depth_dims); + shuffle_layer->getOutput(0)->setName(depth_tensor_paddle->getName()); depth_tensor = shuffle_layer->getOutput(0); - depth_tensor->setName(depth_tensor_paddle->getName()); } auto layer = TRT_ENGINE_ADD_LAYER( engine_, OneHot, *indices_tensor, *values_tensor, *depth_tensor, -1); auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(layer, "one_hot", {output_name}, test_mode); +#else + VLOG(3) << "one_hot is not supported when TensorRT < 8.5.1"; +#endif } };