Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][DTYPE] Isolate dtype to runtime #4560

Merged
merged 1 commit into from
Dec 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/tvm/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class AttrsEqual {
bool operator()(const std::string& lhs, const std::string& rhs) const {
return lhs == rhs;
}
bool operator()(const Type& lhs, const Type& rhs) const {
bool operator()(const DataType& lhs, const DataType& rhs) const {
return lhs == rhs;
}
// node comparator
Expand Down Expand Up @@ -506,8 +506,8 @@ inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
}
}
template<>
inline void SetValue(Type* ptr, const TVMArgValue& val) {
*ptr = val.operator Type();
inline void SetValue(DataType* ptr, const TVMArgValue& val) {
*ptr = val.operator DataType();
}
template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
Expand Down Expand Up @@ -611,7 +611,7 @@ struct TypeName<uint64_t> {
};

template<>
struct TypeName<Type> {
struct TypeName<DataType> {
static constexpr const char* value = "Type";
};

Expand Down
18 changes: 10 additions & 8 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,16 @@ class Buffer : public NodeRef {
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1, Expr offset = make_const(Int(32), 0)) const;
TVM_DLL Expr access_ptr(int access_mask,
DataType ptr_type = DataType::Handle(),
int content_lanes = 1,
Expr offset = make_const(DataType::Int(32), 0)) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
TVM_DLL Expr vload(Array<Expr> begin, Type dtype) const;
TVM_DLL Expr vload(Array<Expr> begin, DataType dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
Expand All @@ -108,7 +110,7 @@ class BufferNode : public Node {
*/
Var data;
/*! \brief data type in the content of the tensor */
Type dtype;
DataType dtype;
/*! \brief The shape of the buffer */
Array<Expr> shape;
/*!
Expand Down Expand Up @@ -149,14 +151,14 @@ class BufferNode : public Node {
}

/*! \return preferred index type for this buffer node */
Type DefaultIndexType() const {
return shape.size() != 0 ? shape[0].type() : Int(32);
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
}

// User can specify data_alignment and offset_factor to be 0
// A default value will be picked.
TVM_DLL static Buffer make(Var ptr,
Type dtype,
DataType dtype,
Array<Expr> shape,
Array<Expr> strides,
Expr elem_offset,
Expand All @@ -183,7 +185,7 @@ inline const BufferNode* Buffer::operator->() const {
* \sa BufferNode::make for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<Expr> shape,
Type dtype = Float(32),
DataType dtype = DataType::Float(32),
std::string name = "buffer");
} // namespace tvm
#endif // TVM_BUFFER_H_
4 changes: 2 additions & 2 deletions include/tvm/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ struct ChannelNode : public Node {
/*! \brief Variable to channel handle */
Var handle_var;
/*! \brief default data type in read/write */
Type dtype;
DataType dtype;
// visit all attributes
void VisitAttrs(AttrVisitor* v) {
v->Visit("handle_var", &handle_var);
v->Visit("dtype", &dtype);
}

static Channel make(Var handle_var, Type dtype);
static Channel make(Var handle_var, DataType dtype);
static constexpr const char* _type_key = "Channel";

TVM_DECLARE_NODE_TYPE_INFO(ChannelNode, Node);
Expand Down
18 changes: 9 additions & 9 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,19 @@
#include <unordered_map>
#include <iostream>
#include "base.h"
#include "dtype.h"
#include "node/node.h"
#include "node/container.h"
#include "node/functor.h"
#include "runtime/c_runtime_api.h"
#include "runtime/data_type.h"

namespace tvm {

/*! \brief Base node of all expressions. */
class ExprNode : public Node {
public:
/*! \brief The data type of the expression. */
DataType type;
DataType dtype;

static constexpr const char* _type_key = "Expr";
TVM_DECLARE_BASE_NODE_INFO(ExprNode, Node);
Expand Down Expand Up @@ -69,8 +69,8 @@ class Expr : public NodeRef {
TVM_DLL Expr(std::string str); // NOLINT(*)

/*! \return the data type of this expression. */
DataType type() const {
return static_cast<const ExprNode*>(get())->type;
DataType dtype() const {
return static_cast<const ExprNode*>(get())->dtype;
}

/*! \brief type indicate the container type */
Expand Down Expand Up @@ -113,7 +113,7 @@ class Variable : public ExprNode {
static Var make(DataType dtype, std::string name_hint);

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("dtype", &dtype);
v->Visit("name", &name_hint);
}

Expand All @@ -126,14 +126,14 @@ class Var : public Expr {
public:
explicit Var(ObjectPtr<Object> n) : Expr(n) {}
TVM_DLL explicit Var(std::string name_hint = "v",
Type t = Int(32));
DataType t = DataType::Int(32));
/*!
* \brief Make a new copy of var with same type, append suffix
* \param suffix The suffix to be appended.
* \return the new Var copy
*/
Var copy_with_suffix(const std::string& suffix) const {
return Var((*this)->name_hint + suffix, (*this)->type);
return Var((*this)->name_hint + suffix, (*this)->dtype);
}
/*!
* \brief Get pointer to the internal value.
Expand Down Expand Up @@ -167,7 +167,7 @@ class IntImm : public ExprNode {
int64_t value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &type);
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

Expand Down Expand Up @@ -452,7 +452,7 @@ inline const char* IterVarType2String(IterVarType t) {
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(std::string name_hint, Type t = Int(32));
TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));

/*
* \brief Template function to convert Map to unordered_map
Expand Down
46 changes: 30 additions & 16 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,28 +44,28 @@ namespace tvm {
*/
template<typename ValueType,
typename = typename std::enable_if<std::is_pod<ValueType>::value>::type>
inline Expr make_const(Type t, ValueType value);
inline Expr make_const(DataType t, ValueType value);
/*!
* \brief Make a const zero expr.
* \param t The target type.
* \return the result expression.
*/
inline Expr make_zero(Type t);
inline Expr make_zero(DataType t);
/*!
* \brief Make a constant true expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_true(int lanes = 1) {
return make_const(UInt(1, lanes), 1);
return make_const(DataType::UInt(1, lanes), 1);
}
/*!
* \brief Make a constant false expression.
* \param lanes The number of lanes in the bool
* \return The result expression.
*/
inline Expr const_false(int lanes = 1) {
return make_const(UInt(1, lanes), 0);
return make_const(DataType::UInt(1, lanes), 0);
}
/*!
* \brief Get x as constant int expression.
Expand Down Expand Up @@ -139,6 +139,20 @@ inline bool is_zero(const Expr& x) {
*/
inline bool is_const(const Expr& x);

/*!
* Query the maximum possible value of dtype.
* \param dtype The data type.
* \return the maximum possible value in this format.
*/
TVM_DLL Expr max_value(const DataType& dtype);

/*!
* Query the minimum possible value of dtype.
* \param dtype The data type.
* \return the minimum possible value in this format.
*/
TVM_DLL Expr min_value(const DataType& dtype);

/*!
* \brief Check whether x is a constant power of two
* If x is power of two, write the power to the shift.
Expand All @@ -157,7 +171,7 @@ TVM_DLL bool is_const_power_of_two_integer(const Expr& x, int* shift);
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr cast(const Type& t, Expr value);
TVM_DLL Expr cast(const DataType& t, Expr value);
/*!
* \brief perform reinterpret cast value to type.
*
Expand All @@ -166,7 +180,7 @@ TVM_DLL Expr cast(const Type& t, Expr value);
* \return The result expression.
* \note This function may return value if the type is the same.
*/
TVM_DLL Expr reinterpret(const Type& t, Expr value);
TVM_DLL Expr reinterpret(const DataType& t, Expr value);
/*!
* \brief add operator
*
Expand Down Expand Up @@ -586,7 +600,7 @@ TVM_DLL Expr trunc(Expr x);
// Intrinsic operators
#define TVM_DECLARE_INTRIN_UNARY(OpName) \
inline Expr OpName(Expr x) { \
return ir::Call::make(x.type(), #OpName, {x}, ir::Call::PureIntrinsic); \
return ir::Call::make(x.dtype(), #OpName, {x}, ir::Call::PureIntrinsic); \
} \

TVM_DECLARE_INTRIN_UNARY(exp);
Expand Down Expand Up @@ -657,7 +671,7 @@ inline bool is_no_op(const Stmt& stmt) {
}

template<typename ValueType>
inline Expr MakeConstScalar(Type t, ValueType value) {
inline Expr MakeConstScalar(DataType t, ValueType value) {
if (t.is_int()) return ir::IntImm::make(t, static_cast<int64_t>(value));
if (t.is_uint()) return ir::UIntImm::make(t, static_cast<uint64_t>(value));
if (t.is_float()) return ir::FloatImm::make(t, static_cast<double>(value));
Expand All @@ -672,7 +686,7 @@ inline Expr MakeConstScalar(Type t, ValueType value) {
}

template<typename ValueType, typename>
inline Expr make_const(Type t, ValueType value) {
inline Expr make_const(DataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
Expand All @@ -681,9 +695,9 @@ inline Expr make_const(Type t, ValueType value) {
}
}

inline Expr make_zero(Type t) {
inline Expr make_zero(DataType t) {
if (t.is_handle()) {
return reinterpret(t, make_const(UInt(64), 0));
return reinterpret(t, make_const(DataType::UInt(64), 0));
}
return make_const(t, 0);
}
Expand All @@ -703,13 +717,13 @@ inline Expr make_zero(Type t) {
return Name(Expr(a), b); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
return Name(make_const(b.dtype(), a), b); \
} \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
return Name(a, make_const(a.dtype(), b)); \
} \
inline Expr Name(const Expr& a, double b) { \
return Name(a, make_const(Float(64), b)); \
return Name(a, make_const(DataType::Float(64), b)); \
}

#define TVM_DEFINE_LOGICAL_OP_CONST_VAL_OVERLOAD(Name) \
Expand All @@ -722,10 +736,10 @@ inline Expr make_zero(Type t) {

#define TVM_DEFINE_INT_OP_CONST_VAL_OVERLOAD(Name) \
inline Expr Name(const Expr& a, int b) { \
return Name(a, make_const(a.type(), b)); \
return Name(a, make_const(a.dtype(), b)); \
} \
inline Expr Name(int a, const Expr& b) { \
return Name(make_const(b.type(), a), b); \
return Name(make_const(b.dtype(), a), b); \
}


Expand Down