-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR] Block dependence analysis without schedules (#15146)
* [TIR] Block dependence analysis without schedules This work introduces a new object called `BlockDependenceInfo` that computes and returns block dependences. The idea is to be able to expose block level dependences to TIR passes without having to create an explicit schedules. The patch introduces 2 main classes: 1. `SRefTreeCreator` - This creates and returns a new SRefTree and returns a map from original statements to corresponding srefs 2. `BlockDependenceInfo` - This object computes the actual dependences between blocks within a block scope and returns it for access in TIR passes This is a continuation to [PR #15034](#15034) and completes the work started there to expose block level dependences to TIR passes Note: One major difference between the SRef Tree created for dependence analysis here versus the one already present in schedules is that this SRef tree only contains block nodes and not loops. This makes it easier to find the parent blocks (by just accessing `parent` member) * Fix lint
- Loading branch information
1 parent
1257f43
commit d26dc44
Showing
9 changed files
with
609 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
/*! | ||
* \file tvm/tir/block_dependence_info.h | ||
* \brief Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects to | ||
* store the block level dependences | ||
* \sa BlockDependenceInfoNode | ||
*/ | ||
|
||
/** | ||
* @brief An object that builds and maintains block scope and StmtSref mapping for Dependence | ||
* analysis | ||
*/ | ||
|
||
#ifndef TVM_TIR_BLOCK_DEPENDENCE_INFO_H_ | ||
#define TVM_TIR_BLOCK_DEPENDENCE_INFO_H_ | ||
|
||
#include <tvm/tir/block_scope.h> | ||
|
||
#include <unordered_map> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
|
||
/** | ||
* @brief An object that helps build and query block level dependences using the 2 core objects | ||
* BlockScope and StmtSRef | ||
* | ||
* The data structures exposed are: | ||
* 1) sref2scope: Mapping from the srefs to its corresponding BlockScope | ||
* 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs | ||
* | ||
* Note that this object does not store SRefs to loops as the purpose is only to expose block level | ||
* dependences. This provides the advantage that the scope block (parent block) for a given block | ||
* sref can be directly accessed using the sref->parent member | ||
*/ | ||
class BlockDependenceInfoNode : public Object { | ||
public: | ||
/*! | ||
* \brief Mapping from a block sref to its correpsonding BlockScope, | ||
* tracking the dependency inside the block scope, | ||
*/ | ||
std::unordered_map<StmtSRef, BlockScope, ObjectPtrHash, ObjectPtrEqual> sref2scope; | ||
/*! \brief The reverse mapping from block/for-loop to their corresponding srefs */ | ||
std::unordered_map<const StmtNode*, StmtSRef> stmt2ref; | ||
|
||
void VisitAttrs(AttrVisitor* v) {} | ||
|
||
static constexpr const char* _type_key = "tir.BlockDependenceInfo"; | ||
TVM_DECLARE_FINAL_OBJECT_INFO(BlockDependenceInfoNode, Object); | ||
|
||
/*! | ||
* \brief Get the BlockScope correpsonding to the sref of scope root block | ||
* \param scope_root The block sref to be retrieved | ||
* \return The corresponding BlockScope | ||
*/ | ||
BlockScope GetBlockScope(const StmtSRef& scope_root) const { | ||
auto it = sref2scope.find(scope_root); | ||
CHECK(it != sref2scope.end()) | ||
<< "IndexError: Cannot find the corresponding BlockScope to the block sref:\n" | ||
<< GetRef<Stmt>(scope_root->stmt); | ||
return it->second; | ||
} | ||
}; | ||
|
||
/*! | ||
* \brief Managed reference to BlockDependenceInfoNode | ||
* \sa BlockDependenceInfo | ||
*/ | ||
class BlockDependenceInfo : public ObjectRef { | ||
/*! \brief Construct an empty BlockDependenceInfo | ||
*/ | ||
TVM_DLL BlockDependenceInfo(); | ||
|
||
public: | ||
/*! \brief Construct a BlockDependenceInfo from IRModule | ||
*/ | ||
TVM_DLL BlockDependenceInfo(IRModule mod); | ||
|
||
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockDependenceInfo, ObjectRef, | ||
BlockDependenceInfoNode); | ||
}; | ||
|
||
} // namespace tir | ||
} // namespace tvm | ||
#endif // TVM_TIR_BLOCK_DEPENDENCE_INFO_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects | ||
to store the block level dependences""" | ||
|
||
from typing import Union, Optional | ||
from tvm._ffi import register_object | ||
from tvm.ir.module import IRModule | ||
from tvm.runtime import Object | ||
from tvm.tir import Block, PrimFunc | ||
|
||
from .block_scope import BlockScope, StmtSRef | ||
from . import _ffi_api | ||
|
||
|
||
@register_object("tir.BlockDependenceInfo") | ||
class BlockDependenceInfo(Object): | ||
""" | ||
BlockDependenceInfo | ||
An object that helps build and query block level dependences using the 2 core objects | ||
BlockScope and StmtSRef | ||
The data structures exposed are: | ||
1) sref2scope: Mapping from the srefs to its corresponding BlockScope | ||
2) stmt2ref: Mapping from blocks to corresponding StmtSRefs | ||
Note that this object does not store SRefs to loops as the purpose is only to expose block level | ||
dependences. This provides the advantage that the scope block (parent block) for a given block | ||
sref can be directly accessed as sref->parent | ||
""" | ||
|
||
mod: IRModule | ||
|
||
def __init__(self, mod: Union[IRModule, PrimFunc]): | ||
if isinstance(mod, PrimFunc): | ||
mod = IRModule({"main": mod}) | ||
if not isinstance(mod, IRModule): | ||
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}") | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.BlockDependenceInfo, # type: ignore # pylint: disable=no-member | ||
mod, | ||
) | ||
|
||
def get_sref(self, block: Block) -> Optional[StmtSRef]: | ||
"""Return the corresponding sref that points to the block | ||
Parameters | ||
---------- | ||
stmt : Block | ||
The block for which the sref is to be retrived | ||
Returns | ||
------- | ||
sref : StmtSRef | ||
The corresponding sref | ||
""" | ||
return _ffi_api.BlockDependenceInfoGetSRef(self, block) # type: ignore # pylint: disable=no-member | ||
|
||
def get_block_scope(self, block_sref: StmtSRef) -> BlockScope: | ||
"""Get the BlockScope correpsonding to the block sref | ||
Parameters | ||
---------- | ||
block_sref : StmtSRef | ||
The block sref to be retrieved | ||
Returns | ||
------- | ||
scope : StmtSRef | ||
The corresponding BlockScope | ||
""" | ||
return _ffi_api.BlockDependenceInfoGetBlockScope( # type: ignore # pylint: disable=no-member | ||
self, block_sref | ||
) |
Oops, something went wrong.