Skip to content

Commit

Permalink
copr: Add non-vectorized repeat (tikv#5885)
Browse files Browse the repository at this point in the history
Signed-off-by: Yuning Zhang <codeworm96@outlook.com>

Co-authored-by: NingLin-P <linningde25@gmail.com>
  • Loading branch information
codeworm96 and NingLin-P committed Mar 9, 2020
1 parent cc42437 commit 10b9209
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
79 changes: 79 additions & 0 deletions components/tidb_query/src/expr/builtin_string.rs
Expand Up @@ -235,6 +235,32 @@ impl ScalarFunc {
}
}

// see https://dev.mysql.com/doc/refman/5.7/en/string-functions.html#function_repeat
#[inline]
pub fn repeat<'a, 'b: 'a>(
&'b self,
ctx: &mut EvalContext,
row: &'a [Datum],
) -> Result<Option<Cow<'a, [u8]>>> {
let val = try_opt!(self.children[0].eval_string(ctx, row));
let num = try_opt!(self.children[1].eval_int(ctx, row));
let count = if num > std::i32::MAX.into() {
std::i32::MAX.into()
} else {
num
};

match count.cmp(&1) {
Ordering::Less => Ok(Some(Cow::Borrowed(b""))),

// return the input string when count is 1 to save one copy
Ordering::Equal => Ok(Some(val)),

// here count > 1, so convert it into usize should be ok
Ordering::Greater => Ok(Some(Cow::Owned(val.repeat(count as usize)))),
}
}

pub fn replace<'a, 'b: 'a>(
&'b self,
ctx: &mut EvalContext,
Expand Down Expand Up @@ -1552,6 +1578,59 @@ mod tests {
assert_eq!(got, exp, "rtrim(NULL)");
}

#[test]
fn test_repeat() {
let cases = vec![
("hello, world!", -1, ""),
("hello, world!", 0, ""),
("hello, world!", 1, "hello, world!"),
(
"hello, world!",
3,
"hello, world!hello, world!hello, world!",
),
("你好世界", 3, "你好世界你好世界你好世界"),
("こんにちは", 2, "こんにちはこんにちは"),
("\x2f\x35", 5, "\x2f\x35\x2f\x35\x2f\x35\x2f\x35\x2f\x35"),
];

let mut ctx = EvalContext::default();
for (input_str, input_count, expected) in cases {
let s = datum_expr(Datum::Bytes(input_str.as_bytes().to_vec()));
let count = datum_expr(Datum::I64(input_count));
let op = scalar_func_expr(ScalarFuncSig::Repeat, &[s, count]);
let op = Expression::build(&mut ctx, op).unwrap();
let got = op.eval(&mut ctx, &[]).unwrap();
let expected = Datum::Bytes(expected.as_bytes().to_vec());
assert_eq!(got, expected, "repeat({:?}, {:?})", input_str, input_count);
}

// test NULL case
let null = datum_expr(Datum::Null);
let count = datum_expr(Datum::I64(42));
let op = scalar_func_expr(ScalarFuncSig::Repeat, &[null, count]);
let op = Expression::build(&mut ctx, op).unwrap();
let got = op.eval(&mut ctx, &[]).unwrap();
let expected = Datum::Null;
assert_eq!(got, expected, "repeat(NULL, count)");

let null = datum_expr(Datum::Null);
let s = datum_expr(Datum::Bytes(b"hi".to_vec()));
let op = scalar_func_expr(ScalarFuncSig::Repeat, &[s, null]);
let op = Expression::build(&mut ctx, op).unwrap();
let got = op.eval(&mut ctx, &[]).unwrap();
let expected = Datum::Null;
assert_eq!(got, expected, "repeat(s, NULL)");

let null1 = datum_expr(Datum::Null);
let null2 = datum_expr(Datum::Null);
let op = scalar_func_expr(ScalarFuncSig::Repeat, &[null1, null2]);
let op = Expression::build(&mut ctx, op).unwrap();
let got = op.eval(&mut ctx, &[]).unwrap();
let expected = Datum::Null;
assert_eq!(got, expected, "repeat(NULL, NULL)");
}

#[test]
fn test_reverse_utf8() {
let cases = vec![
Expand Down
4 changes: 2 additions & 2 deletions components/tidb_query/src/expr/scalar_function.rs
Expand Up @@ -119,6 +119,7 @@ impl ScalarFunc {
| ScalarFuncSig::Trim2Args
| ScalarFuncSig::Substring2ArgsUtf8
| ScalarFuncSig::Substring2Args
| ScalarFuncSig::Repeat
| ScalarFuncSig::DateDiff
| ScalarFuncSig::AddDatetimeAndDuration
| ScalarFuncSig::AddDatetimeAndString
Expand Down Expand Up @@ -469,7 +470,6 @@ impl ScalarFunc {
| ScalarFuncSig::Password
| ScalarFuncSig::Quarter
| ScalarFuncSig::ReleaseLock
| ScalarFuncSig::Repeat
| ScalarFuncSig::RowCount
| ScalarFuncSig::RowSig
| ScalarFuncSig::SecToTime
Expand Down Expand Up @@ -1005,6 +1005,7 @@ dispatch_call! {
LTrim => ltrim,
RTrim => rtrim,
ReverseUtf8 => reverse_utf8,
Repeat => repeat,
Reverse => reverse,
HexIntArg => hex_int_arg,
HexStrArg => hex_str_arg,
Expand Down Expand Up @@ -1627,7 +1628,6 @@ mod tests {
ScalarFuncSig::Password,
ScalarFuncSig::Quarter,
ScalarFuncSig::ReleaseLock,
ScalarFuncSig::Repeat,
ScalarFuncSig::RowCount,
ScalarFuncSig::RowSig,
ScalarFuncSig::SecToTime,
Expand Down

0 comments on commit 10b9209

Please sign in to comment.