Skip to content

Commit

Permalink
[TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)
Browse files Browse the repository at this point in the history
Part of changes being split out from
#14778 into independent portions.
This commit allows the `LowerThreadAllreduce` pass to handle
`DeclBuffer` nodes that occur within its input.
  • Loading branch information
Lunderberg committed Jun 16, 2023
1 parent 64f9b12 commit 3f2aa68
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 66 deletions.
108 changes: 42 additions & 66 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,92 +93,70 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
}
Stmt VisitStmt_(const AllocateNode* op) final {
Stmt stmt = StmtExprMutator::VisitStmt_(op);
op = stmt.as<AllocateNode>();
auto it = alloc_remap_.find(op->buffer_var.get());
if (it != alloc_remap_.end()) {
auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));

if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
new_storage_scopes_[repl->buffer_var.get()] = "local";
} else {
new_storage_scopes_[repl->buffer_var.get()] = "shared";
}
return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
} else {
return stmt;
auto write_ptr = node.CopyOnWrite();
write_ptr->buffer_var = repl->buffer_var;
write_ptr->dtype = repl->dtype;
write_ptr->extents = repl->extents;
write_ptr->condition = repl->condition;
}
return std::move(node);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
{
auto it = load_remap_.find(op->buffer->data.get());
if (it != load_remap_.end()) {
for (const auto& index : op->indices) {
ICHECK(is_zero(index));
}
return it->second;
}
Optional<Buffer> GetRemappedBuffer(const Buffer& buf) {
if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) {
return it->second;
}

BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();

{
auto it = buf_remap_.find(op->buffer.get());
if (it != buf_remap_.end()) {
return BufferLoad(it->second, op->indices, op->span);
}
if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) {
Buffer new_buf = buf;
new_buf.CopyOnWrite()->data = it->second;
buf_remap_[buf.get()] = new_buf;
return new_buf;
}

{
auto it = var_remap_.find(op->buffer->data.get());
if (it != var_remap_.end()) {
Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape,
op->buffer->strides, op->buffer->elem_offset, op->buffer->name,
op->buffer->data_alignment, op->buffer->offset_factor,
op->buffer->buffer_type, op->buffer->axis_separators,
op->buffer->span);
buf_remap_[op->buffer.get()] = remapped_buffer;
return BufferLoad(remapped_buffer, op->indices, op->span);
}
}
return StmtExprMutator::VisitExpr_(op);
return NullOpt;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
Stmt VisitStmt_(const DeclBufferNode* op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto buf = GetRemappedBuffer(node->buffer)) {
node.CopyOnWrite()->buffer = buf.value();
}
return std::move(node);
}

auto it = store_remap_.find(store->buffer.get());
if (it != store_remap_.end()) {
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
if (auto it = load_remap_.find(op->buffer->data.get()); it != load_remap_.end()) {
for (const auto& index : op->indices) {
ICHECK(is_zero(index));
}

auto writer = store.CopyOnWrite();
writer->buffer = it->second;
return std::move(store);
return it->second;
}

{
auto it = buf_remap_.find(store->buffer.get());
if (it != buf_remap_.end()) {
return BufferStore(it->second, store->value, store->indices, store->span);
}
}
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
op = load.get();

{
auto it = var_remap_.find(store->buffer->data.get());
if (it != var_remap_.end()) {
Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape,
store->buffer->strides, store->buffer->elem_offset,
store->buffer->name, store->buffer->data_alignment,
store->buffer->offset_factor, store->buffer->buffer_type,
store->buffer->axis_separators, store->buffer->span);
buf_remap_[store->buffer.get()] = remapped_buffer;
return BufferStore(remapped_buffer, store->value, store->indices, store->span);
}
if (auto opt = GetRemappedBuffer(load->buffer)) {
load.CopyOnWrite()->buffer = opt.value();
}
return std::move(load);
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));

if (auto opt = GetRemappedBuffer(store->buffer)) {
store.CopyOnWrite()->buffer = opt.value();
}
return std::move(store);
}

Expand Down Expand Up @@ -446,11 +424,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
PrimExpr val = BufferLoad(buf, zero_indices);
ICHECK_EQ(val->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = val;
store_remap_[buffers[i].get()] = buf;
Array<PrimExpr> extents{PrimExpr(1)};
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
var_remap_[buffers[i]->data.get()] = buf->data;
buf_remap_[buffers[i].get()] = buf;
warp_allocs_.insert(node.get());
}
} else {
Expand Down Expand Up @@ -489,7 +467,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
Allocate(shared_bufs[idx]->data, types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
store_remap_[buffers[idx].get()] = shared_bufs[idx];
buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
}

Expand Down Expand Up @@ -718,8 +696,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
// The store remap
std::unordered_map<const BufferNode*, Buffer> store_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
// BufferVar remap
Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/update_pointer_storage_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) {
return buf;
}

Stmt UpdatePointerStorageScope::VisitStmt_(const DeclBufferNode* op) {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
return UpdateBufferAccess(node);
}

PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return UpdateBufferAccess(node);
Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/update_pointer_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class UpdatePointerStorageScope : public StmtExprMutator {
virtual PrimExpr VisitExpr_(const VarNode*);
virtual PrimExpr VisitExpr_(const BufferLoadNode*);
virtual Stmt VisitStmt_(const AllocateNode*);
virtual Stmt VisitStmt_(const DeclBufferNode*);
virtual Stmt VisitStmt_(const BufferStoreNode*);

private:
Expand Down
Loading

0 comments on commit 3f2aa68

Please sign in to comment.