Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: arrow csv decimal256 #3711

Merged
merged 9 commits into from
Feb 14, 2023
4 changes: 2 additions & 2 deletions arrow-array/src/array/primitive_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::temporal_conversions::{
};
use crate::timezone::Tz;
use crate::trusted_len::trusted_len_unzip;
use crate::types::*;
use crate::{print_long_array, Array, ArrayAccessor};
use crate::{types::*, ArrowNativeTypeOp};
use arrow_buffer::{i256, ArrowNativeType, Buffer};
use arrow_data::bit_iterator::try_for_each_valid_idx;
use arrow_data::ArrayData;
Expand Down Expand Up @@ -233,7 +233,7 @@ pub type Decimal256Array = PrimitiveArray<Decimal256Type>;
/// static-typed nature of rust types ([`ArrowNativeType`]) for all types that implement [`ArrowNativeType`].
pub trait ArrowPrimitiveType: 'static {
/// Corresponding Rust native type for the primitive type.
type Native: ArrowNativeType;
type Native: ArrowNativeTypeOp;

/// the corresponding Arrow data type of this primitive type.
const DATA_TYPE: DataType;
Expand Down
141 changes: 98 additions & 43 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@

mod records;

use arrow_array::builder::PrimitiveBuilder;
use arrow_array::types::*;
use arrow_array::ArrowNativeTypeOp;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_cast::parse::Parser;
use arrow_schema::*;
use lazy_static::lazy_static;
use regex::{Regex, RegexSet};
use std::collections::HashSet;
Expand All @@ -50,17 +57,9 @@ use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
use std::sync::Arc;

use arrow_array::builder::Decimal128Builder;
use arrow_array::types::*;
use arrow_array::*;
use arrow_cast::parse::Parser;
use arrow_schema::*;

use crate::map_csv_error;
use crate::reader::records::{RecordDecoder, StringRecords};
use arrow_data::decimal::validate_decimal_precision;
use csv::StringRecord;
use std::ops::Neg;

lazy_static! {
static ref REGEX_SET: RegexSet = RegexSet::new([
Expand Down Expand Up @@ -608,7 +607,22 @@ fn parse(
match field.data_type() {
DataType::Boolean => build_boolean_array(line_number, rows, i),
DataType::Decimal128(precision, scale) => {
build_decimal_array(line_number, rows, i, *precision, *scale)
build_decimal_array::<Decimal128Type>(
line_number,
rows,
i,
*precision,
*scale,
)
}
DataType::Decimal256(precision, scale) => {
build_decimal_array::<Decimal256Type>(
line_number,
rows,
i,
*precision,
*scale,
)
}
DataType::Int8 => {
build_primitive_array::<Int8Type>(line_number, rows, i, None)
Expand Down Expand Up @@ -781,22 +795,22 @@ fn parse_bool(string: &str) -> Option<bool> {
}

// parse the column string to an Arrow Array
fn build_decimal_array(
fn build_decimal_array<T: DecimalType>(
suxiaogang223 marked this conversation as resolved.
Show resolved Hide resolved
_line_number: usize,
rows: &StringRecords<'_>,
col_idx: usize,
precision: u8,
scale: i8,
) -> Result<ArrayRef, ArrowError> {
let mut decimal_builder = Decimal128Builder::with_capacity(rows.len());
let mut decimal_builder = PrimitiveBuilder::<T>::with_capacity(rows.len());
for row in rows.iter() {
let s = row.get(col_idx);
if s.is_empty() {
// append null
decimal_builder.append_null();
} else {
let decimal_value: Result<i128, _> =
parse_decimal_with_parameter(s, precision, scale);
let decimal_value: Result<T::Native, _> =
parse_decimal_with_parameter::<T>(s, precision, scale);
match decimal_value {
Ok(v) => {
decimal_builder.append_value(v);
Expand All @@ -814,25 +828,25 @@ fn build_decimal_array(
))
}

// Parse the string format decimal value to i128 format and checking the precision and scale.
// The result i128 value can't be out of bounds.
fn parse_decimal_with_parameter(
// Parse the string format decimal value to i128/i256 format and checking the precision and scale.
// The result value can't be out of bounds.
fn parse_decimal_with_parameter<T: DecimalType>(
s: &str,
precision: u8,
scale: i8,
) -> Result<i128, ArrowError> {
) -> Result<T::Native, ArrowError> {
if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
let len = s.len();
let mut base = 1;
let mut base = T::Native::usize_as(1);
let scale_usize = usize::from(scale as u8);

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
match delimiter_position {
None => {
// there is no '.'
base = 10_i128.pow(scale as u32);
base = T::Native::usize_as(10).pow_checked(scale as u32)?;
}
Some(mid) => {
// there is the '.'
Expand All @@ -841,33 +855,38 @@ fn parse_decimal_with_parameter(
offset -= len - mid - 1 - scale_usize;
} else {
// If the string value is "123.12" and the scale is 4, we should append '00' to the tail.
base = 10_i128.pow((scale_usize + 1 + mid - len) as u32);
base = T::Native::usize_as(10)
.pow_checked((scale_usize + 1 + mid - len) as u32)?;
}
}
};

// each byte is digit、'-' or '.'
let bytes = s.as_bytes();
let mut negative = false;
let mut result: i128 = 0;
let mut result = T::Native::usize_as(0);

bytes[0..offset].iter().rev().for_each(|&byte| match byte {
b'-' => {
negative = true;
}
b'0'..=b'9' => {
result += i128::from(byte - b'0') * base;
base *= 10;
for byte in bytes[0..offset].iter().rev() {
match byte {
b'-' => {
negative = true;
}
b'0'..=b'9' => {
let add =
T::Native::usize_as((byte - b'0') as usize).mul_checked(base)?;
result = result.add_checked(add)?;
base = base.mul_checked(T::Native::usize_as(10))?;
}
// because of the PARSE_DECIMAL_RE, bytes just contains digit、'-' and '.'.
_ => {}
}
// because of the PARSE_DECIMAL_RE, bytes just contains digit、'-' and '.'.
_ => {}
});
}

if negative {
result = result.neg();
result = result.neg_checked()?;
}

match validate_decimal_precision(result, precision) {
match T::validate_decimal_precision(result, precision) {
Ok(_) => Ok(result),
Err(e) => Err(ArrowError::ParseError(format!(
"parse decimal overflow: {e}"
Expand All @@ -884,6 +903,8 @@ fn parse_decimal_with_parameter(
// Like "125.12" to 12512_i128.
#[cfg(test)]
fn parse_decimal(s: &str) -> Result<i128, ArrowError> {
use std::ops::Neg;

if PARSE_DECIMAL_RE.is_match(s) {
let mut offset = s.len();
// each byte is digit、'-' or '.'
Expand Down Expand Up @@ -1230,6 +1251,7 @@ impl ReaderBuilder {
mod tests {
use super::*;

use arrow_buffer::i256;
use std::io::{Cursor, Write};
use tempfile::NamedTempFile;

Expand Down Expand Up @@ -1318,7 +1340,7 @@ mod tests {
let schema = Schema::new(vec![
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Decimal128(38, 6), false),
Field::new("lng", DataType::Decimal128(38, 6), false),
Field::new("lng", DataType::Decimal256(76, 6), false),
]);

let file = File::open("test/data/decimal_test.csv").unwrap();
Expand All @@ -1343,6 +1365,23 @@ mod tests {
assert_eq!("123.000000", lat.value_as_string(7));
assert_eq!("123.000000", lat.value_as_string(8));
assert_eq!("-50.760000", lat.value_as_string(9));

let lng = batch
.column(2)
.as_any()
.downcast_ref::<Decimal256Array>()
.unwrap();

assert_eq!("-3.335724", lng.value_as_string(0));
assert_eq!("-2.179404", lng.value_as_string(1));
assert_eq!("-1.778197", lng.value_as_string(2));
assert_eq!("-3.179090", lng.value_as_string(3));
assert_eq!("-3.179090", lng.value_as_string(4));
assert_eq!("0.290472", lng.value_as_string(5));
assert_eq!("0.290472", lng.value_as_string(6));
assert_eq!("0.290472", lng.value_as_string(7));
assert_eq!("0.290472", lng.value_as_string(8));
assert_eq!("0.290472", lng.value_as_string(9));
}

#[test]
Expand Down Expand Up @@ -1788,26 +1827,42 @@ mod tests {
("-123.", -123000i128),
];
for (s, i) in tests {
let result = parse_decimal_with_parameter(s, 20, 3);
assert_eq!(i, result.unwrap())
let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s, 20, 3);
assert_eq!(i, result_128.unwrap());
let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s, 20, 3);
assert_eq!(i256::from_i128(i), result_256.unwrap());
}
let can_not_parse_tests = ["123,123", ".", "123.123.123"];
for s in can_not_parse_tests {
let result = parse_decimal_with_parameter(s, 20, 3);
let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s, 20, 3);
assert_eq!(
format!("Parser error: can't parse the string value {s} to decimal"),
result_128.unwrap_err().to_string()
);
let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s, 20, 3);
assert_eq!(
format!("Parser error: can't parse the string value {s} to decimal"),
result.unwrap_err().to_string()
result_256.unwrap_err().to_string()
);
}
let overflow_parse_tests = ["12345678", "12345678.9", "99999999.99"];
for s in overflow_parse_tests {
let result = parse_decimal_with_parameter(s, 10, 3);
let expected = "Parser error: parse decimal overflow";
let actual = result.unwrap_err().to_string();
let result_128 = parse_decimal_with_parameter::<Decimal128Type>(s, 10, 3);
let expected_128 = "Parser error: parse decimal overflow";
let actual_128 = result_128.unwrap_err().to_string();

assert!(
actual_128.contains(expected_128),
"actual: '{actual_128}', expected: '{expected_128}'"
);

let result_256 = parse_decimal_with_parameter::<Decimal256Type>(s, 10, 3);
let expected_256 = "Parser error: parse decimal overflow";
let actual_256 = result_256.unwrap_err().to_string();

assert!(
actual.contains(expected),
"actual: '{actual}', expected: '{expected}'"
actual_256.contains(expected_256),
"actual: '{actual_256}', expected: '{expected_256}'"
);
}
}
Expand Down
55 changes: 55 additions & 0 deletions arrow-csv/src/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,9 @@ mod tests {
use super::*;

use crate::Reader;
use arrow_array::builder::{Decimal128Builder, Decimal256Builder};
use arrow_array::types::*;
use arrow_buffer::i256;
use std::io::{Cursor, Read, Seek};
use std::sync::Arc;

Expand Down Expand Up @@ -406,6 +408,59 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo
assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap());
}

#[test]
fn test_write_csv_decimal() {
let schema = Schema::new(vec![
Field::new("c1", DataType::Decimal128(38, 6), true),
Field::new("c2", DataType::Decimal256(76, 6), true),
]);

let mut c1_builder =
Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6));
c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]);
let c1 = c1_builder.finish();

let mut c2_builder =
Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6));
c2_builder.extend(vec![
Some(i256::from_i128(-3335724)),
Some(i256::from_i128(2179404)),
None,
Some(i256::from_i128(290472)),
]);
let c2 = c2_builder.finish();

let batch =
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)])
.unwrap();

let mut file = tempfile::tempfile().unwrap();

let mut writer = Writer::new(&mut file);
let batches = vec![&batch, &batch];
for batch in batches {
writer.write(batch).unwrap();
}
drop(writer);

// check that file was written successfully
file.rewind().unwrap();
let mut buffer: Vec<u8> = vec![];
file.read_to_end(&mut buffer).unwrap();

let expected = r#"c1,c2
-3.335724,-3.335724
2.179404,2.179404
,
0.290472,0.290472
-3.335724,-3.335724
2.179404,2.179404
,
0.290472,0.290472
"#;
assert_eq!(expected.to_string(), String::from_utf8(buffer).unwrap());
}

#[test]
fn test_write_csv_custom_options() {
let schema = Schema::new(vec![
Expand Down