Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
use sync threads (#785)
Browse files Browse the repository at this point in the history
* fix sync threads

Co-authored-by: haozech <chenhaoze94@gmail.com>
  • Loading branch information
2 people authored and zhhsplendid committed Jun 9, 2022
1 parent aab8c55 commit 1d034d4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
35 changes: 21 additions & 14 deletions cinn/hlir/framework/op_lowering.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,20 +292,30 @@ void OpLowerer::ReduceCompute(poly::StageMap& stages,

CHECK_GE(value_pack.size(), 2UL);
CHECK_LE(value_pack.size(), 5UL);
Expr out = value_pack[0];
poly::StageMap tmp_stages = value_pack.back();

std::string post = "";
for (int idx = 0; idx < value_pack.size() - 1; ++idx) {
Expr expr = value_pack[idx];
stages->InsertLazily(expr.as_tensor_ref(), tmp_stages[expr.as_tensor_ref()]);
tensor_map[node_data->id() + post] = expr.as_tensor_ref();
// As op may has more than 1 output tensor, using id + "_0"/"_1" as key.
post = "_" + std::to_string(idx);
}
value_pack.back() = CINNValue(stages);

// node is kCommReduce
if (op_pattern_dict[node->op()] == framework::kCommReduce) {
reducer = node;
// do schedule
value_pack = impl->fschedule(value_pack);
} else if (group->master_nodes.count(node)) {
Expr out = value_pack[0];
// node is master node, copy schedule from reduce node
if (reducer) {
auto reducer_data = GetNodeData(reducer);
tmp_stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[reducer_data->id()]]);
tmp_stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[reducer_data->id()]]);
stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[reducer_data->id()]]);
stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[reducer_data->id()]]);
} else {
bool copied_transform = false;
for (auto rnode : group->master_nodes) {
Expand All @@ -314,24 +324,15 @@ void OpLowerer::ReduceCompute(poly::StageMap& stages,
if (!tensor_map.count(rnode_data->id())) {
continue;
}
tmp_stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[rnode_data->id()]]);
tmp_stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[rnode_data->id()]]);
stages[out.as_tensor_ref()]->CopyTransform(stages[tensor_map[rnode_data->id()]]);
stages[out.as_tensor_ref()]->CopyLoopInfo(stages[tensor_map[rnode_data->id()]]);
copied_transform = true;
break;
}
}
CHECK(copied_transform) << "master node fail to copy transfrom from reduce node!";
}
}

std::string post = "";
for (int idx = 0; idx < value_pack.size() - 1; ++idx) {
Expr expr = value_pack[idx];
stages->InsertLazily(expr.as_tensor_ref(), tmp_stages[expr.as_tensor_ref()]);
tensor_map[node_data->id() + post] = expr.as_tensor_ref();
// As op may has more than 1 output tensor, using id + "_0"/"_1" as key.
post = "_" + std::to_string(idx);
}
}
}

Expand Down Expand Up @@ -540,6 +541,12 @@ void OpLowerer::ReduceSchedule(poly::StageMap& stages,
auto master_reducer_stage = stages[tensor_map[master_reducer_data->id()]];
auto master_reducer_axes = absl::get<std::vector<int>>(master_reducer->attrs.attr_store.at("dim"));
auto master_reducer_shape = this->shape_dict_.at(master_reducer->inlinks_in_order()[0]->source()->id());
// update sync thread depend.
for (auto stage : stages) {
if (stage.first.find("syncthreads") != std::string::npos) {
stage.second->CtrlDepend(tensor_map[master_reducer_data->id() + "_0"]);
}
}

VLOG(3) << "master node : " << master_node->id() << " ,reducer node : " << master_reducer->id();
for (auto& node : sub_group->nodes) {
Expand Down
7 changes: 4 additions & 3 deletions cinn/hlir/pe/schedule.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -489,16 +489,17 @@ void CudaBlockShuffleReduceSchedule(
stages[out]->Split(0, stages[out]->GetDimRange(0));
}

stages[reshape]->ComputeInline();
stages[internal]->SetBuffer("shared");

stages[internal]->Bind(0, "blockIdx.x");
stages[internal]->Bind(1, "threadIdx.x");

stages[out]->Bind(0, "blockIdx.x");
stages[out]->Bind(1, "threadIdx.x");

stages[internal]->SimpleComputeAt(stages[out], 0);

stages[reshape]->ComputeInline();
stages[internal]->SetBuffer("shared");
stages[out]->SyncThreads(0, {internal}, stages);
}

void CudaTwoStepReduceSchedule(poly::StageMap stages,
Expand Down
Empty file modified cinn/lang/lower.cc
100644 → 100755
Empty file.
Empty file modified cinn/poly/stage.cc
100644 → 100755
Empty file.

0 comments on commit 1d034d4

Please sign in to comment.