Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Refactor BlockScope outside schedule #15034

Merged
merged 1 commit into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
* under the License.
*/
/*!
* \file tvm/tir/schedule/block_scope.h
* \file tvm/tir/block_scope.h
* \brief Definition of two pillar data structure for TensorIR scheduling: StmtSRef, BlockScope.
* \sa StmtSRefNode
* \sa BlockScopeNode
*/
#ifndef TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
#define TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
#ifndef TVM_TIR_BLOCK_SCOPE_H_
#define TVM_TIR_BLOCK_SCOPE_H_

#include <tvm/tir/stmt.h>

Expand Down Expand Up @@ -216,16 +216,6 @@ class BlockScopeNode : public Object {
std::unordered_map<StmtSRef, Array<Dependency>, ObjectPtrHash, ObjectPtrEqual> dst2deps;
/*! \brief The mapping from the buffer to the blocks who write it */
std::unordered_map<Buffer, Array<StmtSRef>, ObjectPtrHash, ObjectPtrEqual> buffer_writers;
/*!
* \brief This property indicates that the block scope (rooted at its corresponding block) is
* equivalent to of a stage pipeline. Under the following conditions:
*
* 1) The region cover property holds for every of its child blocks
* 2) No write-after-read dependency or opaque dependency, only read-after-write and
* write-after-write are allowed
* 3) All the statements in the scope are schedulable statements, i.e. Block and For
*/
bool stage_pipeline{false};

void VisitAttrs(AttrVisitor* v) {}

Expand Down Expand Up @@ -270,4 +260,4 @@ class BlockScope : public ObjectRef {
} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_BLOCK_SCOPE_H_
#endif // TVM_TIR_BLOCK_SCOPE_H_
20 changes: 16 additions & 4 deletions include/tvm/tir/schedule/state.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
#define TVM_TIR_SCHEDULE_STATE_H_

#include <tvm/ir/module.h>
#include <tvm/tir/block_scope.h>
#include <tvm/tir/function.h>
#include <tvm/tir/schedule/block_scope.h>

#include <unordered_map>
#include <utility>
Expand All @@ -51,13 +51,25 @@ struct BlockInfo {
* produced by its producers
*/
bool region_cover{false};
/*!
* \brief This property indicates that the block scope (rooted at its corresponding block) is
* equivalent to of a stage pipeline. Under the following conditions:
*
* 1) The region cover property holds for every of its child blocks
* 2) No write-after-read dependency or opaque dependency, only read-after-write and
* write-after-write are allowed
* 3) All the statements in the scope are schedulable statements, i.e. Block and For
*/
bool stage_pipeline{false};

BlockInfo() = default;

explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false)
explicit BlockInfo(BlockScope scope, bool affine_binding = false, bool region_cover = false,
bool stage_pipeline = false)
: scope(std::move(scope)), //
affine_binding(affine_binding), //
region_cover(region_cover) {}
region_cover(region_cover),
stage_pipeline(stage_pipeline) {}
};

/*!
Expand Down Expand Up @@ -185,7 +197,7 @@ class ScheduleStateNode : public Object {
* \return The corresponding BlockScope
*/
bool IsStagePipeline(const StmtSRef& scope_root) const {
return GetBlockScope(scope_root)->stage_pipeline;
return GetBlockInfo(scope_root).stage_pipeline;
}
};

Expand Down
96 changes: 96 additions & 0 deletions include/tvm/tir/utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_UTILS_H_
#define TVM_TIR_UTILS_H_

namespace tvm {
namespace tir {

/*!
* \brief A helper macro to convert an sref to the statement it points to,
* then check if the downcasting succeeded.
* \param Result The result variable, used for checking
* \param SRef The SRef to be cast
* \param Type The type to be cast to, can be Block or For
*/
#define TVM_SREF_AS_OR_ERR(Result, SRef, Type) \
SRef->StmtAs<Type>(); \
ICHECK(Result)

/*!
* \brief A helper macro to convert an sref to the block it points to,
*
* Throws an internal error if downcasting fails. The variable name
* in the parent scope is used for the error message.
*
* \param SRef The SRef to be cast
*/
#define TVM_SREF_TO_BLOCK(SRef) \
[&]() { \
auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::BlockNode) \
<< "TypeError: Expects StmtSRef `" << #SRef << "` points to `Block`, but gets: " \
<< ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
return result; \
}()

/*!
* \brief A helper macro to convert an sref to the for-loop it points to
*
* Throws an internal error if downcasting fails. The variable name
* in the parent scope is used for the error message.
*
* \param SRef The SRef to be cast
*/
#define TVM_SREF_TO_FOR(SRef) \
[&]() { \
auto result = TVM_SREF_AS_OR_ERR(result, (SRef), ::tvm::tir::ForNode) \
<< "TypeError: Expects StmtSRef `" << #SRef << "` points to `Loop`, but gets: " \
<< ((SRef)->stmt ? (SRef)->stmt->GetTypeKey() : "None"); \
return result; \
}()

/*!
* \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
* then check if the downcasting succeeded.
* \param Result The result variable, used for checking
* \param From The ObjectRef to be downcast
* \param Type The type to be downcast to
*/
#define TVM_TYPE_AS_OR_ERR(Result, From, Type) \
From.as<Type>(); \
ICHECK(Result)

/*!
* \brief Downcast a TVM ObjectRef to its corresponding container using `ObjectRef::as<Type>`,
* throwing an internal error if downcast fails.
* \param From The ObjectRef to be downcast
* \param Type The type to be downcast to
*/
#define TVM_TYPE_AS(From, Type) \
[&]() { \
auto result = TVM_TYPE_AS_OR_ERR(result, (From), Type) \
<< "TypeError: Expects `" << #From << "` to have type `" << Type::_type_key \
<< "`, but gets: " << ((From).defined() ? (From)->GetTypeKey() : "None"); \
return result; \
}()

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_UTILS_H_
File renamed without changes.
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# pylint: disable=unused-import
"""Namespace for the TensorIR schedule API."""

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from ..block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .instruction import Instruction, InstructionKind
from .schedule import BlockRV, ExprRV, LoopRV, Schedule, ScheduleError
from .state import ScheduleDebugMask, ScheduleState
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/schedule/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from tvm.tir import Block, BlockRealize, For, PrimFunc

from . import _ffi_api
from .block_scope import BlockScope, StmtSRef
from ..block_scope import BlockScope, StmtSRef

CachedFlags = namedtuple("CachedFlags", ["affine_binding", "region_cover", "stage_pipeline"])

Expand Down
25 changes: 12 additions & 13 deletions src/tir/schedule/block_scope.cc → src/tir/ir/block_scope.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include "./utils.h"
#include <tvm/tir/block_scope.h>
#include <tvm/tir/utils.h>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -141,21 +142,19 @@ TVM_REGISTER_NODE_TYPE(StmtSRefNode);
TVM_REGISTER_NODE_TYPE(DependencyNode);
TVM_REGISTER_NODE_TYPE(BlockScopeNode);

TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefStmt")
.set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
return GetRef<Optional<Stmt>>(sref->stmt);
});
TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefParent")
.set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
return GetRef<Optional<StmtSRef>>(sref->parent);
});
TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefRootMark") //
TVM_REGISTER_GLOBAL("tir.StmtSRefStmt").set_body_typed([](StmtSRef sref) -> Optional<Stmt> {
return GetRef<Optional<Stmt>>(sref->stmt);
});
TVM_REGISTER_GLOBAL("tir.StmtSRefParent").set_body_typed([](StmtSRef sref) -> Optional<StmtSRef> {
return GetRef<Optional<StmtSRef>>(sref->parent);
});
TVM_REGISTER_GLOBAL("tir.StmtSRefRootMark") //
.set_body_typed(StmtSRef::RootMark);
TVM_REGISTER_GLOBAL("tir.schedule.StmtSRefInlineMark") //
TVM_REGISTER_GLOBAL("tir.StmtSRefInlineMark") //
.set_body_typed(StmtSRef::InlineMark);
TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsBySrc")
TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsBySrc")
.set_body_method<BlockScope>(&BlockScopeNode::GetDepsBySrc);
TVM_REGISTER_GLOBAL("tir.schedule.BlockScopeGetDepsByDst")
TVM_REGISTER_GLOBAL("tir.BlockScopeGetDepsByDst")
.set_body_method<BlockScope>(&BlockScopeNode::GetDepsByDst);

} // namespace tir
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Definition of a scope that is a stage pipeline:
}
// Step 2. Handle `require_stage_pipeline`
if (require_stage_pipeline && self->enable_check) {
bool stage_pipeline = self->GetBlockInfo(scope_root_sref).scope->stage_pipeline;
bool stage_pipeline = self->GetBlockInfo(scope_root_sref).stage_pipeline;
if (stage_pipeline == false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(scope_root_sref);
throw NotStagePipelineError(self->mod, GetRef<Block>(block));
Expand Down
6 changes: 3 additions & 3 deletions src/tir/schedule/analysis/verify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,10 @@ void VerifyCachedFlags(const ScheduleState& self) {
new_block_info.region_cover,
old_block_info.region_cover);
}
if (new_block_info.scope->stage_pipeline != old_block_info.scope->stage_pipeline) {
if (new_block_info.stage_pipeline != old_block_info.stage_pipeline) {
block_info_wrong_stage_pipeline.emplace_back(new_sref, //
new_block_info.scope->stage_pipeline,
old_block_info.scope->stage_pipeline);
new_block_info.stage_pipeline,
old_block_info.stage_pipeline);
}
}

Expand Down
1 change: 0 additions & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ class ScheduleCopier {
scope->src2deps = Copy(old_info.scope->src2deps);
scope->dst2deps = Copy(old_info.scope->dst2deps);
scope->buffer_writers = Copy(old_info.scope->buffer_writers);
scope->stage_pipeline = old_info.scope->stage_pipeline;
new_info.scope = BlockScope(std::move(scope));
result[Copy(old_sref)] = std::move(new_info);
}
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/cache_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,
Array<Block> cache_stages = MakeIndexCacheStage(&info, storage_scope);
Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);

bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline;
bool old_stage_pipeline = self->block_info[block_sref].stage_pipeline;

// Step 3. Replacing and updating flags.
self->Replace(scope_sref, new_scope, info.block_reuse);
Expand All @@ -486,7 +486,7 @@ Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,

block_info.affine_binding = affine_binding;
block_info.region_cover = true;
block_info.scope->stage_pipeline = old_stage_pipeline;
block_info.stage_pipeline = old_stage_pipeline;
}

return result_block_srefs;
Expand Down
14 changes: 7 additions & 7 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1526,7 +1526,7 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& block_sref, int read_buff
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;
return result_block_sref;
}

Expand Down Expand Up @@ -1591,7 +1591,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;
return result_block_sref;
}

Expand Down Expand Up @@ -1812,7 +1812,7 @@ StmtSRef ReindexCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;
return result_block_sref;
}

Expand Down Expand Up @@ -1876,7 +1876,7 @@ StmtSRef ReindexCacheWrite(ScheduleState self, const StmtSRef& block_sref, int w
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;
return result_block_sref;
}

Expand Down Expand Up @@ -1954,7 +1954,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int
BlockInfo& block_info_read = self->block_info[result_block_sref];
block_info_read.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info_read.region_cover = true;
block_info_read.scope->stage_pipeline = false;
block_info_read.stage_pipeline = false;
results_block_sref.push_back(result_block_sref);

// Do cache write
Expand Down Expand Up @@ -1983,7 +1983,7 @@ Array<StmtSRef> CacheInplace(ScheduleState self, const StmtSRef& block_sref, int
BlockInfo& block_info_write = self->block_info[result_block_sref];
block_info_write.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info_write.region_cover = true;
block_info_write.scope->stage_pipeline = false;
block_info_write.stage_pipeline = false;
results_block_sref.push_back(result_block_sref);

return results_block_sref;
Expand Down Expand Up @@ -2058,7 +2058,7 @@ StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buffer_inde
BlockInfo& block_info = self->block_info[result_block_sref];
block_info.affine_binding = CalculateAffineFlag(self, result_block_sref);
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;
return result_block_sref;
}

Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/decompose_padding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
BlockInfo& block_info = self->block_info[new_block_sref];
block_info.affine_binding = true;
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;

// If the const pad value filling block is lifted out of the original subtree,
// set the region_cover flag as false since region_cover is the property under the subtree.
Expand All @@ -518,7 +518,7 @@ StmtSRef DecomposePaddingImpl(ScheduleState self, const StmtSRef& block_sref,
}
}
if (!preserve_stage_pipeline) {
self->block_info[scope_root_sref].scope->stage_pipeline = false;
self->block_info[scope_root_sref].stage_pipeline = false;
}
return new_block_sref;
}
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/read_write_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ struct ReadWriteAtImpl {
BlockInfo& block_info = self_->block_info[new_block_sref];
block_info.affine_binding = affine_binding;
block_info.region_cover = true;
block_info.scope->stage_pipeline = true;
block_info.stage_pipeline = true;
}

template <bool is_read>
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& rf_loop_sref, int factor_ax
BlockInfo& info = self->block_info[new_block_sref];
info.affine_binding = true;
info.region_cover = true;
info.scope->stage_pipeline = true;
info.stage_pipeline = true;
}
return new_block_srefs[0];
}
Expand Down