From 5d4f9d85b5606ca27b746f5300936a7026ef75f6 Mon Sep 17 00:00:00 2001 From: Kun Liu Date: Wed, 17 Nov 2021 00:43:29 +0800 Subject: [PATCH] Support read decimal data from csv reader if user provide the schema with decimal data type (#941) * support decimal data type for csv reader * format code and fix lint check * fix the clippy error * enchance the parse csv to decimal and add more test --- arrow/src/array/builder.rs | 4 +- arrow/src/array/mod.rs | 2 + arrow/src/csv/reader.rs | 263 ++++++++++++++++++++++++++++++- arrow/test/data/decimal_test.csv | 10 ++ 4 files changed, 275 insertions(+), 4 deletions(-) create mode 100644 arrow/test/data/decimal_test.csv diff --git a/arrow/src/array/builder.rs b/arrow/src/array/builder.rs index d08816c6276..af6f3c39a71 100644 --- a/arrow/src/array/builder.rs +++ b/arrow/src/array/builder.rs @@ -1118,7 +1118,7 @@ pub struct FixedSizeBinaryBuilder { builder: FixedSizeListBuilder, } -const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ +pub const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ 9, 99, 999, @@ -1158,7 +1158,7 @@ const MAX_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ 9999999999999999999999999999999999999, 170141183460469231731687303715884105727, ]; -const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ +pub const MIN_DECIMAL_FOR_EACH_PRECISION: [i128; 38] = [ -9, -99, -999, diff --git a/arrow/src/array/mod.rs b/arrow/src/array/mod.rs index 235d868be8e..26b410ee7c4 100644 --- a/arrow/src/array/mod.rs +++ b/arrow/src/array/mod.rs @@ -391,6 +391,8 @@ pub use self::builder::StringBuilder; pub use self::builder::StringDictionaryBuilder; pub use self::builder::StructBuilder; pub use self::builder::UnionBuilder; +pub use self::builder::MAX_DECIMAL_FOR_EACH_PRECISION; +pub use self::builder::MIN_DECIMAL_FOR_EACH_PRECISION; pub type Int8Builder = PrimitiveBuilder; pub type Int16Builder = PrimitiveBuilder; diff --git a/arrow/src/csv/reader.rs b/arrow/src/csv/reader.rs index 4940ea29a1b..ac72939e629 100644 --- a/arrow/src/csv/reader.rs +++ b/arrow/src/csv/reader.rs @@ -50,7 +50,8 @@ use std::io::{Read, Seek, SeekFrom}; use std::sync::Arc; use crate::array::{ - ArrayRef, BooleanArray, DictionaryArray, PrimitiveArray, StringArray, + ArrayRef, BooleanArray, DecimalBuilder, DictionaryArray, PrimitiveArray, StringArray, + MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, }; use crate::compute::kernels::cast_utils::string_to_timestamp_nanos; use crate::datatypes::*; @@ -58,8 +59,11 @@ use crate::error::{ArrowError, Result}; use crate::record_batch::RecordBatch; use csv_crate::{ByteRecord, StringRecord}; +use std::ops::Neg; lazy_static! { + static ref PARSE_DECIMAL_RE: Regex = + Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap(); static ref DECIMAL_RE: Regex = Regex::new(r"^-?(\d*\.\d+|\d+\.\d*)$").unwrap(); static ref INTEGER_RE: Regex = Regex::new(r"^-?(\d+)$").unwrap(); static ref BOOLEAN_RE: Regex = RegexBuilder::new(r"^(true)$|^(false)$") @@ -99,7 +103,7 @@ fn infer_field_schema(string: &str) -> DataType { /// /// If `max_read_records` is not set, the whole file is read to infer its schema. /// -/// Return infered schema and number of records used for inference. This function does not change +/// Return inferred schema and number of records used for inference. This function does not change /// reader cursor offset. pub fn infer_file_schema( reader: &mut R, @@ -513,6 +517,9 @@ fn parse( let field = &fields[i]; match field.data_type() { DataType::Boolean => build_boolean_array(line_number, rows, i), + DataType::Decimal(precision, scale) => { + build_decimal_array(line_number, rows, i, *precision, *scale) + } DataType::Int8 => build_primitive_array::(line_number, rows, i), DataType::Int16 => { build_primitive_array::(line_number, rows, i) @@ -728,6 +735,161 @@ fn parse_bool(string: &str) -> Option { } } +// parse the column string to an Arrow Array +fn build_decimal_array( + _line_number: usize, + rows: &[StringRecord], + col_idx: usize, + precision: usize, + scale: usize, +) -> Result { + let mut decimal_builder = DecimalBuilder::new(rows.len(), precision, scale); + for row in rows { + let col_s = row.get(col_idx); + match col_s { + None => { + // No data for this row + decimal_builder.append_null()?; + } + Some(s) => { + if s.is_empty() { + // append null + decimal_builder.append_null()?; + } else { + let decimal_value: Result = + parse_decimal_with_parameter(s, precision, scale); + match decimal_value { + Ok(v) => { + decimal_builder.append_value(v)?; + } + Err(e) => { + return Err(e); + } + } + } + } + } + } + Ok(Arc::new(decimal_builder.finish())) +} + +// Parse the string format decimal value to i128 format and checking the precision and scale. +// The result i128 value can't be out of bounds. +fn parse_decimal_with_parameter(s: &str, precision: usize, scale: usize) -> Result { + if PARSE_DECIMAL_RE.is_match(s) { + let mut offset = s.len(); + let len = s.len(); + // each byte is digit、'-' or '.' + let mut base = 1; + + // handle the value after the '.' and meet the scale + let delimiter_position = s.find('.'); + match delimiter_position { + None => { + // there is no '.' + base = 10_i128.pow(scale as u32); + } + Some(mid) => { + // there is the '.' + if len - mid >= scale + 1 { + // If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value. + offset -= len - mid - 1 - scale; + } else { + // If the string value is "123.12" and the scale is 4, we should append '00' to the tail. + base = 10_i128.pow((scale + 1 + mid - len) as u32); + } + } + }; + + let bytes = s.as_bytes(); + let mut negative = false; + let mut result: i128 = 0; + + while offset > 0 { + match bytes[offset - 1] { + b'-' => { + negative = true; + } + b'.' => { + // do nothing + } + b'0'..=b'9' => { + result += i128::from(bytes[offset - 1] - b'0') * base; + base *= 10; + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't match byte {}", + bytes[offset - 1] + ))); + } + } + offset -= 1; + } + if negative { + result = result.neg(); + } + if result > MAX_DECIMAL_FOR_EACH_PRECISION[precision - 1] + || result < MIN_DECIMAL_FOR_EACH_PRECISION[precision - 1] + { + return Err(ArrowError::ParseError(format!( + "parse decimal overflow, the precision {}, the scale {}, the value {}", + precision, scale, s + ))); + } + Ok(result) + } else { + Err(ArrowError::ParseError(format!( + "can't parse the string value {} to decimal", + s + ))) + } +} + +// Parse the string format decimal value to i128 format without checking the precision and scale. +// Like "125.12" to 12512_i128. +fn parse_decimal(s: &str) -> Result { + if PARSE_DECIMAL_RE.is_match(s) { + let mut offset = s.len(); + // each byte is digit、'-' or '.' + let bytes = s.as_bytes(); + let mut negative = false; + let mut result: i128 = 0; + let mut base = 1; + while offset > 0 { + match bytes[offset - 1] { + b'-' => { + negative = true; + } + b'.' => { + // do nothing + } + b'0'..=b'9' => { + result += i128::from(bytes[offset - 1] - b'0') * base; + base *= 10; + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't match byte {}", + bytes[offset - 1] + ))); + } + } + offset -= 1; + } + if negative { + Ok(result.neg()) + } else { + Ok(result) + } + } else { + Err(ArrowError::ParseError(format!( + "can't parse the string value {} to decimal", + s + ))) + } +} + // parses a specific column (col_idx) into an Arrow Array. fn build_primitive_array( line_number: usize, @@ -1055,6 +1217,37 @@ mod tests { assert_eq!(&metadata, batch.schema().metadata()); } + #[test] + fn test_csv_reader_with_decimal() { + let schema = Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Decimal(26, 6), false), + Field::new("lng", DataType::Decimal(26, 6), false), + ]); + + let file = File::open("test/data/decimal_test.csv").unwrap(); + + let mut csv = Reader::new(file, Arc::new(schema), false, None, 1024, None, None); + let batch = csv.next().unwrap().unwrap(); + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("57.653484", lat.value_as_string(0)); + assert_eq!("53.002666", lat.value_as_string(1)); + assert_eq!("52.412811", lat.value_as_string(2)); + assert_eq!("51.481583", lat.value_as_string(3)); + assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("50.760000", lat.value_as_string(5)); + assert_eq!("0.123000", lat.value_as_string(6)); + assert_eq!("123.000000", lat.value_as_string(7)); + assert_eq!("123.000000", lat.value_as_string(8)); + assert_eq!("-50.760000", lat.value_as_string(9)); + } + #[test] fn test_csv_from_buf_reader() { let schema = Schema::new(vec![ @@ -1348,6 +1541,8 @@ mod tests { assert_eq!(infer_field_schema("false"), DataType::Boolean); assert_eq!(infer_field_schema("2020-11-08"), DataType::Date32); assert_eq!(infer_field_schema("2020-11-08T14:20:01"), DataType::Date64); + assert_eq!(infer_field_schema("-5.13"), DataType::Float64); + assert_eq!(infer_field_schema("0.1300"), DataType::Float64); } #[test] @@ -1374,6 +1569,70 @@ mod tests { ); } + #[test] + fn test_parse_decimal() { + let tests = [ + ("123.00", 12300i128), + ("123.123", 123123i128), + ("0.0123", 123i128), + ("0.12300", 12300i128), + ("-5.123", -5123i128), + ("-45.432432", -45432432i128), + ]; + for (s, i) in tests { + let result = parse_decimal(s); + assert_eq!(i, result.unwrap()); + } + } + + #[test] + fn test_parse_decimal_with_parameter() { + let tests = [ + ("123.123", 123123i128), + ("123.1234", 123123i128), + ("123.1", 123100i128), + ("123", 123000i128), + ("-123.123", -123123i128), + ("-123.1234", -123123i128), + ("-123.1", -123100i128), + ("-123", -123000i128), + ("0.0000123", 0i128), + ("12.", 12000i128), + ("-12.", -12000i128), + ("00.1", 100i128), + ("-00.1", -100i128), + ("12345678912345678.1234", 12345678912345678123i128), + ("-12345678912345678.1234", -12345678912345678123i128), + ("99999999999999999.999", 99999999999999999999i128), + ("-99999999999999999.999", -99999999999999999999i128), + (".123", 123i128), + ("-.123", -123i128), + ("123.", 123000i128), + ("-123.", -123000i128), + ]; + for (s, i) in tests { + let result = parse_decimal_with_parameter(s, 20, 3); + assert_eq!(i, result.unwrap()) + } + let can_not_parse_tests = ["123,123", "."]; + for s in can_not_parse_tests { + let result = parse_decimal_with_parameter(s, 20, 3); + assert_eq!( + format!( + "Parser error: can't parse the string value {} to decimal", + s + ), + result.unwrap_err().to_string() + ); + } + let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"]; + for s in overflow_parse_tests { + let result = parse_decimal_with_parameter(s, 10, 3); + assert_eq!(format!( + "Parser error: parse decimal overflow, the precision {}, the scale {}, the value {}", 10,3, s),result.unwrap_err().to_string()); + } + } + /// Interprets a naive_datetime (with no explicit timezone offset) /// using the local timezone and returns the timestamp in UTC (0 /// offset) diff --git a/arrow/test/data/decimal_test.csv b/arrow/test/data/decimal_test.csv new file mode 100644 index 00000000000..460ed808c1a --- /dev/null +++ b/arrow/test/data/decimal_test.csv @@ -0,0 +1,10 @@ +"Elgin, Scotland, the UK",57.653484,-3.335724 +"Stoke-on-Trent, Staffordshire, the UK",53.002666,-2.179404 +"Solihull, Birmingham, UK",52.412811,-1.778197 +"Cardiff, Cardiff county, UK",51.481583,-3.179090 +"Cardiff, Cardiff county, UK",12.12345678,-3.179090 +"Eastbourne, East Sussex, UK",50.76,0.290472 +"Eastbourne, East Sussex, UK",.123,0.290472 +"Eastbourne, East Sussex, UK",123.,0.290472 +"Eastbourne, East Sussex, UK",123,0.290472 +"Eastbourne, East Sussex, UK",-50.76,0.290472 \ No newline at end of file