Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,13 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(),
return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation);
}

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false) { \
DataType dtype = DType; \
return expr.defined() \
? tvm::cast(dtype, expr.value()) \
: (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \
inline PrimExpr FuncName(Optional<PrimExpr> expr = NullOpt, bool is_size_var = false, \
int64_t min_value = 0) { \
DataType dtype = DType; \
return expr.defined() ? tvm::cast(dtype, expr.value()) \
: (is_size_var ? tvm::tir::SizeVar("", dtype, min_value) \
: tvm::tir::Var("", dtype)); \
}

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
Expand Down
11 changes: 11 additions & 0 deletions include/tvm/target/tag.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class TargetTagRegEntry {
* \param config The config dict for target creation
*/
inline TargetTagRegEntry& set_config(Map<String, ObjectRef> config);
/*!
* \brief Add a key-value pair to the config dict
* \param key The attribute name
* \param value The attribute value
*/
inline TargetTagRegEntry& with_config(String key, ObjectRef value);
/*! \brief Set name of the TargetTag to be the same as registry if it is empty */
inline TargetTagRegEntry& set_name();
/*!
Expand Down Expand Up @@ -131,6 +137,11 @@ inline TargetTagRegEntry& TargetTagRegEntry::set_config(Map<String, ObjectRef> c
return *this;
}

inline TargetTagRegEntry& TargetTagRegEntry::with_config(String key, ObjectRef value) {
tag_->config.Set(key, value);
return *this;
}

inline TargetTagRegEntry& TargetTagRegEntry::set_name() {
if (tag_->name.empty()) {
tag_->name = name;
Expand Down
11 changes: 9 additions & 2 deletions include/tvm/tir/var.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,12 @@ class Var : public PrimExpr {
*/
class SizeVarNode : public VarNode {
public:
int64_t min_value;
void VisitAttrs(tvm::AttrVisitor* v) {
VarNode::VisitAttrs(v);
v->Visit("min_value", &min_value);
}

static constexpr const char* _type_key = "tir.SizeVar";
TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode);
};
Expand All @@ -157,14 +163,15 @@ class SizeVar : public Var {
* \param span The location of this object in the source code.
*/
TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32),
Span span = Span());
int64_t min_value = 0, Span span = Span());
/*!
* \brief Constructor which provides a more detailed type annotation.
* \param name_hint variable name.
* \param type_annotation The type annotation.
* \param span The location of this object in the source code.
*/
TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span());
TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, int64_t min_value = 0,
Span span = Span());
/*!
* \brief Get pointer to the internal value.
* \return the corresponding Variable.
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/arith/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,17 @@ def canonical_simplify(self, expr):
"""
return self._canonical_simplify(expr)

def int_set(self, expr, dom_map):
def int_set(self, expr, dom_map=None):
"""Compute a symbolic IntSet that covers expr for all values in dom_map.

Parameters
----------
expr : PrimExpr
The expression.

dom_map : Dict[Var, tvm.arith.IntSet]
The domain for variables to be relaxed.
dom_map : Optional[Dict[Var, tvm.arith.IntSet]]
The domain for variables to be relaxed. If None, use the domain map defined by bound
variables.

Returns
-------
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/dlight/gpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@
GPU-generic schedule rules.
For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead
"""
from .gemv import GEMV
from .fallback import Fallback
from .matmul import Matmul
from .gemv import GEMV
from .general_reduction import GeneralReduction
from .matmul import (
Matmul,
MatmulTensorizationMMA,
MatmulTensorizationWMMA,
MatmulTensorizationLegacy,
)
from .reduction import Reduction
from .transpose import Transpose
from .general_reduction import GeneralReduction
Loading