diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index 29bdeb4e289..e78f2d0ba71 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -51,7 +51,6 @@ use arrow_cast::parse::Parser; use arrow_schema::*; use lazy_static::lazy_static; use regex::{Regex, RegexSet}; -use std::collections::HashSet; use std::fmt; use std::fs::File; use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom}; @@ -62,44 +61,68 @@ use crate::reader::records::{RecordDecoder, StringRecords}; use csv::StringRecord; lazy_static! { + /// Order should match [`InferredDataType`] static ref REGEX_SET: RegexSet = RegexSet::new([ r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN - r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL r"^-?(\d+)$", //INTEGER + r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL r"^\d{4}-\d\d-\d\d$", //DATE32 - r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d$", //DATE64 + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d$", //Timestamp(Second) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,3}$", //Timestamp(Millisecond) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,6}$", //Timestamp(Microsecond) + r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,9}$", //Timestamp(Nanosecond) ]).unwrap(); - //The order should match with REGEX_SET - static ref MATCH_DATA_TYPE: Vec = vec![ - DataType::Boolean, - DataType::Float64, - DataType::Int64, - DataType::Date32, - DataType::Date64, - ]; static ref PARSE_DECIMAL_RE: Regex = Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap(); - static ref DATETIME_RE: Regex = - Regex::new(r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}$").unwrap(); } -/// Infer the data type of a record -fn infer_field_schema(string: &str, datetime_re: Option) -> DataType { - // when quoting is enabled in the reader, these quotes aren't escaped, we default to - // Utf8 for them - if string.starts_with('"') { - return DataType::Utf8; - } - let matches = REGEX_SET.matches(string).into_iter().next(); - // match regex in a particular order - match matches { - Some(ix) => MATCH_DATA_TYPE[ix].clone(), - None => { - let datetime_re = datetime_re.unwrap_or_else(|| DATETIME_RE.clone()); - if datetime_re.is_match(string) { - DataType::Timestamp(TimeUnit::Nanosecond, None) - } else { - DataType::Utf8 +#[derive(Default, Copy, Clone)] +struct InferredDataType { + /// Packed booleans indicating type + /// + /// 0 - Boolean + /// 1 - Integer + /// 2 - Float64 + /// 3 - Date32 + /// 4 - Timestamp(Second) + /// 5 - Timestamp(Millisecond) + /// 6 - Timestamp(Microsecond) + /// 7 - Timestamp(Nanosecond) + /// 8 - Utf8 + packed: u16, +} + +impl InferredDataType { + /// Returns the inferred data type + fn get(&self) -> DataType { + match self.packed { + 1 => DataType::Boolean, + 2 => DataType::Int64, + 4 | 6 => DataType::Float64, // Promote Int64 to Float64 + b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() { + // Promote to highest precision temporal type + 8 => DataType::Timestamp(TimeUnit::Nanosecond, None), + 9 => DataType::Timestamp(TimeUnit::Microsecond, None), + 10 => DataType::Timestamp(TimeUnit::Millisecond, None), + 11 => DataType::Timestamp(TimeUnit::Second, None), + 12 => DataType::Date32, + _ => unreachable!(), + }, + _ => DataType::Utf8, + } + } + + /// Updates the [`InferredDataType`] with the given string + fn update(&mut self, string: &str, datetime_re: Option<&Regex>) { + self.packed |= if string.starts_with('"') { + 1 << 8 // Utf8 + } else if let Some(m) = REGEX_SET.matches(string).into_iter().next() { + 1 << m + } else { + match datetime_re { + // Timestamp(Nanosecond) + Some(d) if d.is_match(string) => 1 << 7, + _ => 1 << 8, // Utf8 } } } @@ -230,10 +253,9 @@ fn infer_reader_schema_with_csv_options( let header_length = headers.len(); // keep track of inferred field types - let mut column_types: Vec> = vec![HashSet::new(); header_length]; + let mut column_types: Vec = vec![Default::default(); header_length]; let mut records_count = 0; - let mut fields = vec![]; let mut record = StringRecord::new(); let max_records = roptions.max_read_records.unwrap_or(usize::MAX); @@ -248,40 +270,18 @@ fn infer_reader_schema_with_csv_options( for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) { if let Some(string) = record.get(i) { if !string.is_empty() { - column_type - .insert(infer_field_schema(string, roptions.datetime_re.clone())); + column_type.update(string, roptions.datetime_re.as_ref()) } } } } // build schema from inference results - for i in 0..header_length { - let possibilities = &column_types[i]; - let field_name = &headers[i]; - - // determine data type based on possible types - // if there are incompatible types, use DataType::Utf8 - match possibilities.len() { - 1 => { - for dtype in possibilities.iter() { - fields.push(Field::new(field_name, dtype.clone(), true)); - } - } - 2 => { - if possibilities.contains(&DataType::Int64) - && possibilities.contains(&DataType::Float64) - { - // we have an integer and double, fall down to double - fields.push(Field::new(field_name, DataType::Float64, true)); - } else { - // default to Utf8 for conflicting datatypes (e.g bool and int) - fields.push(Field::new(field_name, DataType::Utf8, true)); - } - } - _ => fields.push(Field::new(field_name, DataType::Utf8, true)), - } - } + let fields = column_types + .iter() + .zip(&headers) + .map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true)) + .collect(); Ok((Schema::new(fields), records_count)) } @@ -681,6 +681,19 @@ fn parse( >( line_number, rows, i, None ), + DataType::Timestamp(TimeUnit::Second, _) => build_primitive_array::< + TimestampSecondType, + >( + line_number, rows, i, None + ), + DataType::Timestamp(TimeUnit::Millisecond, _) => { + build_primitive_array::( + line_number, + rows, + i, + None, + ) + } DataType::Timestamp(TimeUnit::Microsecond, _) => { build_primitive_array::( line_number, @@ -1637,7 +1650,10 @@ mod tests { assert_eq!(&DataType::Float64, schema.field(2).data_type()); assert_eq!(&DataType::Boolean, schema.field(3).data_type()); assert_eq!(&DataType::Date32, schema.field(4).data_type()); - assert_eq!(&DataType::Date64, schema.field(5).data_type()); + assert_eq!( + &DataType::Timestamp(TimeUnit::Second, None), + schema.field(5).data_type() + ); let names: Vec<&str> = schema.fields().iter().map(|x| x.name().as_str()).collect(); @@ -1698,6 +1714,13 @@ mod tests { } } + /// Infer the data type of a record + fn infer_field_schema(string: &str, datetime_re: Option) -> DataType { + let mut v = InferredDataType::default(); + v.update(string, datetime_re.as_ref()); + v.get() + } + #[test] fn test_infer_field_schema() { assert_eq!(infer_field_schema("A", None), DataType::Utf8); @@ -1712,22 +1735,22 @@ mod tests { assert_eq!(infer_field_schema("2020-11-08", None), DataType::Date32); assert_eq!( infer_field_schema("2020-11-08T14:20:01", None), - DataType::Date64 + DataType::Timestamp(TimeUnit::Second, None) ); assert_eq!( infer_field_schema("2020-11-08 14:20:01", None), - DataType::Date64 + DataType::Timestamp(TimeUnit::Second, None) ); let reg = Regex::new(r"^\d{4}-\d\d-\d\d \d\d:\d\d:\d\d$").ok(); assert_eq!( infer_field_schema("2020-11-08 14:20:01", reg), - DataType::Date64 + DataType::Timestamp(TimeUnit::Second, None) ); assert_eq!(infer_field_schema("-5.13", None), DataType::Float64); assert_eq!(infer_field_schema("0.1300", None), DataType::Float64); assert_eq!( infer_field_schema("2021-12-19 13:12:30.921", None), - DataType::Timestamp(TimeUnit::Nanosecond, None) + DataType::Timestamp(TimeUnit::Millisecond, None) ); assert_eq!( infer_field_schema("2021-12-19T13:12:30.123456789", None), @@ -2407,4 +2430,59 @@ mod tests { assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]); assert_eq!(read.fill_count, 4); } + + #[test] + fn test_inference() { + let cases: &[(&[&str], DataType)] = &[ + (&[], DataType::Utf8), + (&["false", "12"], DataType::Utf8), + (&["12", "cupcakes"], DataType::Utf8), + (&["12", "12.4"], DataType::Float64), + (&["14050", "24332"], DataType::Int64), + (&["14050.0", "true"], DataType::Utf8), + (&["14050", "2020-03-19 00:00:00"], DataType::Utf8), + (&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8), + ( + &["2020-03-19 02:00:00", "2020-03-19 00:00:00"], + DataType::Timestamp(TimeUnit::Second, None), + ), + (&["2020-03-19", "2020-03-20"], DataType::Date32), + ( + &["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"], + DataType::Timestamp(TimeUnit::Second, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00", + "2020-03-19 00:00:00.000", + ], + DataType::Timestamp(TimeUnit::Millisecond, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00", + "2020-03-19 00:00:00.000000", + ], + DataType::Timestamp(TimeUnit::Microsecond, None), + ), + ( + &[ + "2020-03-19", + "2020-03-19 02:00:00.000000000", + "2020-03-19 00:00:00.000000", + ], + DataType::Timestamp(TimeUnit::Nanosecond, None), + ), + ]; + + for (values, expected) in cases { + let mut t = InferredDataType::default(); + for v in *values { + t.update(v, None) + } + assert_eq!(&t.get(), expected, "{:?}", values) + } + } }