Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN][New Hardware Update] standardize CINN_WITH_CUDA #64506

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 73 additions & 71 deletions paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1072,77 +1072,79 @@ bool CasSimplifyMutator::SimplifySpecificSumMod(Expr* result, Expr a, Expr b) {
}
}
}
#ifdef CINN_WITH_CUDA
return false;
#else

int const_value = 0;
Expr lower_bound;
Expr upper_bound;
Expr rest_oper;
bool can_simplify = true;
bool has_int = false;
// fold only the expr bound(may contains the var) and try to simplify the var
Expr unfolded_lower_bound, unfolded_upper_bound;
for (Expr& v : a_sum->operands()) {
auto* v_int = v.As<IntImm>();
if (v_int) {
const_value += v_int->value;
has_int = true;
} else if (GetVarBound(&lower_bound, &upper_bound, v, false)) {
AddBaseAndSimplify(&rest_oper, v);
} else {
can_simplify = false;
break;
}
}
can_simplify = can_simplify && has_int &&
std::abs(const_value) % b_i->value == b_i->value - 1 &&
lower_bound.defined() && upper_bound.defined() &&
rest_oper.defined();
// further infer the vars' bound by the intervals infos, try to get the
// constant
if (can_simplify) {
std::vector<Expr> bounds = {lower_bound, upper_bound};
for (int i = 0; i < bounds.size(); ++i) {
Expr bound = bounds[i];
Expr bound_l, bound_r;
GetExprBound(&bound_l, &bound_r, bound);
if (i == 0 && bound_l.defined()) {
lower_bound = bound_l;
}
if (i == 1 && bound_r.defined()) {
upper_bound = bound_r;
}
}
} else {
return false;
}
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32)
can_simplify = can_simplify && lower_bound.is_constant();
bool case1 = can_simplify && const_value >= 0 &&
lower_bound.get_constant() >= -const_value &&
upper_bound.is_constant() && upper_bound.get_constant() <= 0;
bool case2 = can_simplify && const_value <= 0 &&
lower_bound.get_constant() >= 0 && upper_bound.is_constant() &&
upper_bound.get_constant() <= -const_value;
can_simplify = can_simplify && (case1 || case2);
if (can_simplify) {
Expr const_expr;
if (const_value < 0) {
const_expr = make_const(b->type(), const_value % b_i->value);
} else {
const_expr = make_const(b->type(), const_value % b_i->value);
}
*result = CasSimplify(
Sum::Make(
{const_expr, CasSimplify(Mod::Make(rest_oper, b), var_intervals)}),
var_intervals);
return true;
}
return false;
#endif
return cinn::common::DefaultDeviceTarget().arch.Match(
[&](common::NVGPUArch) { return false; },
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
int const_value = 0;
Expr lower_bound;
Expr upper_bound;
Expr rest_oper;
bool can_simplify = true;
bool has_int = false;
// fold only the expr bound(may contains the var) and try to simplify
// the var
Expr unfolded_lower_bound, unfolded_upper_bound;
for (Expr& v : a_sum->operands()) {
auto* v_int = v.As<IntImm>();
if (v_int) {
const_value += v_int->value;
has_int = true;
} else if (GetVarBound(&lower_bound, &upper_bound, v, false)) {
AddBaseAndSimplify(&rest_oper, v);
} else {
can_simplify = false;
break;
}
}
can_simplify = can_simplify && has_int &&
std::abs(const_value) % b_i->value == b_i->value - 1 &&
lower_bound.defined() && upper_bound.defined() &&
rest_oper.defined();
// further infer the vars' bound by the intervals infos, try to get the
// constant
if (can_simplify) {
std::vector<Expr> bounds = {lower_bound, upper_bound};
for (int i = 0; i < bounds.size(); ++i) {
Expr bound = bounds[i];
Expr bound_l, bound_r;
GetExprBound(&bound_l, &bound_r, bound);
if (i == 0 && bound_l.defined()) {
lower_bound = bound_l;
}
if (i == 1 && bound_r.defined()) {
upper_bound = bound_r;
}
}
} else {
return false;
}
// case1: (32+(-x))%33 = 32-x%33 (0<=x<=32)
// case2: (x-32)%33 = x%33 - 32%33 (0<=x<=32)
can_simplify = can_simplify && lower_bound.is_constant();
bool case1 = can_simplify && const_value >= 0 &&
lower_bound.get_constant() >= -const_value &&
upper_bound.is_constant() &&
upper_bound.get_constant() <= 0;
bool case2 = can_simplify && const_value <= 0 &&
lower_bound.get_constant() >= 0 &&
upper_bound.is_constant() &&
upper_bound.get_constant() <= -const_value;
can_simplify = can_simplify && (case1 || case2);
if (can_simplify) {
Expr const_expr;
if (const_value < 0) {
const_expr = make_const(b->type(), const_value % b_i->value);
} else {
const_expr = make_const(b->type(), const_value % b_i->value);
}
*result = CasSimplify(
Sum::Make({const_expr,
CasSimplify(Mod::Make(rest_oper, b), var_intervals)}),
var_intervals);
return true;
}
return false;
});
}

// Return if the var's interval is nonnegative.
Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,15 @@ void Graph::VisualizeGroupedGraph(
for (int idx = 0; idx < groups.size(); ++idx) {
// Create fusion_group_x folder
int device_id = 0;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
cudaGetDevice(&device_id);
cudaGetDevice(&device_id);
#endif
});
auto group_path =
utils::StringFormat("%s/device_%d/fusion_group_%d",
FLAGS_cinn_fusion_groups_graphviz_dir.c_str(),
Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/graph_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -428,9 +428,15 @@ std::vector<ir::LoweredFunc> GetFuncFromImpl(
CHECK_EQ(funcs_after_schedule.size(), expr_pack.size());
std::vector<ir::LoweredFunc> res;
for (int i = 0; i < funcs_after_schedule.size(); i++) {
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
optim::OptimizeExprGPU(&(funcs_after_schedule[i]->body));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然ifdef CINN_WITH_CUDA都已经放到common::NVGPUArch这个alternative下了。最好把ifdef的else写全。

#ifdef CINN_WITH_CUDA
  ...
#else
  CINN_NOT_IMPLEMENTED();
#endif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下同。

#endif
});
auto temp_buffers = lang::GetTempBuffers(
all_arg_tensors, tensor_group, funcs_after_schedule[i]->body);

Expand Down
7 changes: 6 additions & 1 deletion paddle/cinn/hlir/framework/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -438,9 +438,14 @@ std::string Instruction::DumpInstruction() const {

void Instruction::CheckResults(
const std::map<std::string, cinn_pod_value_t>* name2podargs, void* stream) {
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
cudaStreamSynchronize(static_cast<cudaStream_t>(stream));
cudaStreamSynchronize(static_cast<cudaStream_t>(stream));
#endif
});

if (fn_names_.size() == 1) {
std::unordered_set<std::string> skipped_instr_set = {
Expand Down
8 changes: 7 additions & 1 deletion paddle/cinn/hlir/framework/instruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,15 @@ class Instruction {
((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
pod_args.size());
}
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaDeviceSynchronize());
CUDA_CALL(cudaDeviceSynchronize());
#endif
});
}
}
if (flag >= 0) {
Expand Down
11 changes: 8 additions & 3 deletions paddle/cinn/hlir/framework/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
}

auto func_body = ir_sch->GetModule().GetExprs().at(0);
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
if (apply_pass) {
optim::OptimizeExprGPU(&(func_body));
}
if (apply_pass) {
optim::OptimizeExprGPU(&(func_body));
}
#endif
});
// 2.Prepare temp buffers
poly::StageMap stages;
auto temp_buffers =
Expand Down
14 changes: 12 additions & 2 deletions paddle/cinn/hlir/framework/parallel_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,14 @@ void ParallelCompiler::SplitTask() {
context_->graph->fusion_groups.size() ==
context_->lowered_funcs.size());
int device_id = 0;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaGetDevice(&device_id));
CUDA_CALL(cudaGetDevice(&device_id));
#endif
});
for (int group_id = 0; group_id < context_->graph->fusion_groups.size();
++group_id) {
tasks_.emplace_back(device_id, group_id, this, context_);
Expand Down Expand Up @@ -132,9 +137,14 @@ void ParallelCompiler::RunTask() {

void ParallelCompiler::LaunchTask() {
int device_id = 0;
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
CUDA_CALL(cudaGetDevice(&device_id));
CUDA_CALL(cudaGetDevice(&device_id));
#endif
});
int num_threads = FLAGS_cinn_parallel_compile_thread;
#if defined(PADDLE_WITH_DISTRIBUTE)
if (device_id > 0) {
Expand Down
10 changes: 8 additions & 2 deletions paddle/cinn/hlir/framework/pir/op_lowering_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -816,10 +816,16 @@ std::vector<ir::LoweredFunc> OpLowererImpl::PostProcess(
std::vector<ir::LoweredFunc> lowered_funcs;
for (ir::Expr func_body : func_bodies) {
optim::EliminateDeadScheduleBlock(&(func_body), group->output_names());
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch,
common::X86Arch,
common::ARMArch>) {},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
optim::EliminateCommonGlobalMemoryRead(&(func_body));
optim::OptimizeExprGPU(&(func_body));
#endif
});

// 2.Prepare temp buffers
auto temp_buffers =
Expand Down
23 changes: 19 additions & 4 deletions paddle/cinn/ir/group_schedule/st_shape_group_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,22 +114,37 @@ void StaticShapeGroupScheduler::Schedule() {
&StaticShapeGroupScheduler::IsKeepGraphDependency);
DoLoopAlignment();
DoComputeInline();
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
OptimizeReduction();
OptimizeReduction();
#endif
});
DoHorizontalLoopFusion();
DoVerticalLoopFusion();
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
BindCudaAxis();
AllocateStorage();
BindCudaAxis();
AllocateStorage();
#endif
});
}

void StaticShapeGroupScheduler::MapExprSchedule() {
DoComputeInline();
cinn::common::DefaultDeviceTarget().arch.Match(
[&](std::variant<common::UnknownArch, common::X86Arch, common::ARMArch>) {
},
[&](common::NVGPUArch) {
#ifdef CINN_WITH_CUDA
AllocateStorage();
AllocateStorage();
#endif
});
}

std::vector<std::pair<SymbolicPredicate, ir::Expr>>
Expand Down
Loading