Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick Support read decimal data from csv reader if user provide the schema with decimal data type to active_release #974

Merged
merged 1 commit into from
Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ pub struct FixedSizeBinaryBuilder {
builder: FixedSizeListBuilder<UInt8Builder>,
}

const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
9,
99,
999,
Expand Down Expand Up @@ -1158,7 +1158,7 @@ const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
9999999999999999999999999999999999999,
170141183460469231731687303715884105727,
];
const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [
-9,
-99,
-999,
Expand Down
2 changes: 2 additions & 0 deletions arrow/src/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ pub use self::builder::StringBuilder;
pub use self::builder::StringDictionaryBuilder;
pub use self::builder::StructBuilder;
pub use self::builder::UnionBuilder;
pub use self::builder::MAX_DECIMAL_FOR_EACH_PRECISION;
pub use self::builder::MIN_DECIMAL_FOR_EACH_PRECISION;

pub type Int8Builder = PrimitiveBuilder<Int8Type>;
pub type Int16Builder = PrimitiveBuilder<Int16Type>;
Expand Down
263 changes: 261 additions & 2 deletions arrow/src/csv/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,20 @@ use std::io::{Read, Seek, SeekFrom};
use std::sync::Arc;

use crate::array::{
ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray,
ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray, StringArray,
MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
};
use crate::compute::kernels::cast_utils::string_to_timestamp_nanos;
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;

use csv_crate::{ByteRecord, StringRecord};
use std::ops::Neg;

lazy_static! {
static ref PARSE_DECIMAL_RE: Regex =
Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap();
static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d*\.\d+|\d+\.\d*)$").unwrap();
static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap();
static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$")
Expand Down Expand Up @@ -99,7 +103,7 @@ fn infer_field_schema(string: &str) -> DataType {
///
/// If `max_read_records` is not set, the whole file is read to infer its schema.
///
/// Return infered schema and number of records used for inference. This function does not change
/// Return inferred schema and number of records used for inference. This function does not change
/// reader cursor offset.
pub fn infer_file_schema<R: Read + Seek>(
reader: &mut R,
Expand Down Expand Up @@ -513,6 +517,9 @@ fn parse(
let field = &fields[i];
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i),
DataType::Decimal(precision, scale) => {
build_decimal_array(line_number, rows, i, *precision, *scale)
}
DataType::Int8 => build_primitive_array::<Int8Type>(line_number, rows, i),
DataType::Int16 => {
build_primitive_array::<Int16Type>(line_number, rows, i)
Expand Down Expand Up @@ -728,6 +735,161 @@ fn parse_bool(string: &str) -> Option<bool> {
}
}

// parse the column string to an Arrow Array
fn build_decimal_array(
_line_number: usize,
rows: &[StringRecord],
col_idx: usize,
precision: usize,
scale: usize,
) -> Result<ArrayRef> {
let mut decimal_builder = DecimalBuilder::new(rows.len(), precision, scale);
for row in rows {
let col_s = row.get(col_idx);
match col_s {
None => {
// No data for this row
decimal_builder.append_null()?;
}
Some(s) => {
if s.is_empty() {
// append null
decimal_builder.append_null()?;
} else {
let decimal_value: Result<i128> =
parse_decimal_with_parameter(s, precision, scale);
match decimal_value {
Ok(v) => {
decimal_builder.append_value(v)?;
}
Err(e) => {
return Err(e);
}
}
}
}
}
}
Ok(Arc::new(decimal_builder.finish()))
}

// Parse the string format decimal value to i128 format and checking the precision and scale.
// The result i128 value can't be out of bounds.
fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Result<i128> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
// each byte is digit、'-' or '.'
let mut base = 1;

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
match delimiter_position {
None => {
// there is no '.'
base = 10_i128.pow(scale as u32);
}
Some(mid) => {
// there is the '.'
if len - mid >= scale + 1 {
// If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value.
offset -= len - mid - 1 - scale;
} else {
// If the string value is "123.12" and the scale is 4, we should append '00' to the tail.
base = 10_i128.pow((scale + 1 + mid - len) as u32);
}
}
};

let bytes = s.as_bytes();
let mut negative = false;
let mut result: i128 = 0;

while offset > 0 {
match bytes[offset - 1] {
b'-' => {
negative = true;
}
b'.' => {
// do nothing
}
b'0'..=b'9' => {
result += i128::from(bytes[offset - 1] - b'0') * base;
base *= 10;
}
_ => {
return Err(ArrowError::ParseError(format!(
"can't match byte {}",
bytes[offset - 1]
)));
}
}
offset -= 1;
}
if negative {
result = result.neg();
}
if result > MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1]
|| result < MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1]
{
return Err(ArrowError::ParseError(format!(
"parse decimal overflow, the precision {}, the scale {}, the value {}",
precision, scale, s
)));
}
Ok(result)
} else {
Err(ArrowError::ParseError(format!(
"can't parse the string value {} to decimal",
s
)))
}
}

// Parse the string format decimal value to i128 format without checking the precision and scale.
// Like "125.12" to 12512_i128.
fn parse_decimal(s: &str) -> Result<i128> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
// each byte is digit、'-' or '.'
let bytes = s.as_bytes();
let mut negative = false;
let mut result: i128 = 0;
let mut base = 1;
while offset > 0 {
match bytes[offset - 1] {
b'-' => {
negative = true;
}
b'.' => {
// do nothing
}
b'0'..=b'9' => {
result += i128::from(bytes[offset - 1] - b'0') * base;
base *= 10;
}
_ => {
return Err(ArrowError::ParseError(format!(
"can't match byte {}",
bytes[offset - 1]
)));
}
}
offset -= 1;
}
if negative {
Ok(result.neg())
} else {
Ok(result)
}
} else {
Err(ArrowError::ParseError(format!(
"can't parse the string value {} to decimal",
s
)))
}
}

// parses a specific column (col_idx) into an Arrow Array.
fn build_primitive_array<T: ArrowPrimitiveType + Parser>(
line_number: usize,
Expand Down Expand Up @@ -1055,6 +1217,37 @@ mod tests {
assert_eq!(&metadata, batch.schema().metadata());
}

#[test]
fn test_csv_reader_with_decimal() {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Decimal(26, 6), false),
Field::new("lng", DataType::Decimal(26, 6), false),
]);

let file = File::open("test/data/decimal_test.csv").unwrap();

let mut csv = Reader::new(file, Arc::new(schema), false, None, 1024, None, None);
let batch = csv.next().unwrap().unwrap();
// access data from a primitive array
let lat = batch
.column(1)
.as_any()
.downcast_ref::<DecimalArray>()
.unwrap();

assert_eq!("57.653484", lat.value_as_string(0));
assert_eq!("53.002666", lat.value_as_string(1));
assert_eq!("52.412811", lat.value_as_string(2));
assert_eq!("51.481583", lat.value_as_string(3));
assert_eq!("12.123456", lat.value_as_string(4));
assert_eq!("50.760000", lat.value_as_string(5));
assert_eq!("0.123000", lat.value_as_string(6));
assert_eq!("123.000000", lat.value_as_string(7));
assert_eq!("123.000000", lat.value_as_string(8));
assert_eq!("-50.760000", lat.value_as_string(9));
}

#[test]
fn test_csv_from_buf_reader() {
let schema = Schema::new(vec![
Expand Down Expand Up @@ -1348,6 +1541,8 @@ mod tests {
assert_eq!(infer_field_schema("false"), DataType::Boolean);
assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32);
assert_eq!(infer_field_schema("2020-11-08T14:20:01"), DataType::Date64);
assert_eq!(infer_field_schema("-5.13"), DataType::Float64);
assert_eq!(infer_field_schema("0.1300"), DataType::Float64);
}

#[test]
Expand All @@ -1374,6 +1569,70 @@ mod tests {
);
}

#[test]
fn test_parse_decimal() {
let tests = [
("123.00", 12300i128),
("123.123", 123123i128),
("0.0123", 123i128),
("0.12300", 12300i128),
("-5.123", -5123i128),
("-45.432432", -45432432i128),
];
for (s, i) in tests {
let result = parse_decimal(s);
assert_eq!(i, result.unwrap());
}
}

#[test]
fn test_parse_decimal_with_parameter() {
let tests = [
("123.123", 123123i128),
("123.1234", 123123i128),
("123.1", 123100i128),
("123", 123000i128),
("-123.123", -123123i128),
("-123.1234", -123123i128),
("-123.1", -123100i128),
("-123", -123000i128),
("0.0000123", 0i128),
("12.", 12000i128),
("-12.", -12000i128),
("00.1", 100i128),
("-00.1", -100i128),
("12345678912345678.1234", 12345678912345678123i128),
("-12345678912345678.1234", -12345678912345678123i128),
("99999999999999999.999", 99999999999999999999i128),
("-99999999999999999.999", -99999999999999999999i128),
(".123", 123i128),
("-.123", -123i128),
("123.", 123000i128),
("-123.", -123000i128),
];
for (s, i) in tests {
let result = parse_decimal_with_parameter(s, 20, 3);
assert_eq!(i, result.unwrap())
}
let can_not_parse_tests = ["123,123", "."];
for s in can_not_parse_tests {
let result = parse_decimal_with_parameter(s, 20, 3);
assert_eq!(
format!(
"Parser error: can't parse the string value {} to decimal",
s
),
result.unwrap_err().to_string()
);
}
let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
for s in overflow_parse_tests {
let result = parse_decimal_with_parameter(s, 10, 3);
assert_eq!(format!(
"Parser error: parse decimal overflow, the precision {}, the scale {}, the value {}", 10,3, s),result.unwrap_err().to_string());
}
}

/// Interprets a naive_datetime (with no explicit timezone offset)
/// using the local timezone and returns the timestamp in UTC (0
/// offset)
Expand Down
10 changes: 10 additions & 0 deletions arrow/test/data/decimal_test.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"Elgin, Scotland, the UK",57.653484,-3.335724
"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404
"Solihull, Birmingham, UK",52.412811,-1.778197
"Cardiff, Cardiff county, UK",51.481583,-3.179090
"Cardiff, Cardiff county, UK",12.12345678,-3.179090
"Eastbourne, East Sussex, UK",50.76,0.290472
"Eastbourne, East Sussex, UK",.123,0.290472
"Eastbourne, East Sussex, UK",123.,0.290472
"Eastbourne, East Sussex, UK",123,0.290472
"Eastbourne, East Sussex, UK",-50.76,0.290472