Skip to content

Commit

Permalink
Implement jit for comparisons, except for (double, int).
Browse files Browse the repository at this point in the history
That one has some edge cases which I can't be bothered to code.
  • Loading branch information
pyos committed May 10, 2018
1 parent bd332b9 commit 6d2259f
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 17 deletions.
44 changes: 27 additions & 17 deletions dbms/src/DataTypes/Native.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ static inline bool typeIsEither(const IDataType & type)
return (typeid_cast<const Ts *>(&type) || ...);
}

static inline bool typeIsSigned(const IDataType & type)
{
return typeIsEither<
DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64,
DataTypeFloat32, DataTypeFloat64,
DataTypeDate, DataTypeDateTime, DataTypeInterval
>(type);
}

static inline llvm::Type * toNativeType(llvm::IRBuilderBase & builder, const IDataType & type)
{
if (auto * nullable = typeid_cast<const DataTypeNullable *>(&type))
Expand Down Expand Up @@ -77,11 +86,26 @@ static inline llvm::Value * nativeBoolCast(llvm::IRBuilder<> & b, const DataType
throw Exception("Cannot cast non-number " + from->getName() + " to bool", ErrorCodes::NOT_IMPLEMENTED);
}

static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to)
static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, llvm::Type * to)
{
auto * n_from = value->getType();
if (n_from == to)
return value;
if (n_from->isIntegerTy() && to->isFloatingPointTy())
return typeIsSigned(*from) ? b.CreateSIToFP(value, to) : b.CreateUIToFP(value, to);
if (n_from->isFloatingPointTy() && to->isIntegerTy())
return typeIsSigned(*from) ? b.CreateFPToSI(value, to) : b.CreateFPToUI(value, to);
if (n_from->isIntegerTy() && to->isIntegerTy())
return b.CreateIntCast(value, to, typeIsSigned(*from));
if (n_from->isFloatingPointTy() && to->isFloatingPointTy())
return b.CreateFPCast(value, to);
throw Exception("Cannot cast " + from->getName() + " to requested type", ErrorCodes::NOT_IMPLEMENTED);
}

static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr & from, llvm::Value * value, const DataTypePtr & to)
{
auto * n_to = toNativeType(b, to);
if (n_from == n_to)
if (value->getType() == n_to)
return value;
if (from->isNullable() && to->isNullable())
{
Expand All @@ -95,21 +119,7 @@ static inline llvm::Value * nativeCast(llvm::IRBuilder<> & b, const DataTypePtr
auto * inner = nativeCast(b, from, value, removeNullable(to));
return b.CreateInsertValue(llvm::Constant::getNullValue(n_to), inner, {0});
}

bool is_signed = typeIsEither<
DataTypeInt8, DataTypeInt16, DataTypeInt32, DataTypeInt64,
DataTypeFloat32, DataTypeFloat64,
DataTypeDate, DataTypeDateTime, DataTypeInterval
>(*from);
if (n_from->isIntegerTy() && n_to->isFloatingPointTy())
return is_signed ? b.CreateSIToFP(value, n_to) : b.CreateUIToFP(value, n_to);
if (n_from->isFloatingPointTy() && n_to->isIntegerTy())
return is_signed ? b.CreateFPToSI(value, n_to) : b.CreateFPToUI(value, n_to);
if (n_from->isIntegerTy() && n_to->isIntegerTy())
return b.CreateIntCast(value, n_to, is_signed);
if (n_from->isFloatingPointTy() && n_to->isFloatingPointTy())
return b.CreateFPCast(value, n_to);
throw Exception("Cannot cast " + from->getName() + " to " + to->getName(), ErrorCodes::NOT_IMPLEMENTED);
return nativeCast(b, from, value, n_to);
}

}
Expand Down
77 changes: 77 additions & 0 deletions dbms/src/Functions/FunctionsComparison.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,26 @@ template <typename A, typename B> struct EqualsOp
using SymmetricOp = EqualsOp<B, A>;

static UInt8 apply(A a, B b) { return accurate::equalsOp(a, b); }

#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool /*is_signed*/)
{
return x->getType()->isIntegerTy() ? b.CreateICmpEQ(x, y) : b.CreateFCmpOEQ(x, y); /// qNaNs always compare false
}
#endif
};

template <typename A, typename B> struct NotEqualsOp
{
using SymmetricOp = NotEqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::notEqualsOp(a, b); }

#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool /*is_signed*/)
{
return x->getType()->isIntegerTy() ? b.CreateICmpNE(x, y) : b.CreateFCmpONE(x, y);
}
#endif
};

template <typename A, typename B> struct GreaterOp;
Expand All @@ -67,12 +81,26 @@ template <typename A, typename B> struct LessOp
{
using SymmetricOp = GreaterOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::lessOp(a, b); }

#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSLT(x, y) : b.CreateICmpULT(x, y)) : b.CreateFCmpOLT(x, y);
}
#endif
};

template <typename A, typename B> struct GreaterOp
{
using SymmetricOp = LessOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::greaterOp(a, b); }

#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSGT(x, y) : b.CreateICmpUGT(x, y)) : b.CreateFCmpOGT(x, y);
}
#endif
};

template <typename A, typename B> struct GreaterOrEqualsOp;
Expand All @@ -81,12 +109,26 @@ template <typename A, typename B> struct LessOrEqualsOp
{
using SymmetricOp = GreaterOrEqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::lessOrEqualsOp(a, b); }

#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSLE(x, y) : b.CreateICmpULE(x, y)) : b.CreateFCmpOLE(x, y);
}
#endif
};

template <typename A, typename B> struct GreaterOrEqualsOp
{
using SymmetricOp = LessOrEqualsOp<B, A>;
static UInt8 apply(A a, B b) { return accurate::greaterOrEqualsOp(a, b); }

#if USE_EMBEDDED_COMPILER
static llvm::Value * compile(llvm::IRBuilder<> & b, llvm::Value * x, llvm::Value * y, bool is_signed)
{
return x->getType()->isIntegerTy() ? (is_signed ? b.CreateICmpSGE(x, y) : b.CreateICmpUGE(x, y)) : b.CreateFCmpOGE(x, y);
}
#endif
};


Expand Down Expand Up @@ -1136,6 +1178,41 @@ class FunctionComparison : public IFunction
col_with_type_and_name_left.type, col_with_type_and_name_right.type,
left_is_num, input_rows_count);
}

#if USE_EMBEDDED_COMPILER
bool isCompilableImpl(const DataTypes & types) const override
{
auto isBigInteger = &typeIsEither<DataTypeInt64, DataTypeUInt64, DataTypeUUID>;
auto isFloatingPoint = &typeIsEither<DataTypeFloat32, DataTypeFloat64>;
if ((isBigInteger(*types[0]) && isFloatingPoint(*types[1])) || (isBigInteger(*types[1]) && isFloatingPoint(*types[0])))
return false; /// TODO: implement (double, int_N where N > double's mantissa width)
return types[0]->isValueRepresentedByNumber() && types[1]->isValueRepresentedByNumber();
}

llvm::Value * compileImpl(llvm::IRBuilderBase & builder, const DataTypes & types, ValuePlaceholders values) const override
{
auto & b = static_cast<llvm::IRBuilder<> &>(builder);
auto * x = values[0]();
auto * y = values[1]();
if (!types[0]->equals(*types[1]))
{
llvm::Type * common;
if (x->getType()->isIntegerTy() && y->getType()->isIntegerTy())
common = b.getIntNTy(std::max(
/// if one integer has a sign bit, make sure the other does as well. llvm generates optimal code
/// (e.g. uses overflow flag on x86) for (word size + 1)-bit integer operations.
x->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[0]) && typeIsSigned(*types[1])),
y->getType()->getIntegerBitWidth() + (!typeIsSigned(*types[1]) && typeIsSigned(*types[0]))));
else
/// (double, float) or (double, int_N where N <= double's mantissa width) -> double
common = b.getDoubleTy();
x = nativeCast(b, types[0], x, common);
y = nativeCast(b, types[1], y, common);
}
auto * result = Op<int, int>::compile(b, x, y, typeIsSigned(*types[0]) || typeIsSigned(*types[1]));
return b.CreateSelect(result, b.getInt8(1), b.getInt8(0));
}
#endif
};


Expand Down

0 comments on commit 6d2259f

Please sign in to comment.