diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index aeece612ded..b64534e9835 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -23,8 +23,8 @@ use crate::temporal_conversions::{ }; use crate::timezone::Tz; use crate::trusted_len::trusted_len_unzip; -use crate::types::*; use crate::{print_long_array, Array, ArrayAccessor}; +use crate::{types::*, ArrowNativeTypeOp}; use arrow_buffer::{i256, ArrowNativeType, Buffer}; use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_data::ArrayData; @@ -233,7 +233,7 @@ pub type Decimal256Array = PrimitiveArray; /// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`]. pub trait ArrowPrimitiveType: 'static { /// Corresponding Rust native type for the primitive type. - type Native: ArrowNativeType; + type Native: ArrowNativeTypeOp; /// the corresponding Arrow data type of this primitive type. const DATA_TYPE: DataType; diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 610f05155b5..c5fe20e9d91 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -42,6 +42,13 @@ mod records; +use arrow_array::builder::PrimitiveBuilder; +use arrow_array::types::*; +use arrow_array::ArrowNativeTypeOp; +use arrow_array::*; +use arrow_buffer::ArrowNativeType; +use arrow_cast::parse::Parser; +use arrow_schema::*; use lazy_static::lazy_static; use regex::{Regex, RegexSet}; use std::collections::HashSet; @@ -50,17 +57,9 @@ use std::fs::File; use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; use std::sync::Arc; -use arrow_array::builder::Decimal128Builder; -use arrow_array::types::*; -use arrow_array::*; -use arrow_cast::parse::Parser; -use arrow_schema::*; - use crate::map_csv_error; use crate::reader::records::{RecordDecoder, StringRecords}; -use arrow_data::decimal::validate_decimal_precision; use csv::StringRecord; -use std::ops::Neg; lazy_static! { static ref REGEX_SET: RegexSet = RegexSet::new([ @@ -608,7 +607,22 @@ fn parse( match field.data_type() { DataType::Boolean => build_boolean_array(line_number, rows, i), DataType::Decimal128(precision, scale) => { - build_decimal_array(line_number, rows, i, *precision, *scale) + build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + ) + } + DataType::Decimal256(precision, scale) => { + build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + ) } DataType::Int8 => { build_primitive_array::(line_number, rows, i, None) @@ -781,22 +795,22 @@ fn parse_bool(string: &str) -> Option { } // parse the column string to an Arrow Array -fn build_decimal_array( +fn build_decimal_array( _line_number: usize, rows: &StringRecords<'_>, col_idx: usize, precision: u8, scale: i8, ) -> Result { - let mut decimal_builder = Decimal128Builder::with_capacity(rows.len()); + let mut decimal_builder = PrimitiveBuilder::::with_capacity(rows.len()); for row in rows.iter() { let s = row.get(col_idx); if s.is_empty() { // append null decimal_builder.append_null(); } else { - let decimal_value: Result = - parse_decimal_with_parameter(s, precision, scale); + let decimal_value: Result = + parse_decimal_with_parameter::(s, precision, scale); match decimal_value { Ok(v) => { decimal_builder.append_value(v); @@ -814,17 +828,17 @@ fn build_decimal_array( )) } -// Parse the string format decimal value to i128 format and checking the precision and scale. -// The result i128 value can't be out of bounds. -fn parse_decimal_with_parameter( +// Parse the string format decimal value to i128/i256 format and checking the precision and scale. +// The result value can't be out of bounds. +fn parse_decimal_with_parameter( s: &str, precision: u8, scale: i8, -) -> Result { +) -> Result { if PARSE_DECIMAL_RE.is_match(s) { let mut offset = s.len(); let len = s.len(); - let mut base = 1; + let mut base = T::Native::usize_as(1); let scale_usize = usize::from(scale as u8); // handle the value after the '.' and meet the scale @@ -832,7 +846,7 @@ fn parse_decimal_with_parameter( match delimiter_position { None => { // there is no '.' - base = 10_i128.pow(scale as u32); + base = T::Native::usize_as(10).pow_checked(scale as u32)?; } Some(mid) => { // there is the '.' @@ -841,7 +855,8 @@ fn parse_decimal_with_parameter( offset -= len - mid - 1 - scale_usize; } else { // If the string value is "123.12" and the scale is 4, we should append '00' to the tail. - base = 10_i128.pow((scale_usize + 1 + mid - len) as u32); + base = T::Native::usize_as(10) + .pow_checked((scale_usize + 1 + mid - len) as u32)?; } } }; @@ -849,25 +864,29 @@ fn parse_decimal_with_parameter( // each byte is digit、'-' or '.' let bytes = s.as_bytes(); let mut negative = false; - let mut result: i128 = 0; + let mut result = T::Native::usize_as(0); - bytes[0..offset].iter().rev().for_each(|&byte| match byte { - b'-' => { - negative = true; - } - b'0'..=b'9' => { - result += i128::from(byte - b'0') * base; - base *= 10; + for byte in bytes[0..offset].iter().rev() { + match byte { + b'-' => { + negative = true; + } + b'0'..=b'9' => { + let add = + T::Native::usize_as((byte - b'0') as usize).mul_checked(base)?; + result = result.add_checked(add)?; + base = base.mul_checked(T::Native::usize_as(10))?; + } + // because of the PARSE_DECIMAL_RE, bytes just contains digit、'-' and '.'. + _ => {} } - // because of the PARSE_DECIMAL_RE, bytes just contains digit、'-' and '.'. - _ => {} - }); + } if negative { - result = result.neg(); + result = result.neg_checked()?; } - match validate_decimal_precision(result, precision) { + match T::validate_decimal_precision(result, precision) { Ok(_) => Ok(result), Err(e) => Err(ArrowError::ParseError(format!( "parse decimal overflow: {e}" @@ -884,6 +903,8 @@ fn parse_decimal_with_parameter( // Like "125.12" to 12512_i128. #[cfg(test)] fn parse_decimal(s: &str) -> Result { + use std::ops::Neg; + if PARSE_DECIMAL_RE.is_match(s) { let mut offset = s.len(); // each byte is digit、'-' or '.' @@ -1230,6 +1251,7 @@ impl ReaderBuilder { mod tests { use super::*; + use arrow_buffer::i256; use std::io::{Cursor, Write}; use tempfile::NamedTempFile; @@ -1318,7 +1340,7 @@ mod tests { let schema = Schema::new(vec![ Field::new("city", DataType::Utf8, false), Field::new("lat", DataType::Decimal128(38, 6), false), - Field::new("lng", DataType::Decimal128(38, 6), false), + Field::new("lng", DataType::Decimal256(76, 6), false), ]); let file = File::open("test/data/decimal_test.csv").unwrap(); @@ -1343,6 +1365,23 @@ mod tests { assert_eq!("123.000000", lat.value_as_string(7)); assert_eq!("123.000000", lat.value_as_string(8)); assert_eq!("-50.760000", lat.value_as_string(9)); + + let lng = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("-3.335724", lng.value_as_string(0)); + assert_eq!("-2.179404", lng.value_as_string(1)); + assert_eq!("-1.778197", lng.value_as_string(2)); + assert_eq!("-3.179090", lng.value_as_string(3)); + assert_eq!("-3.179090", lng.value_as_string(4)); + assert_eq!("0.290472", lng.value_as_string(5)); + assert_eq!("0.290472", lng.value_as_string(6)); + assert_eq!("0.290472", lng.value_as_string(7)); + assert_eq!("0.290472", lng.value_as_string(8)); + assert_eq!("0.290472", lng.value_as_string(9)); } #[test] @@ -1788,26 +1827,42 @@ mod tests { ("-123.", -123000i128), ]; for (s, i) in tests { - let result = parse_decimal_with_parameter(s, 20, 3); - assert_eq!(i, result.unwrap()) + let result_128 = parse_decimal_with_parameter::(s, 20, 3); + assert_eq!(i, result_128.unwrap()); + let result_256 = parse_decimal_with_parameter::(s, 20, 3); + assert_eq!(i256::from_i128(i), result_256.unwrap()); } let can_not_parse_tests = ["123,123", ".", "123.123.123"]; for s in can_not_parse_tests { - let result = parse_decimal_with_parameter(s, 20, 3); + let result_128 = parse_decimal_with_parameter::(s, 20, 3); + assert_eq!( + format!("Parser error: can't parse the string value {s} to decimal"), + result_128.unwrap_err().to_string() + ); + let result_256 = parse_decimal_with_parameter::(s, 20, 3); assert_eq!( format!("Parser error: can't parse the string value {s} to decimal"), - result.unwrap_err().to_string() + result_256.unwrap_err().to_string() ); } let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"]; for s in overflow_parse_tests { - let result = parse_decimal_with_parameter(s, 10, 3); - let expected = "Parser error: parse decimal overflow"; - let actual = result.unwrap_err().to_string(); + let result_128 = parse_decimal_with_parameter::(s, 10, 3); + let expected_128 = "Parser error: parse decimal overflow"; + let actual_128 = result_128.unwrap_err().to_string(); + + assert!( + actual_128.contains(expected_128), + "actual: '{actual_128}', expected: '{expected_128}'" + ); + + let result_256 = parse_decimal_with_parameter::(s, 10, 3); + let expected_256 = "Parser error: parse decimal overflow"; + let actual_256 = result_256.unwrap_err().to_string(); assert!( - actual.contains(expected), - "actual: '{actual}', expected: '{expected}'" + actual_256.contains(expected_256), + "actual: '{actual_256}', expected: '{expected_256}'" ); } } diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index e0734a15fd4..d9331053f3d 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -326,7 +326,9 @@ mod tests { use super::*; use crate::Reader; + use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; use arrow_array::types::*; + use arrow_buffer::i256; use std::io::{Cursor, Read, Seek}; use std::sync::Arc; @@ -406,6 +408,59 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); } + #[test] + fn test_write_csv_decimal() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Decimal128(38, 6), true), + Field::new("c2", DataType::Decimal256(76, 6), true), + ]); + + let mut c1_builder = + Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); + c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); + let c1 = c1_builder.finish(); + + let mut c2_builder = + Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); + c2_builder.extend(vec![ + Some(i256::from_i128(-3335724)), + Some(i256::from_i128(2179404)), + None, + Some(i256::from_i128(290472)), + ]); + let c2 = c2_builder.finish(); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]) + .unwrap(); + + let mut file = tempfile::tempfile().unwrap(); + + let mut writer = Writer::new(&mut file); + let batches = vec![&batch, &batch]; + for batch in batches { + writer.write(batch).unwrap(); + } + drop(writer); + + // check that file was written successfully + file.rewind().unwrap(); + let mut buffer: Vec = vec![]; + file.read_to_end(&mut buffer).unwrap(); + + let expected = r#"c1,c2 +-3.335724,-3.335724 +2.179404,2.179404 +, +0.290472,0.290472 +-3.335724,-3.335724 +2.179404,2.179404 +, +0.290472,0.290472 +"#; + assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap()); + } + #[test] fn test_write_csv_custom_options() { let schema = Schema::new(vec![