Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/tvm/relax/struct_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,9 @@ class TensorStructInfo : public StructInfo {
*
* \note shape must already be normalized.
*/
TVM_DLL TensorStructInfo(Expr shape, DataType dtype, VDevice vdevice = VDevice(),
TVM_DLL TensorStructInfo(Expr shape, DataType dtype,
VDevice vdevice = VDevice(/*tgt*/ {}, /*dev_id*/ 0,
/*mem_scope*/ "global"),
Span span = Span());

/*!
Expand All @@ -230,7 +232,9 @@ class TensorStructInfo : public StructInfo {
* \param vdevice The virtual device.
* \param span The span of the AST.
*/
TVM_DLL TensorStructInfo(DataType dtype, int ndim, VDevice vdevice = VDevice(),
TVM_DLL TensorStructInfo(DataType dtype, int ndim,
VDevice vdevice = VDevice(/*tgt*/ {}, /*dev_id*/ 0,
/*mem_scope*/ "global"),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode);
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/struct_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def __init__(
) -> None:
if isinstance(shape, (list, tuple, Array)):
shape = ShapeExpr(shape)
if vdevice is None:
vdevice = VDevice(None, 0, "global")
self.__init_handle_by_constructor__(
_ffi_api.TensorStructInfo, shape, dtype, ndim, vdevice, span # type: ignore
)
Expand Down
2 changes: 1 addition & 1 deletion src/ir/global_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
return n;
});

VDevice::VDevice(Target tgt = {}, int dev_id = -1, MemoryScope mem_scope = {}) {
VDevice::VDevice(Target tgt, int dev_id, MemoryScope mem_scope) {
ObjectPtr<VDeviceNode> n = make_object<VDeviceNode>();
n->target = std::move(tgt);
n->vdevice_id = std::move(dev_id);
Expand Down
4 changes: 2 additions & 2 deletions src/relax/analysis/struct_info_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class WellDefinedEraser : public StructInfoMutator,
std::swap(has_undefined_, has_undefined);
}

VDevice vdev = VDevice();
VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
if (op->vdevice.defined()) {
vdev = op->vdevice.value();
}
Expand Down Expand Up @@ -772,7 +772,7 @@ class StructInfoLCAFinder
// find the target dtype and ndim.
DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void();
int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim;
VDevice vdev = VDevice();
VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
if (lhs->vdevice.value().same_as(lhs->vdevice.value())) {
vdev = lhs->vdevice.value();
Expand Down
3 changes: 2 additions & 1 deletion src/relax/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ Constant::Constant(runtime::NDArray data, Optional<StructInfo> struct_info_annot
n->struct_info_ = struct_info_annotation.value();
n->checked_type_ = GetStaticType(struct_info_annotation.value());
} else {
TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), span);
TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(),
VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global"), span);
n->struct_info_ = tinfo;
n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/ir/struct_info_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ StructInfo StructInfoMutator::VisitStructInfo_(const TensorStructInfoNode* op) {
shape = this->VisitStructInfoExprField(op->shape.value());
}

VDevice vdev = VDevice();
VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
if (op->vdevice.defined()) {
vdev = op->vdevice.value();
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ class LayoutConvertMutator : public ExprMutator {
new_shape.push_back(
shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
}
VDevice vdev = VDevice();
VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
if (tsinfo->vdevice.defined()) {
vdev = tsinfo->vdevice.value();
}
Expand Down
2 changes: 1 addition & 1 deletion src/relax/transform/to_mixed_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ class ToMixedPrecisionRewriter : public ExprMutator {
if (fp16_input_names_.count(var->name_hint())) {
auto sinfo = GetStructInfo(var);
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
VDevice vdev = VDevice();
VDevice vdev = VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
if (tensor_sinfo->vdevice.defined()) {
vdev = tensor_sinfo->vdevice.value();
}
Expand Down
2 changes: 1 addition & 1 deletion src/script/ir_builder/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ VDevice LookupVDevice(String target_kind, int device_index) {
}
}
LOG(WARNING) << "The annotated device was not found, please check your vdevice list.";
return VDevice();
return VDevice(/*tgt*/ {}, /*dev_id*/ 0, /*mem_scope*/ "global");
}

TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/relax/struct_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
kwargs_keys.push_back("ndim");
kwargs_values.push_back(LiteralDoc::Int(n->ndim, n_p->Attr("ndim")));
}
if (n->vdevice.defined()) {
if (n->vdevice.defined() && n->vdevice.value()->target.defined()) {
kwargs_keys.push_back("vdevice");
std::string dev_kind = n->vdevice.value()->target->kind->name;
int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), d);
Expand Down