Skip to content

Commit

Permalink
Fix xla test case fail (#5203)
Browse files Browse the repository at this point in the history
* refactor SbpXXX to cfg::SbpXXX

* modify ParallelDistributionHint4InputArgNameAndIndex to be const function

* fix sbp to cfg::sbp in job_pass

* fix bug ToProto, InitFromProto and pb passed to cfg

* auto format by CI

* fix gpt segment fault

* fix xla

* tmp commit

* fix xla compile error

* fix test case fail

* auto format by CI

* refine

Co-authored-by: lixinqi <lixinqi0703106@163.com>
Co-authored-by: liufengwei <2472937968@qq.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <ci-bot@oneflow.org>
  • Loading branch information
5 people committed Jun 16, 2021
1 parent 3af1b27 commit 90d3277
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
5 changes: 4 additions & 1 deletion oneflow/xrt/launch_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,11 @@ xrt::Executable* XrtLaunchKernel<device_type>::BuildExecutable(
auto options = xrt::CreateDefaultXrtPassOptions();
xrt::util::PbMap<std::string, cfg::SbpSignature> cfg_sbp_signatures;
for (auto& pair : sbp_signatures) { cfg_sbp_signatures.insert({pair.first, pair.second}); }
const xrt::util::PbMap<std::string, cfg::SbpSignature>* const_cfg_sbp_signatures_ptr =
&cfg_sbp_signatures;
xrt::RunXrtPass("InferShape", graph.get(), options, &this->job_desc(), &parallel_ctx,
&parallel_desc, &sbp_signatures, &lbn2logical_blob_desc, &entry_blob_descs);
&parallel_desc, const_cfg_sbp_signatures_ptr, &lbn2logical_blob_desc,
&entry_blob_descs);
// Update argument meta data
// xrt::RunXrtPass("UpdateArgMetaData", graph.get(), options,
// &this->job_desc());
Expand Down
5 changes: 4 additions & 1 deletion oneflow/xrt/launch_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ Maybe<void> XrtLaunchOp::InferOutBlobDescs(
DeviceType device_type = JUST(DeviceType4DeviceTag(op_conf().device_tag()));
auto graph = xrt::BuildXrtGraph(launch_conf.function(), device_type, GlobalJobDesc());
const ParallelDesc& op_parallel_desc = *JUST(GetOpParallelDesc());
const xrt::util::PbMap<std::string, cfg::SbpSignature>* const_cfg_sbp_signatures_ptr =
&cfg_sbp_signatures;
xrt::RunXrtPass("InferShape", graph.get(), options, &GlobalJobDesc(), parallel_ctx,
&op_parallel_desc, &cfg_sbp_signatures, &lbn2logical_blob_desc, &blob_descs);
&op_parallel_desc, const_cfg_sbp_signatures_ptr, &lbn2logical_blob_desc,
&blob_descs);
}

// Fetch output blob descs
Expand Down

0 comments on commit 90d3277

Please sign in to comment.