Skip to content

Commit

Permalink
[TIR] Block dependence analysis without schedules (#15146)
Browse files Browse the repository at this point in the history
* [TIR] Block dependence analysis without schedules

This work introduces a new object called `BlockDependenceInfo` that
computes and returns block dependences. The idea is to be able to expose
block level dependences to TIR passes without having to create an
explicit schedules.

The patch introduces 2 main classes:
1. `SRefTreeCreator` - This creates and returns a new SRefTree and
   returns a map from original statements to corresponding srefs
2. `BlockDependenceInfo` - This object computes the actual dependences
   between blocks within a block scope and returns it for access in TIR
passes

This is a continuation to
[PR #15034](#15034) and completes the work
started there to expose block level dependences to TIR passes

Note: One major difference between the SRef Tree created for dependence
analysis here versus the one already present in schedules is that this
SRef tree only contains block nodes and not loops. This makes it easier
to find the parent blocks (by just accessing `parent` member)

* Fix lint
  • Loading branch information
quic-sanirudh committed Jun 24, 2023
1 parent 1257f43 commit d26dc44
Show file tree
Hide file tree
Showing 9 changed files with 609 additions and 131 deletions.
102 changes: 102 additions & 0 deletions include/tvm/tir/block_dependence_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/tir/block_dependence_info.h
* \brief Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects to
* store the block level dependences
* \sa BlockDependenceInfoNode
*/

/**
* @brief An object that builds and maintains block scope and StmtSref mapping for Dependence
* analysis
*/

#ifndef TVM_TIR_BLOCK_DEPENDENCE_INFO_H_
#define TVM_TIR_BLOCK_DEPENDENCE_INFO_H_

#include <tvm/tir/block_scope.h>

#include <unordered_map>

namespace tvm {
namespace tir {

/**
* @brief An object that helps build and query block level dependences using the 2 core objects
* BlockScope and StmtSRef
*
* The data structures exposed are:
* 1) sref2scope: Mapping from the srefs to its corresponding BlockScope
* 2) stmt2ref: Mapping from blocks to corresponding StmtSRefs
*
* Note that this object does not store SRefs to loops as the purpose is only to expose block level
* dependences. This provides the advantage that the scope block (parent block) for a given block
* sref can be directly accessed using the sref->parent member
*/
class BlockDependenceInfoNode : public Object {
public:
/*!
* \brief Mapping from a block sref to its correpsonding BlockScope,
* tracking the dependency inside the block scope,
*/
std::unordered_map<StmtSRef, BlockScope, ObjectPtrHash, ObjectPtrEqual> sref2scope;
/*! \brief The reverse mapping from block/for-loop to their corresponding srefs */
std::unordered_map<const StmtNode*, StmtSRef> stmt2ref;

void VisitAttrs(AttrVisitor* v) {}

static constexpr const char* _type_key = "tir.BlockDependenceInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockDependenceInfoNode, Object);

/*!
* \brief Get the BlockScope correpsonding to the sref of scope root block
* \param scope_root The block sref to be retrieved
* \return The corresponding BlockScope
*/
BlockScope GetBlockScope(const StmtSRef& scope_root) const {
auto it = sref2scope.find(scope_root);
CHECK(it != sref2scope.end())
<< "IndexError: Cannot find the corresponding BlockScope to the block sref:\n"
<< GetRef<Stmt>(scope_root->stmt);
return it->second;
}
};

/*!
* \brief Managed reference to BlockDependenceInfoNode
* \sa BlockDependenceInfo
*/
class BlockDependenceInfo : public ObjectRef {
/*! \brief Construct an empty BlockDependenceInfo
*/
TVM_DLL BlockDependenceInfo();

public:
/*! \brief Construct a BlockDependenceInfo from IRModule
*/
TVM_DLL BlockDependenceInfo(IRModule mod);

TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockDependenceInfo, ObjectRef,
BlockDependenceInfoNode);
};

} // namespace tir
} // namespace tvm
#endif // TVM_TIR_BLOCK_DEPENDENCE_INFO_H_
50 changes: 50 additions & 0 deletions include/tvm/tir/block_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@
#ifndef TVM_TIR_BLOCK_SCOPE_H_
#define TVM_TIR_BLOCK_SCOPE_H_

#include <tvm/ir/module.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>

#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -142,6 +147,51 @@ class StmtSRef : public ObjectRef {
TVM_DLL static StmtSRef RootMark();
};

class SRefTreeCreator : private StmtVisitor {
public:
/*!
* \brief StmtSRef Tree Creator
* \param mod The module being scheduled.
* \param include_loops Ignore ForNodes if this value is false
*/
static std::unordered_map<const StmtNode*, StmtSRef> Create(IRModule mod,
bool include_loops = true) {
SRefTreeCreator creator(include_loops);
for (const auto& kv : mod->functions) {
const BaseFunc& base_func = kv.second;
if (auto opt = base_func.as<PrimFunc>()) {
auto func = opt.value();
creator.VisitStmt(func->body);
}
}
return std::move(creator.stmt2ref_);
}

private:
explicit SRefTreeCreator(bool include_loops) : include_loops_(include_loops) {}

/*!
* \brief Add a new statement to the stack, which becomes the current scope
* \param stmt A for-loop statement or a block statement
*/
void PushSRef(const StmtNode* stmt);

/*! \brief Pop the top of the scope and record it in stmt2ref map */
void PopAndRecordSRef();

void VisitStmt_(const ForNode* loop) final;

void VisitStmt_(const BlockRealizeNode* realize) final;

void VisitStmt_(const SeqStmtNode* seq_stmt) final;

bool include_loops_;
/*! \brief The result ScheduleStateNode */
std::unordered_map<const StmtNode*, StmtSRef> stmt2ref_;
/*! \brief The stack frame used to indicate the current scope */
std::vector<StmtSRef> srefs_;
};

/*!
* \brief Type of dependency. Right now we have 4 types of dependencies
* 1) Read-after-write (kRAW)
Expand Down
45 changes: 45 additions & 0 deletions include/tvm/tir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
#ifndef TVM_TIR_UTILS_H_
#define TVM_TIR_UTILS_H_

#include <tvm/tir/block_scope.h>
#include <tvm/tir/stmt.h>

#include <unordered_map>

namespace tvm {
namespace tir {

Expand Down Expand Up @@ -90,6 +95,46 @@ namespace tir {
return result; \
}()

/*!
* \brief Set the `StmtSRefNode::seq_index` field for stmt
* \param stmt2ref The stmt2ref map to be updated with seq_index
* \param stmt The statement, or the realize node of the statement whose sref to be set
* \param seq_index The seq_index to be set
* \param include_loops Ignore ForNodes if this value is false
* \note The method is NOP for statements that are not schedulable, i.e. not For or Block
*/
inline void SetSeqIndex(std::unordered_map<const StmtNode*, StmtSRef>& stmt2ref, // NOLINT(*)
const Stmt& stmt, int seq_index, bool include_loops = true) {
if (const auto* realize = stmt.as<BlockRealizeNode>()) {
const BlockNode* block = realize->block.get();
ICHECK(stmt2ref.count(block));
stmt2ref.at(block)->seq_index = seq_index;
} else if (const auto* block = stmt.as<BlockNode>()) {
ICHECK(stmt2ref.count(block));
stmt2ref.at(block)->seq_index = seq_index;
} else if (const auto* loop = stmt.as<ForNode>()) {
if (!include_loops) return;
ICHECK(stmt2ref.count(loop));
stmt2ref.at(loop)->seq_index = seq_index;
}
}

/*!
* \brief Update seq_index of the children of a SeqStmt
* \param stmt2ref The stmt2ref map to be updated with indices
* \param seq_stmt The SeqStmt whose children need updating
* \param include_loops Ignore ForNodes if this value is false
*/
inline void SetSeqIndexInChildren(
std::unordered_map<const StmtNode*, StmtSRef>& stmt2ref, // NOLINT(*)
const SeqStmtNode* seq_stmt, bool include_loops = true) {
int i = 0;
for (const Stmt& stmt : seq_stmt->seq) {
SetSeqIndex(stmt2ref, stmt, i, include_loops);
++i;
}
}

} // namespace tir
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .generic import add, subtract, multiply

from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError
from .block_dependence_info import BlockDependenceInfo

from . import schedule
from . import ir_builder
Expand Down
88 changes: 88 additions & 0 deletions python/tvm/tir/block_dependence_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Define BlockDependenceInfoNode that uses the BlockScope and StmtSRef objects
to store the block level dependences"""

from typing import Union, Optional
from tvm._ffi import register_object
from tvm.ir.module import IRModule
from tvm.runtime import Object
from tvm.tir import Block, PrimFunc

from .block_scope import BlockScope, StmtSRef
from . import _ffi_api


@register_object("tir.BlockDependenceInfo")
class BlockDependenceInfo(Object):
"""
BlockDependenceInfo
An object that helps build and query block level dependences using the 2 core objects
BlockScope and StmtSRef
The data structures exposed are:
1) sref2scope: Mapping from the srefs to its corresponding BlockScope
2) stmt2ref: Mapping from blocks to corresponding StmtSRefs
Note that this object does not store SRefs to loops as the purpose is only to expose block level
dependences. This provides the advantage that the scope block (parent block) for a given block
sref can be directly accessed as sref->parent
"""

mod: IRModule

def __init__(self, mod: Union[IRModule, PrimFunc]):
if isinstance(mod, PrimFunc):
mod = IRModule({"main": mod})
if not isinstance(mod, IRModule):
raise TypeError(f"Expected `mod` to be PrimFunc or IRModule, but gets: {mod}")
self.__init_handle_by_constructor__(
_ffi_api.BlockDependenceInfo, # type: ignore # pylint: disable=no-member
mod,
)

def get_sref(self, block: Block) -> Optional[StmtSRef]:
"""Return the corresponding sref that points to the block
Parameters
----------
stmt : Block
The block for which the sref is to be retrived
Returns
-------
sref : StmtSRef
The corresponding sref
"""
return _ffi_api.BlockDependenceInfoGetSRef(self, block) # type: ignore # pylint: disable=no-member

def get_block_scope(self, block_sref: StmtSRef) -> BlockScope:
"""Get the BlockScope correpsonding to the block sref
Parameters
----------
block_sref : StmtSRef
The block sref to be retrieved
Returns
-------
scope : StmtSRef
The corresponding BlockScope
"""
return _ffi_api.BlockDependenceInfoGetBlockScope( # type: ignore # pylint: disable=no-member
self, block_sref
)

0 comments on commit d26dc44

Please sign in to comment.