Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
012c8e4
support async copy for if_then_else
flow-matic-brewster Feb 12, 2023
d700b21
add comment and trigger PrintPredicatedCpAsyncAssembly in codegen_cud…
flow-matic-brewster Feb 12, 2023
f055ae5
reformat code
flow-matic-brewster Feb 12, 2023
936e7ad
Merge branch 'apache:main' into main
cblmemo Feb 12, 2023
96224f2
add zfill support & comment
flow-matic-brewster Feb 13, 2023
d22a8ca
add unittest
flow-matic-brewster Feb 13, 2023
587d12d
Merge branch 'apache:main' into main
cblmemo Feb 14, 2023
516af3b
reformat unittest
flow-matic-brewster Feb 14, 2023
bca2c63
Update src/target/source/codegen_cuda.cc
cblmemo Feb 14, 2023
6d6f072
Update src/tir/transforms/inject_ptx_async_copy.cc
cblmemo Feb 14, 2023
bfee9f8
add compute version check
flow-matic-brewster Feb 14, 2023
7de4d88
fix & reformat inject_ptx_async_copy.cc
flow-matic-brewster Feb 14, 2023
0552891
update unittest using a small example
flow-matic-brewster Feb 14, 2023
494503f
add gemm integration test
flow-matic-brewster Feb 14, 2023
59c87dd
add correctness test for integration test
flow-matic-brewster Feb 14, 2023
8cf7c79
update test script to support device < sm80
flow-matic-brewster Feb 16, 2023
07fcc4f
fix unittest
flow-matic-brewster Feb 16, 2023
1c29863
fix unittest dummy code function name
flow-matic-brewster Feb 16, 2023
f5f606f
reformat unittest
flow-matic-brewster Feb 16, 2023
28b866f
add license and doc string
flow-matic-brewster Feb 16, 2023
1de7625
reformat inject_ptx_async_copy.cc
flow-matic-brewster Feb 16, 2023
57f0193
reformat codegen_cuda.cc
flow-matic-brewster Feb 16, 2023
202fead
fix global syntax error in unittest
flow-matic-brewster Feb 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,7 +914,13 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
std::string src = this->PrintExpr(op->args[2]);
std::string src_offset = this->PrintExpr(op->args[3]);
std::string size = this->PrintExpr(op->args[4]);
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
// use size of argument list to indicate whether or not to use predicated cp.async
if (op->args.size() == 5) {
this->stream << PrintCpAsyncAssembly(dst, dst_offset, src, src_offset, size);
} else {
this->stream << PrintPredicatedCpAsyncAssembly(dst, dst_offset, src, src_offset, size,
this->PrintExpr(op->args[5]));
}
} else if (op->op.same_as(builtin::ptx_commit_group())) {
this->stream << "__asm__ __volatile__(\"cp.async.commit_group;\");\n\n";
} else if (op->op.same_as(builtin::ptx_wait_group())) {
Expand Down
31 changes: 31 additions & 0 deletions src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,5 +659,36 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
return asm_code;
}

std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value) {
std::string predicated_asm_code = R"(
{
unsigned int addr;
__asm__ __volatile__(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
: "=r"(addr)
: "l"((void *)({smem_addr}))
);
int src_bytes = {pred_guard} ? {bytes} : 0;
__asm__ __volatile__(
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2, %3;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}), "r"(src_bytes)
);
}
)";
Replacer replacer;
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
replacer.register_rule("{pred_guard}", predicate_value);
predicated_asm_code = replacer.rewrite(predicated_asm_code);
return predicated_asm_code;
}

} // namespace codegen
} // namespace tvm
16 changes: 16 additions & 0 deletions src/target/source/ptx.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,22 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
const std::string& global_ptr,
const std::string& global_elem_offset, const std::string& bytes);

/*!
* \brief Print predicated ptx cp.async assembly string given parameters.
* \param shared_ptr: The pointer to the destination shared memory.
* \param shared_elem_offset: The offset into the shared memory.
* \param global_ptr: The pointer to the global memory.
* \param global_elem_offset: The offset into the global memory.
* \param bytes: The number of bytes to copy, valid values are 4, 8, and 16.
* \param predicate_value: The value of predicate `@p`.
*/
std::string PrintPredicatedCpAsyncAssembly(const std::string& shared_ptr,
const std::string& shared_elem_offset,
const std::string& global_ptr,
const std::string& global_elem_offset,
const std::string& bytes,
const std::string& predicate_value);

} // namespace codegen
} // namespace tvm

Expand Down
156 changes: 94 additions & 62 deletions src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,73 +47,105 @@ class PTXAsyncCopyInjector : public StmtMutator {
return StmtMutator::VisitStmt_(attr);
}

Stmt InjectPTX(const BufferLoadNode* load, const BufferStoreNode* store, bool predicated = false,
PrimExpr predicate_value = PrimExpr()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());

const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();

if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type annotation.";

int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
index_factor = src_elem_type->bytes();
}

if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)};
// use arguments size to indicate whether or not to use predicated cp.async
if (predicated) {
args.push_back(predicate_value);
}
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), args));
}

// Predicated load don't support vectorized indexing.
if (!predicated) {
// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
}
}
}
return StmtMutator::VisitStmt_(store);
}

Stmt VisitStmt_(const BufferStoreNode* store) {
if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) {
if (auto* load = store->value.as<BufferLoadNode>()) {
if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());

const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();

if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.has_value() && src_elem_type.has_value())
<< "Both store and load buffer should have a pointer type annotation.";

int index_factor = 1;
if (dst_elem_type.value() != src_elem_type.value()) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
ICHECK(dst_elem_type.value() == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
index_factor = src_elem_type->bytes();
return InjectPTX(load, store);
} else if (auto* call = store->value.as<CallNode>()) {
// tir.if_then_else is a call to tir::builtin::if_then_else()
if (call->op.same_as(builtin::if_then_else()) && call->args.size() == 3) {
if (auto* load = call->args[1].as<BufferLoadNode>()) {
// Only default value of 0 is supported since 0 is the default value used by cp.async
// ptx. @see section 9.7.8.22.3. of
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-memory-operations
bool else_value_is_zero = false;
if (auto* b = call->args[2].as<BroadcastNode>()) {
if (auto* f = b->value.as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
}
}

if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
if (auto* f = call->args[2].as<FloatImmNode>()) {
else_value_is_zero = f->value == 0.0f;
}

// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
if (else_value_is_zero) {
return InjectPTX(load, store, true, call->args[0]);
}
}
}
Expand Down
Loading