Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for SQL bitwise operators on integer types #8350

Merged
merged 1 commit into from Sep 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion doc/user/content/release-notes.md
Expand Up @@ -48,7 +48,9 @@ Wrap your release notes at the 80 character mark.

{{% version-header v0.9.5 %}}

- Timezone parsing is now case insensitive to be compatible with PostgreSQL
- Timezone parsing is now case insensitive to be compatible with PostgreSQL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra new line here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the formatting from the last few versions (line items are separated by a new line). I now realize this is not correct markdown, though (it will be interpreted as a sequence of singleton lists, so maybe we should omit the newlines as you are suggesting.

- Add support for [bitwise operators on integers](/sql/functions/#numbers).

{{% version-header v0.9.4 %}}

Expand Down
6 changes: 6 additions & 0 deletions doc/user/content/sql/functions/_index.md
Expand Up @@ -54,6 +54,12 @@ Operator | Computes
`*` | Multiplication
`/` | Division
`%` | Modulo
`&` | Bitwise AND
`|` | Bitwise OR
`#` | Bitwise XOR
`~` | Bitwise NOT
aalexandrov marked this conversation as resolved.
Show resolved Hide resolved
`<<`| Bitwise left shift
`>>`| Bitwise right shift

### String

Expand Down
191 changes: 186 additions & 5 deletions src/expr/src/scalar/func.rs
Expand Up @@ -1166,6 +1166,96 @@ fn add_interval<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError>
.map(Datum::from)
}

fn bit_not_int16<'a>(a: Datum<'a>) -> Datum<'a> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the unary funcs be expressed as a macro? Not sure tbh, but maybe worth checking!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably, but this issue is more general, as most functions come in families indexed by the possible combination of types (e.g. see neg_int{XX}, abs_int{XX}, ...). I agree that we should think how we can minimize code for this functions, but I prefer to not tackle this as part of my starter issue. We could use a macro, think about how to utilize generics in a better way, or choose another zero cost abstraction.

I'm happy to open a follow-up issue if more people agree that this is a problem.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about something like this:

sqlfunc!(
    fn bit_not_int16(a: i16) -> i16 {
        !a
    }
);

Elsewhere, there seems to be a TryFrom missing:

impl TryFrom<Datum<'_>> for i16 {
    type Error = ();
    fn try_from(from: Datum<'_>) -> Result<Self, Self::Error> {
        match from {
            Datum::Int16(f) => Ok(f),
            _ => Err(()),
        }
    }
}

impl TryFrom<Datum<'_>> for Option<i16> {
    type Error = ();
    fn try_from(from: Datum<'_>) -> Result<Self, Self::Error> {
        match from {
            Datum::Null => Ok(None),
            Datum::Int16(f) => Ok(Some(f)),
            _ => Err(()),
        }
    }
}

But, totally up to you :D

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the sqlfunc! macro for this unary function is totally doable and totally optional since this is your first issue. We're in the process of transitioning the traditional declarations (like the one you provided for bit_not_int16) to use sqlfunc!. The plan is to transition all functions (even binary and variadic), but for now only unarfy functions are supported by the macro.

So, if you want to merge as is, that's totally fine, I will include this function in a future PR that will migrate all the int16 functions whole sale.

If you want to dip your toes in the new way of defining unary functions you can use @antiguru's snippets above and a commit like this one for reference: 88a1443

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, if you want to merge as is, that's totally fine, I will include this function in a future PR that will migrate all the int16 functions whole sale.

@petrosagg: I like the approach outlined in #7704! Unfortunately, missed that and started working on a this feature shortly before the PR was merged, so I prefer to merge the current PR without further. I am happy take off some of the refactoring work (including the ~ operator) in the next days.

Elsewhere, there seems to be a TryFrom missing:

@antiguru Is the above somehow affecting the correctness of the current implementation and actionable on my end?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, TryFrom is only needed if you go down the sqlfunc! path. Don't worry, I'm slowly converting all the functions in batches, let's merge this as is :)

Datum::from(!a.unwrap_int16())
}

fn bit_not_int32<'a>(a: Datum<'a>) -> Datum<'a> {
Datum::from(!a.unwrap_int32())
}

fn bit_not_int64<'a>(a: Datum<'a>) -> Datum<'a> {
Datum::from(!a.unwrap_int64())
}

fn bit_and_int16<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int16() & b.unwrap_int16())
}

fn bit_and_int32<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int32() & b.unwrap_int32())
}

fn bit_and_int64<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int64() & b.unwrap_int64())
}

fn bit_or_int16<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int16() | b.unwrap_int16())
}

fn bit_or_int32<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int32() | b.unwrap_int32())
}

fn bit_or_int64<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int64() | b.unwrap_int64())
}

fn bit_xor_int16<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int16() ^ b.unwrap_int16())
}

fn bit_xor_int32<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int32() ^ b.unwrap_int32())
}

fn bit_xor_int64<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
Datum::from(a.unwrap_int64() ^ b.unwrap_int64())
}

fn bit_shift_left_int16<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
// widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
// when the rhs in the 16-31 range, e.g. (1 << 17 should evaluate to 0)
// see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
let lhs: i32 = a.unwrap_int16() as i32;
let rhs: u32 = b.unwrap_int32() as u32;
Datum::from(lhs.wrapping_shl(rhs) as i16)
}

fn bit_shift_left_int32<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
let lhs = a.unwrap_int32();
let rhs = b.unwrap_int32() as u32;
Datum::from(lhs.wrapping_shl(rhs))
}

fn bit_shift_left_int64<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
let lhs = a.unwrap_int64();
let rhs = b.unwrap_int32() as u32;
Datum::from(lhs.wrapping_shl(rhs))
}

fn bit_shift_right_int16<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
// widen to i32 and then cast back to i16 in order emulate the C promotion rules used in by Postgres
// when the rhs in the 16-31 range, e.g. (-32767 >> 17 should evaluate to -1)
// see https://github.com/postgres/postgres/blob/REL_14_STABLE/src/backend/utils/adt/int.c#L1460-L1476
let lhs = a.unwrap_int16() as i32;
let rhs = b.unwrap_int32() as u32;
Datum::from(lhs.wrapping_shr(rhs) as i16)
}

fn bit_shift_right_int32<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
let lhs = a.unwrap_int32();
let rhs = b.unwrap_int32() as u32;
Datum::from(lhs.wrapping_shr(rhs))
}

fn bit_shift_right_int64<'a>(a: Datum<'a>, b: Datum<'a>) -> Datum<'a> {
let lhs = a.unwrap_int64();
let rhs = b.unwrap_int32() as u32;
Datum::from(lhs.wrapping_shr(rhs))
}

fn sub_int16<'a>(a: Datum<'a>, b: Datum<'a>) -> Result<Datum<'a>, EvalError> {
a.unwrap_int16()
.checked_sub(b.unwrap_int16())
Expand Down Expand Up @@ -2595,6 +2685,21 @@ pub enum BinaryFunc {
AddDateTime,
AddTimeInterval,
AddNumeric,
BitAndInt16,
BitAndInt32,
BitAndInt64,
BitOrInt16,
BitOrInt32,
BitOrInt64,
BitXorInt16,
BitXorInt32,
BitXorInt64,
BitShiftLeftInt16,
BitShiftLeftInt32,
BitShiftLeftInt64,
BitShiftRightInt16,
BitShiftRightInt32,
BitShiftRightInt64,
SubInt16,
SubInt32,
SubInt64,
Expand Down Expand Up @@ -2730,6 +2835,21 @@ impl BinaryFunc {
BinaryFunc::AddTimeInterval => Ok(eager!(add_time_interval)),
BinaryFunc::AddNumeric => eager!(add_numeric),
BinaryFunc::AddInterval => eager!(add_interval),
BinaryFunc::BitAndInt16 => Ok(eager!(bit_and_int16)),
BinaryFunc::BitAndInt32 => Ok(eager!(bit_and_int32)),
BinaryFunc::BitAndInt64 => Ok(eager!(bit_and_int64)),
BinaryFunc::BitOrInt16 => Ok(eager!(bit_or_int16)),
BinaryFunc::BitOrInt32 => Ok(eager!(bit_or_int32)),
BinaryFunc::BitOrInt64 => Ok(eager!(bit_or_int64)),
BinaryFunc::BitXorInt16 => Ok(eager!(bit_xor_int16)),
BinaryFunc::BitXorInt32 => Ok(eager!(bit_xor_int32)),
BinaryFunc::BitXorInt64 => Ok(eager!(bit_xor_int64)),
BinaryFunc::BitShiftLeftInt16 => Ok(eager!(bit_shift_left_int16)),
BinaryFunc::BitShiftLeftInt32 => Ok(eager!(bit_shift_left_int32)),
BinaryFunc::BitShiftLeftInt64 => Ok(eager!(bit_shift_left_int64)),
BinaryFunc::BitShiftRightInt16 => Ok(eager!(bit_shift_right_int16)),
BinaryFunc::BitShiftRightInt32 => Ok(eager!(bit_shift_right_int32)),
BinaryFunc::BitShiftRightInt64 => Ok(eager!(bit_shift_right_int64)),
BinaryFunc::SubInt16 => eager!(sub_int16),
BinaryFunc::SubInt32 => eager!(sub_int32),
BinaryFunc::SubInt64 => eager!(sub_int64),
Expand Down Expand Up @@ -2897,7 +3017,8 @@ impl BinaryFunc {
ToCharTimestamp | ToCharTimestampTz | ConvertFrom | Left | Right | Trim
| TrimLeading | TrimTrailing => ScalarType::String.nullable(in_nullable),

AddInt16 | SubInt16 | MulInt16 | DivInt16 | ModInt16 => {
AddInt16 | SubInt16 | MulInt16 | DivInt16 | ModInt16 | BitAndInt16 | BitOrInt16
| BitXorInt16 | BitShiftLeftInt16 | BitShiftRightInt16 => {
ScalarType::Int16.nullable(in_nullable || is_div_mod)
}

Expand All @@ -2906,10 +3027,16 @@ impl BinaryFunc {
| MulInt32
| DivInt32
| ModInt32
| BitAndInt32
| BitOrInt32
| BitXorInt32
| BitShiftLeftInt32
| BitShiftRightInt32
| EncodedBytesCharLength
| SubDate => ScalarType::Int32.nullable(in_nullable || is_div_mod),

AddInt64 | SubInt64 | MulInt64 | DivInt64 | ModInt64 => {
AddInt64 | SubInt64 | MulInt64 | DivInt64 | ModInt64 | BitAndInt64 | BitOrInt64
| BitXorInt64 | BitShiftLeftInt64 | BitShiftRightInt64 => {
ScalarType::Int64.nullable(in_nullable || is_div_mod)
}

Expand Down Expand Up @@ -3059,6 +3186,21 @@ impl BinaryFunc {
| AddDateInterval
| AddTimeInterval
| AddInterval
| BitAndInt16
| BitAndInt32
| BitAndInt64
| BitOrInt16
| BitOrInt32
| BitOrInt64
| BitXorInt16
| BitXorInt32
| BitXorInt64
| BitShiftLeftInt16
| BitShiftLeftInt32
| BitShiftLeftInt64
| BitShiftRightInt16
| BitShiftRightInt32
| BitShiftRightInt64
| SubInterval
| MulInterval
| DivInterval
Expand Down Expand Up @@ -3113,6 +3255,21 @@ impl BinaryFunc {
| AddDateInterval
| AddTimeInterval
| AddInterval
| BitAndInt16
| BitAndInt32
| BitAndInt64
| BitOrInt16
| BitOrInt32
| BitOrInt64
| BitXorInt16
| BitXorInt32
| BitXorInt64
| BitShiftLeftInt16
| BitShiftLeftInt32
| BitShiftLeftInt64
| BitShiftRightInt16
| BitShiftRightInt32
| BitShiftRightInt64
| SubInterval
| MulInterval
| DivInterval
Expand Down Expand Up @@ -3247,6 +3404,21 @@ impl fmt::Display for BinaryFunc {
BinaryFunc::AddDateTime => f.write_str("+"),
BinaryFunc::AddDateInterval => f.write_str("+"),
BinaryFunc::AddTimeInterval => f.write_str("+"),
BinaryFunc::BitAndInt16 => f.write_str("&"),
BinaryFunc::BitAndInt32 => f.write_str("&"),
BinaryFunc::BitAndInt64 => f.write_str("&"),
BinaryFunc::BitOrInt16 => f.write_str("|"),
BinaryFunc::BitOrInt32 => f.write_str("|"),
BinaryFunc::BitOrInt64 => f.write_str("|"),
BinaryFunc::BitXorInt16 => f.write_str("#"),
BinaryFunc::BitXorInt32 => f.write_str("#"),
BinaryFunc::BitXorInt64 => f.write_str("#"),
BinaryFunc::BitShiftLeftInt16 => f.write_str("<<"),
BinaryFunc::BitShiftLeftInt32 => f.write_str("<<"),
BinaryFunc::BitShiftLeftInt64 => f.write_str("<<"),
BinaryFunc::BitShiftRightInt16 => f.write_str(">>"),
BinaryFunc::BitShiftRightInt32 => f.write_str(">>"),
BinaryFunc::BitShiftRightInt64 => f.write_str(">>"),
BinaryFunc::SubInt16 => f.write_str("-"),
BinaryFunc::SubInt32 => f.write_str("-"),
BinaryFunc::SubInt64 => f.write_str("-"),
Expand Down Expand Up @@ -3379,6 +3551,9 @@ trait UnaryFuncTrait {
pub enum UnaryFunc {
Not(Not),
IsNull(IsNull),
BitNotInt16,
BitNotInt32,
BitNotInt64,
NegInt16,
NegInt32,
NegInt64,
Expand Down Expand Up @@ -3658,6 +3833,9 @@ impl UnaryFunc {
| MzRowSize(_)
| IsNull(_)
| CastFloat32ToFloat64(_) => unreachable!(),
BitNotInt16 => Ok(bit_not_int16(a)),
BitNotInt32 => Ok(bit_not_int32(a)),
BitNotInt64 => Ok(bit_not_int64(a)),
NegInt16 => Ok(neg_int16(a)),
NegInt32 => Ok(neg_int32(a)),
NegInt64 => Ok(neg_int64(a)),
Expand Down Expand Up @@ -3971,9 +4149,8 @@ impl UnaryFunc {
ScalarType::VarChar { length: *length }.nullable(nullable)
}

NegInt16 | NegInt32 | NegInt64 | NegInterval | AbsInt16 | AbsInt32 | AbsInt64 => {
input_type
}
BitNotInt16 | BitNotInt32 | BitNotInt64 | NegInt16 | NegInt32 | NegInt64
| NegInterval | AbsInt16 | AbsInt32 | AbsInt64 => input_type,

DatePartInterval(_) | DatePartTimestamp(_) | DatePartTimestampTz(_) => {
ScalarType::Float64.nullable(nullable)
Expand Down Expand Up @@ -4145,6 +4322,7 @@ impl UnaryFunc {
DatePartInterval(_) | DatePartTimestamp(_) | DatePartTimestampTz(_) => false,
DateTruncTimestamp(_) | DateTruncTimestampTz(_) => false,
NegInt16 | NegInt32 | NegInt64 | NegInterval | AbsInt16 | AbsInt32 | AbsInt64 => false,
BitNotInt16 | BitNotInt32 | BitNotInt64 => false,
Log10 | Ln | Exp | Cos | Cosh | Sin | Sinh | Tan | Tanh | Cot | SqrtFloat64
| CbrtFloat64 => false,
AbsNumeric | CeilNumeric | ExpNumeric | FloorNumeric | LnNumeric | Log10Numeric
Expand Down Expand Up @@ -4233,6 +4411,9 @@ impl UnaryFunc {
| MzRowSize(_)
| IsNull(_)
| CastFloat32ToFloat64(_) => unreachable!(),
BitNotInt16 => f.write_str("~"),
BitNotInt32 => f.write_str("~"),
BitNotInt64 => f.write_str("~"),
NegInt16 => f.write_str("-"),
NegInt32 => f.write_str("-"),
NegInt64 => f.write_str("-"),
Expand Down
5 changes: 5 additions & 0 deletions src/sql-parser/src/parser.rs
Expand Up @@ -382,6 +382,11 @@ impl<'a> Parser<'a> {
expr1: Box::new(self.parse_subexpr(Precedence::PrefixPlusMinus)?),
expr2: None,
}),
Token::Op(op) if op == "~" => Ok(Expr::Op {
op,
expr1: Box::new(self.parse_subexpr(Precedence::Other)?),
expr2: None,
}),
Token::Number(_) | Token::String(_) | Token::HexString(_) => {
self.prev_token();
Ok(Expr::Value(self.parse_value()?))
Expand Down
50 changes: 44 additions & 6 deletions src/sql/src/func.rs
Expand Up @@ -2226,11 +2226,21 @@ lazy_static! {
builtins! {
// Literal OIDs collected from PG 13 using a version of this query
// ```sql
// SELECT oid, oprname, oprleft::regtype, oprright::regtype
// FROM pg_operator
// WHERE oprname IN (
// '+', '-', '*', '/', '%', '~~', '!~~', '~'
// );
// SELECT
// oid,
// oprname,
// oprleft::regtype,
// oprright::regtype
// FROM
// pg_operator
// WHERE
// oprname IN (
// '+', '-', '*', '/', '%',
// '|', '&', '#', '~', '<<', '>>',
// '~~', '!~~'
// )
// ORDER BY
// oprname;
// ```
// Values are also available through
// https://github.com/postgres/postgres/blob/master/src/include/catalog/pg_operator.dat
Expand Down Expand Up @@ -2336,6 +2346,31 @@ lazy_static! {
params!(Float64, Float64) => ModFloat64, oid::OP_MOD_F64_OID;
params!(Numeric, Numeric) => ModNumeric, 1762;
},
"&" => Scalar {
params!(Int16, Int16) => BitAndInt16, 1874;
params!(Int32, Int32) => BitAndInt32, 1880;
params!(Int64, Int64) => BitAndInt64, 1886;
},
"|" => Scalar {
params!(Int16, Int16) => BitOrInt16, 1875;
params!(Int32, Int32) => BitOrInt32, 1881;
params!(Int64, Int64) => BitOrInt64, 1887;
},
"#" => Scalar {
params!(Int16, Int16) => BitXorInt16, 1876;
params!(Int32, Int32) => BitXorInt32, 1882;
params!(Int64, Int64) => BitXorInt64, 1888;
},
"<<" => Scalar {
params!(Int16, Int32) => BitShiftLeftInt16, 1878;
params!(Int32, Int32) => BitShiftLeftInt32, 1884;
params!(Int64, Int32) => BitShiftLeftInt64, 1890;
},
">>" => Scalar {
params!(Int16, Int32) => BitShiftRightInt16, 1879;
params!(Int32, Int32) => BitShiftRightInt32, 1885;
params!(Int64, Int32) => BitShiftRightInt64, 1891;
},

// ILIKE
"~~*" => Scalar {
Expand Down Expand Up @@ -2377,6 +2412,9 @@ lazy_static! {

// REGEX
"~" => Scalar {
params!(Int16) => UnaryFunc::BitNotInt16, 1877;
params!(Int32) => UnaryFunc::BitNotInt32, 1883;
params!(Int64) => UnaryFunc::BitNotInt64, 1889;
params!(String, String) => IsRegexpMatch { case_insensitive: false }, 641;
params!(Char, String) => Operation::binary(|ecx, lhs, rhs| {
let length = ecx.scalar_type(&lhs).unwrap_char_varchar_length();
Expand Down Expand Up @@ -2665,6 +2703,6 @@ pub fn resolve_op(op: &str) -> Result<&'static [FuncImpl<HirScalarExpr>], anyhow
// JsonDeletePath
// JsonContainsPath
// JsonApplyPathPredicate
None => bail_unsupported!(op),
None => bail_unsupported!(format!("[{}]", op)),
}
}