Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/unicode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn test_unicode_expressions() -> Result<()> {
test_expression!("lpad(NULL, 0)", "NULL");
test_expression!("lpad(NULL, 5, 'xy')", "NULL");
test_expression!("reverse('abcde')", "edcba");
test_expression!("reverse('loẅks')", "skẅol");
test_expression!("reverse('loẅks')", "sk̈wol"); // Compatible with PostgreSQL
test_expression!("reverse(NULL)", "NULL");
test_expression!("right('abcde', -2)", "cde");
test_expression!("right('abcde', -200)", "");
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@ mod tests {
test_function!(
Reverse,
&[lit("loẅks")],
Ok(Some("skẅol")),
Ok(Some("sk̈wol")),
&str,
Utf8,
StringArray
Expand All @@ -1653,7 +1653,7 @@ mod tests {
test_function!(
Reverse,
&[lit("loẅks")],
Ok(Some("skẅol")),
Ok(Some("sk̈wol")),
&str,
Utf8,
StringArray
Expand Down
136 changes: 38 additions & 98 deletions datafusion/physical-expr/src/unicode_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ use arrow::{
};
use datafusion_common::{DataFusionError, Result};
use hashbrown::HashMap;
use std::any::type_name;
use std::cmp::Ordering;
use std::sync::Arc;
use std::{any::type_name, cmp::max};
use unicode_segmentation::UnicodeSegmentation;

macro_rules! downcast_string_arg {
Expand Down Expand Up @@ -60,6 +60,7 @@ macro_rules! downcast_arg {

/// Returns number of characters in the string.
/// character_length('josé') = 4
/// The implementation counts UTF-8 code points to count the number of characters
pub fn character_length<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
Expand All @@ -75,9 +76,8 @@ where
.iter()
.map(|string| {
string.map(|string: &str| {
T::Native::from_usize(string.graphemes(true).count()).expect(
"should not fail as graphemes.count will always return integer",
)
T::Native::from_usize(string.chars().count())
.expect("should not fail as string.chars will always return integer")
})
})
.collect::<PrimitiveArray<T>>();
Expand All @@ -87,6 +87,7 @@ where

/// Returns first n characters in the string, or when n is negative, returns all but last |n| characters.
/// left('abcde', 2) = 'ab'
/// The implementation uses UTF-8 code points as characters
pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
let n_array = downcast_arg!(args[1], "n", Int64Array);
Expand All @@ -96,19 +97,16 @@ pub fn left<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
.map(|(string, n)| match (string, n) {
(Some(string), Some(n)) => match n.cmp(&0) {
Ordering::Less => {
let graphemes = string.graphemes(true);
let len = graphemes.clone().count() as i64;
match n.abs().cmp(&len) {
Ordering::Less => {
Some(graphemes.take((len + n) as usize).collect::<String>())
}
Ordering::Equal => Some("".to_string()),
Ordering::Greater => Some("".to_string()),
}
let len = string.chars().count() as i64;
Some(if n.abs() < len {
string.chars().take((len + n) as usize).collect::<String>()
} else {
"".to_string()
})
}
Ordering::Equal => Some("".to_string()),
Ordering::Greater => {
Some(string.graphemes(true).take(n as usize).collect::<String>())
Some(string.chars().take(n as usize).collect::<String>())
}
},
_ => None,
Expand Down Expand Up @@ -139,11 +137,8 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if length < graphemes.len() {
Some(graphemes[..length].concat())
} else {
let mut s = string.to_string();
s.insert_str(
0,
" ".repeat(length - graphemes.len()).as_str(),
);
let mut s: String = " ".repeat(length - graphemes.len());
s.push_str(string);
Some(s)
}
}
Expand Down Expand Up @@ -209,21 +204,21 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Reverses the order of the characters in the string.
/// reverse('abcde') = 'edcba'
/// The implementation uses UTF-8 code points as characters
pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);

let result = string_array
.iter()
.map(|string| {
string.map(|string: &str| string.graphemes(true).rev().collect::<String>())
})
.map(|string| string.map(|string: &str| string.chars().rev().collect::<String>()))
.collect::<GenericStringArray<T>>();

Ok(Arc::new(result) as ArrayRef)
}

/// Returns last n characters in the string, or when n is negative, returns all but first |n| characters.
/// right('abcde', 2) = 'de'
/// The implementation uses UTF-8 code points as characters
pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let string_array = downcast_string_arg!(args[0], "string", T);
let n_array = downcast_arg!(args[1], "n", Int64Array);
Expand All @@ -233,33 +228,17 @@ pub fn right<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
.zip(n_array.iter())
.map(|(string, n)| match (string, n) {
(Some(string), Some(n)) => match n.cmp(&0) {
Ordering::Less => {
let graphemes = string.graphemes(true).rev();
let len = graphemes.clone().count() as i64;
match n.abs().cmp(&len) {
Ordering::Less => Some(
graphemes
.take((len + n) as usize)
.collect::<Vec<&str>>()
.iter()
.rev()
.copied()
.collect::<String>(),
),
Ordering::Equal => Some("".to_string()),
Ordering::Greater => Some("".to_string()),
}
}
Ordering::Less => Some(
string
.chars()
.skip(n.unsigned_abs() as usize)
.collect::<String>(),
),
Ordering::Equal => Some("".to_string()),
Ordering::Greater => Some(
string
.graphemes(true)
.rev()
.take(n as usize)
.collect::<Vec<&str>>()
.iter()
.rev()
.copied()
.chars()
.skip(max(string.chars().count() as i64 - n, 0) as usize)
.collect::<String>(),
),
},
Expand Down Expand Up @@ -349,6 +328,7 @@ pub fn rpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {

/// Returns starting index of specified substring within string, or zero if it's not present. (Same as position(substring in string), but note the reversed argument order.)
/// strpos('high', 'ig') = 2
/// The implementation uses UTF-8 code points as characters
pub fn strpos<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
where
T::Native: OffsetSizeTrait,
Expand All @@ -374,28 +354,13 @@ where
.zip(substring_array.iter())
.map(|(string, substring)| match (string, substring) {
(Some(string), Some(substring)) => {
// the rfind method returns the byte index of the substring which may or may not be the same as the character index due to UTF8 encoding
// this method first finds the matching byte using rfind
// then maps that to the character index by matching on the grapheme_index of the byte_index
Some(
T::Native::from_usize(string.to_string().rfind(substring).map_or(
0,
|byte_offset| {
string
.grapheme_indices(true)
.collect::<Vec<(usize, &str)>>()
.iter()
.enumerate()
.filter(|(_, (offset, _))| *offset == byte_offset)
.map(|(index, _)| index)
.collect::<Vec<usize>>()
.first()
.expect("should not fail as grapheme_indices and byte offsets are tightly coupled")
.to_owned()
+ 1
},
))
.expect("should not fail due to map_or default value")
// the find method returns the byte index of the substring
// Next, we count the number of the chars until that byte
T::Native::from_usize(
string
.find(substring)
.map(|x| string[..x].chars().count() + 1)
.unwrap_or(0),
)
}
_ => None,
Expand All @@ -408,6 +373,7 @@ where
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
/// substr('alphabet', 3) = 'phabet'
/// substr('alphabet', 3, 2) = 'ph'
/// The implementation uses UTF-8 code points as characters
pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
Expand All @@ -422,13 +388,7 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if start <= 0 {
Some(string.to_string())
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
let start_pos = start as usize - 1;
if graphemes.len() < start_pos {
Some("".to_string())
} else {
Some(graphemes[start_pos..].concat())
}
Some(string.chars().skip(start as usize - 1).collect())
}
}
_ => None,
Expand All @@ -455,29 +415,9 @@ pub fn substr<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
count
)))
} else {
let graphemes = string.graphemes(true).collect::<Vec<&str>>();
let (start_pos, end_pos) = if start <= 0 {
let end_pos = start + count - 1;
(
0_usize,
if end_pos < 0 {
// we use 0 as workaround for usize to return empty string
0
} else {
end_pos as usize
},
)
} else {
((start - 1) as usize, (start + count - 1) as usize)
};

if end_pos == 0 || graphemes.len() < start_pos {
Ok(Some("".to_string()))
} else if graphemes.len() < end_pos {
Ok(Some(graphemes[start_pos..].concat()))
} else {
Ok(Some(graphemes[start_pos..end_pos].concat()))
}
let skip = max(0, start - 1);
let count = max(0, count + (if start < 1 {start - 1} else {0}));
Ok(Some(string.chars().skip(skip as usize).take(count as usize).collect::<String>()))
}
}
_ => Ok(None),
Expand Down
17 changes: 17 additions & 0 deletions integration-tests/sqls/character_length.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
-- Licensed to the Apache Software Foundation (ASF) under one
-- or more contributor license agreements. See the NOTICE file
-- distributed with this work for additional information
-- regarding copyright ownership. The ASF licenses this file
-- to you under the Apache License, Version 2.0 (the
-- "License"); you may not use this file except in compliance
-- with the License. You may obtain a copy of the License at

-- http://www.apache.org/licenses/LICENSE-2.0

-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.

select length('ä');
2 changes: 1 addition & 1 deletion integration-tests/test_psql_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def generate_csv_from_psql(fname: str):

class TestPsqlParity:
def test_tests_count(self):
assert len(test_files) == 21, "tests are missed"
assert len(test_files) == 22, "tests are missed"

@pytest.mark.parametrize("fname", test_files)
def test_sql_file(self, fname):
Expand Down