Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Infer 2020-03-19 00:00:00 as timestamp not Date64 in CSV (#3744) #3746

Merged
merged 2 commits into from
Feb 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
208 changes: 143 additions & 65 deletions arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ use arrow_cast::parse::Parser;
use arrow_schema::*;
use lazy_static::lazy_static;
use regex::{Regex, RegexSet};
use std::collections::HashSet;
use std::fmt;
use std::fs::File;
use std::io::{BufRead, BufReader as StdBufReader, Read, Seek, SeekFrom};
Expand All @@ -62,44 +61,68 @@ use crate::reader::records::{RecordDecoder, StringRecords};
use csv::StringRecord;

lazy_static! {
/// Order should match [`InferredDataType`]
static ref REGEX_SET: RegexSet = RegexSet::new([
r"(?i)^(true)$|^(false)$(?-i)", //BOOLEAN
r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL
r"^-?(\d+)$", //INTEGER
r"^-?((\d*\.\d+|\d+\.\d*)([eE]-?\d+)?|\d+([eE]-?\d+))$", //DECIMAL
r"^\d{4}-\d\d-\d\d$", //DATE32
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d$", //DATE64
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d$", //Timestamp(Second)
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,3}$", //Timestamp(Millisecond)
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,6}$", //Timestamp(Microsecond)
r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d.\d{1,9}$", //Timestamp(Nanosecond)
]).unwrap();
//The order should match with REGEX_SET
static ref MATCH_DATA_TYPE: Vec<DataType> = vec![
DataType::Boolean,
DataType::Float64,
DataType::Int64,
DataType::Date32,
DataType::Date64,
];
static ref PARSE_DECIMAL_RE: Regex =
Regex::new(r"^-?(\d+\.?\d*|\d*\.?\d+)$").unwrap();
static ref DATETIME_RE: Regex =
Regex::new(r"^\d{4}-\d\d-\d\d[T ]\d\d:\d\d:\d\d\.\d{1,9}$").unwrap();
}

/// Infer the data type of a record
fn infer_field_schema(string: &str, datetime_re: Option<Regex>) -> DataType {
// when quoting is enabled in the reader, these quotes aren't escaped, we default to
// Utf8 for them
if string.starts_with('"') {
return DataType::Utf8;
}
let matches = REGEX_SET.matches(string).into_iter().next();
// match regex in a particular order
match matches {
Some(ix) => MATCH_DATA_TYPE[ix].clone(),
None => {
let datetime_re = datetime_re.unwrap_or_else(|| DATETIME_RE.clone());
if datetime_re.is_match(string) {
DataType::Timestamp(TimeUnit::Nanosecond, None)
} else {
DataType::Utf8
#[derive(Default, Copy, Clone)]
struct InferredDataType {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic should be significantly faster as an added bonus

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

/// Packed booleans indicating type
///
/// 0 - Boolean
/// 1 - Integer
/// 2 - Float64
/// 3 - Date32
/// 4 - Timestamp(Second)
/// 5 - Timestamp(Millisecond)
/// 6 - Timestamp(Microsecond)
/// 7 - Timestamp(Nanosecond)
/// 8 - Utf8
packed: u16,
}

impl InferredDataType {
/// Returns the inferred data type
fn get(&self) -> DataType {
match self.packed {
1 => DataType::Boolean,
2 => DataType::Int64,
4 | 6 => DataType::Float64, // Promote Int64 to Float64
b if b != 0 && (b & !0b11111000) == 0 => match b.leading_zeros() {
// Promote to highest precision temporal type
8 => DataType::Timestamp(TimeUnit::Nanosecond, None),
9 => DataType::Timestamp(TimeUnit::Microsecond, None),
10 => DataType::Timestamp(TimeUnit::Millisecond, None),
11 => DataType::Timestamp(TimeUnit::Second, None),
12 => DataType::Date32,
_ => unreachable!(),
},
_ => DataType::Utf8,
}
}

/// Updates the [`InferredDataType`] with the given string
fn update(&mut self, string: &str, datetime_re: Option<&Regex>) {
self.packed |= if string.starts_with('"') {
1 << 8 // Utf8
} else if let Some(m) = REGEX_SET.matches(string).into_iter().next() {
1 << m
} else {
match datetime_re {
// Timestamp(Nanosecond)
Some(d) if d.is_match(string) => 1 << 7,
_ => 1 << 8, // Utf8
}
}
}
Expand Down Expand Up @@ -230,10 +253,9 @@ fn infer_reader_schema_with_csv_options<R: Read>(

let header_length = headers.len();
// keep track of inferred field types
let mut column_types: Vec<HashSet<DataType>> = vec![HashSet::new(); header_length];
let mut column_types: Vec<InferredDataType> = vec![Default::default(); header_length];

let mut records_count = 0;
let mut fields = vec![];

let mut record = StringRecord::new();
let max_records = roptions.max_read_records.unwrap_or(usize::MAX);
Expand All @@ -248,40 +270,18 @@ fn infer_reader_schema_with_csv_options<R: Read>(
for (i, column_type) in column_types.iter_mut().enumerate().take(header_length) {
if let Some(string) = record.get(i) {
if !string.is_empty() {
column_type
.insert(infer_field_schema(string, roptions.datetime_re.clone()));
column_type.update(string, roptions.datetime_re.as_ref())
}
}
}
}

// build schema from inference results
for i in 0..header_length {
let possibilities = &column_types[i];
let field_name = &headers[i];

// determine data type based on possible types
// if there are incompatible types, use DataType::Utf8
match possibilities.len() {
1 => {
for dtype in possibilities.iter() {
fields.push(Field::new(field_name, dtype.clone(), true));
}
}
2 => {
if possibilities.contains(&DataType::Int64)
&& possibilities.contains(&DataType::Float64)
{
// we have an integer and double, fall down to double
fields.push(Field::new(field_name, DataType::Float64, true));
} else {
// default to Utf8 for conflicting datatypes (e.g bool and int)
fields.push(Field::new(field_name, DataType::Utf8, true));
}
}
_ => fields.push(Field::new(field_name, DataType::Utf8, true)),
}
}
let fields = column_types
.iter()
.zip(&headers)
.map(|(inferred, field_name)| Field::new(field_name, inferred.get(), true))
.collect();

Ok((Schema::new(fields), records_count))
}
Expand Down Expand Up @@ -681,6 +681,19 @@ fn parse(
>(
line_number, rows, i, None
),
DataType::Timestamp(TimeUnit::Second, _) => build_primitive_array::<
TimestampSecondType,
>(
line_number, rows, i, None
),
DataType::Timestamp(TimeUnit::Millisecond, _) => {
build_primitive_array::<TimestampMillisecondType>(
line_number,
rows,
i,
None,
)
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
build_primitive_array::<TimestampMicrosecondType>(
line_number,
Expand Down Expand Up @@ -1637,7 +1650,10 @@ mod tests {
assert_eq!(&DataType::Float64, schema.field(2).data_type());
assert_eq!(&DataType::Boolean, schema.field(3).data_type());
assert_eq!(&DataType::Date32, schema.field(4).data_type());
assert_eq!(&DataType::Date64, schema.field(5).data_type());
assert_eq!(
&DataType::Timestamp(TimeUnit::Second, None),
schema.field(5).data_type()
);

let names: Vec<&str> =
schema.fields().iter().map(|x| x.name().as_str()).collect();
Expand Down Expand Up @@ -1698,6 +1714,13 @@ mod tests {
}
}

/// Infer the data type of a record
fn infer_field_schema(string: &str, datetime_re: Option<Regex>) -> DataType {
let mut v = InferredDataType::default();
v.update(string, datetime_re.as_ref());
v.get()
}

#[test]
fn test_infer_field_schema() {
assert_eq!(infer_field_schema("A", None), DataType::Utf8);
Expand All @@ -1712,22 +1735,22 @@ mod tests {
assert_eq!(infer_field_schema("2020-11-08", None), DataType::Date32);
assert_eq!(
infer_field_schema("2020-11-08T14:20:01", None),
DataType::Date64
DataType::Timestamp(TimeUnit::Second, None)
);
assert_eq!(
infer_field_schema("2020-11-08 14:20:01", None),
DataType::Date64
DataType::Timestamp(TimeUnit::Second, None)
);
let reg = Regex::new(r"^\d{4}-\d\d-\d\d \d\d:\d\d:\d\d$").ok();
assert_eq!(
infer_field_schema("2020-11-08 14:20:01", reg),
DataType::Date64
DataType::Timestamp(TimeUnit::Second, None)
);
assert_eq!(infer_field_schema("-5.13", None), DataType::Float64);
assert_eq!(infer_field_schema("0.1300", None), DataType::Float64);
assert_eq!(
infer_field_schema("2021-12-19 13:12:30.921", None),
DataType::Timestamp(TimeUnit::Nanosecond, None)
DataType::Timestamp(TimeUnit::Millisecond, None)
);
assert_eq!(
infer_field_schema("2021-12-19T13:12:30.123456789", None),
Expand Down Expand Up @@ -2407,4 +2430,59 @@ mod tests {
assert_eq!(&read.fill_sizes, &[23, 3, 0, 0]);
assert_eq!(read.fill_count, 4);
}

#[test]
fn test_inference() {
let cases: &[(&[&str], DataType)] = &[
(&[], DataType::Utf8),
(&["false", "12"], DataType::Utf8),
(&["12", "cupcakes"], DataType::Utf8),
(&["12", "12.4"], DataType::Float64),
(&["14050", "24332"], DataType::Int64),
(&["14050.0", "true"], DataType::Utf8),
(&["14050", "2020-03-19 00:00:00"], DataType::Utf8),
(&["14050", "2340.0", "2020-03-19 00:00:00"], DataType::Utf8),
(
&["2020-03-19 02:00:00", "2020-03-19 00:00:00"],
DataType::Timestamp(TimeUnit::Second, None),
),
(&["2020-03-19", "2020-03-20"], DataType::Date32),
(
&["2020-03-19", "2020-03-19 02:00:00", "2020-03-19 00:00:00"],
DataType::Timestamp(TimeUnit::Second, None),
),
(
&[
"2020-03-19",
"2020-03-19 02:00:00",
"2020-03-19 00:00:00.000",
],
DataType::Timestamp(TimeUnit::Millisecond, None),
),
(
&[
"2020-03-19",
"2020-03-19 02:00:00",
"2020-03-19 00:00:00.000000",
],
DataType::Timestamp(TimeUnit::Microsecond, None),
),
(
&[
"2020-03-19",
"2020-03-19 02:00:00.000000000",
"2020-03-19 00:00:00.000000",
],
DataType::Timestamp(TimeUnit::Nanosecond, None),
),
];

for (values, expected) in cases {
let mut t = InferredDataType::default();
for v in *values {
t.update(v, None)
}
assert_eq!(&t.get(), expected, "{:?}", values)
}
}
}