Skip to content

Commit

Permalink
Add CSV options to the CSV parser (#28491)
Browse files Browse the repository at this point in the history
* remove invalid legacy option

* remove unused option

* the tests pass but this is quite messy

* very slight clean up

* Add skip options to csv format

* fix some of the typing issues

* fixme comment

* remove extra log message

* fix typing issues

* skip before header

* skip after header

* format

* add another test

* Automated Commit - Formatting Changes

* auto generate column names

* delete dead code

* update title and description

* true and false values

* Update the tests

* Add comment

* missing test

* rename

* update expected spec

* move to method

* Update comment

* fix typo

* remove unused import

* Add a comment

* None records do not pass the WaitForDiscoverPolicy

* format

* remove second branch to ensure we always go through the same processing

* Raise an exception if the record is None

* reset

* Update tests

* handle unquoted newlines

* Automated Commit - Formatting Changes

* Update test case so the quoting is explicit

* Update comment

* Automated Commit - Formatting Changes

* Fail validation if skipping rows before header and header is autogenerated

* always fail if a record cannot be parsed

* format

* set write line_no in error message

* remove none check

* Automated Commit - Formatting Changes

* enable autogenerate test

* remove duplicate test

* missing unit tests

* Update

* remove branching

* remove unused none check

* Update tests

* remove branching

* format

* extract to function

* comment

* missing type

* type annotation

* use set

* Document that the strings are case-sensitive

* public -> private

* add unit test

* newline

---------

Co-authored-by: girarda <girarda@users.noreply.github.com>
  • Loading branch information
girarda and girarda committed Aug 3, 2023
1 parent 0a4be6e commit 641a65a
Show file tree
Hide file tree
Showing 8 changed files with 1,533 additions and 130 deletions.
Expand Up @@ -4,9 +4,9 @@

import codecs
from enum import Enum
from typing import Optional
from typing import Any, Mapping, Optional, Set

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, root_validator, validator
from typing_extensions import Literal


Expand All @@ -17,6 +17,10 @@ class QuotingBehavior(Enum):
QUOTE_NONE = "Quote None"


DEFAULT_TRUE_VALUES = ["y", "yes", "t", "true", "on", "1"]
DEFAULT_FALSE_VALUES = ["n", "no", "f", "false", "off", "0"]


class CsvFormat(BaseModel):
filetype: Literal["csv"] = "csv"
delimiter: str = Field(
Expand Down Expand Up @@ -46,10 +50,34 @@ class CsvFormat(BaseModel):
default=QuotingBehavior.QUOTE_SPECIAL_CHARACTERS,
description="The quoting behavior determines when a value in a row should have quote marks added around it. For example, if Quote Non-numeric is specified, while reading, quotes are expected for row values that do not contain numbers. Or for Quote All, every row value will be expecting quotes.",
)

# Noting that the existing S3 connector had a config option newlines_in_values. This was only supported by pyarrow and not
# the Python csv package. It has a little adoption, but long term we should ideally phase this out because of the drawbacks
# of using pyarrow
null_values: Set[str] = Field(
title="Null Values",
default=[],
description="A set of case-sensitive strings that should be interpreted as null values. For example, if the value 'NA' should be interpreted as null, enter 'NA' in this field.",
)
skip_rows_before_header: int = Field(
title="Skip Rows Before Header",
default=0,
description="The number of rows to skip before the header row. For example, if the header row is on the 3rd row, enter 2 in this field.",
)
skip_rows_after_header: int = Field(
title="Skip Rows After Header", default=0, description="The number of rows to skip after the header row."
)
autogenerate_column_names: bool = Field(
title="Autogenerate Column Names",
default=False,
description="Whether to autogenerate column names if column_names is empty. If true, column names will be of the form “f0”, “f1”… If false, column names will be read from the first CSV row after skip_rows_before_header.",
)
true_values: Set[str] = Field(
title="True Values",
default=DEFAULT_TRUE_VALUES,
description="A set of case-sensitive strings that should be interpreted as true values.",
)
false_values: Set[str] = Field(
title="False Values",
default=DEFAULT_FALSE_VALUES,
description="A set of case-sensitive strings that should be interpreted as false values.",
)

@validator("delimiter")
def validate_delimiter(cls, v: str) -> str:
Expand Down Expand Up @@ -78,3 +106,11 @@ def validate_encoding(cls, v: str) -> str:
except LookupError:
raise ValueError(f"invalid encoding format: {v}")
return v

@root_validator
def validate_option_combinations(cls, values: Mapping[str, Any]) -> Mapping[str, Any]:
skip_rows_before_header = values.get("skip_rows_before_header", 0)
auto_generate_column_names = values.get("autogenerate_column_names", False)
if skip_rows_before_header > 0 and auto_generate_column_names:
raise ValueError("Cannot skip rows before header and autogenerate column names at the same time.")
return values
Expand Up @@ -5,12 +5,13 @@
import csv
import json
import logging
from distutils.util import strtobool
from typing import Any, Dict, Iterable, Mapping, Optional
from functools import partial
from io import IOBase
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set

from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat, QuotingBehavior
from airbyte_cdk.sources.file_based.config.file_based_stream_config import FileBasedStreamConfig
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError
from airbyte_cdk.sources.file_based.exceptions import FileBasedSourceError, RecordParseError
from airbyte_cdk.sources.file_based.file_based_stream_reader import AbstractFileBasedStreamReader, FileReadMode
from airbyte_cdk.sources.file_based.file_types.file_type_parser import FileTypeParser
from airbyte_cdk.sources.file_based.remote_file import RemoteFile
Expand All @@ -34,30 +35,25 @@ async def infer_schema(
stream_reader: AbstractFileBasedStreamReader,
logger: logging.Logger,
) -> Dict[str, Any]:
config_format = config.format.get(config.file_type) if config.format else None
if config_format:
if not isinstance(config_format, CsvFormat):
raise ValueError(f"Invalid format config: {config_format}")
dialect_name = config.name + DIALECT_NAME
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
quotechar=config_format.quote_char,
escapechar=config_format.escape_char,
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
schema = {field.strip(): {"type": "string"} for field in next(reader)}
csv.unregister_dialect(dialect_name)
return schema
else:
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
reader = csv.DictReader(fp) # type: ignore
return {field.strip(): {"type": "string"} for field in next(reader)}
config_format = config.format.get(config.file_type) if config.format else CsvFormat()
if not isinstance(config_format, CsvFormat):
raise ValueError(f"Invalid format config: {config_format}")
dialect_name = config.name + DIALECT_NAME
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
quotechar=config_format.quote_char,
escapechar=config_format.escape_char,
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
headers = self._get_headers(fp, config_format, dialect_name)
schema = {field.strip(): {"type": "string"} for field in headers}
csv.unregister_dialect(dialect_name)
return schema

def parse_records(
self,
Expand All @@ -67,38 +63,36 @@ def parse_records(
logger: logging.Logger,
) -> Iterable[Dict[str, Any]]:
schema: Mapping[str, Any] = config.input_schema # type: ignore
config_format = config.format.get(config.file_type) if config.format else None
if config_format:
if not isinstance(config_format, CsvFormat):
raise ValueError(f"Invalid format config: {config_format}")
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
# Wwe don't unregister the dialect because we are lazily parsing each csv file to generate records
dialect_name = config.name + DIALECT_NAME
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
quotechar=config_format.quote_char,
escapechar=config_format.escape_char,
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
yield from self._read_and_cast_types(reader, schema, logger)
else:
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
reader = csv.DictReader(fp) # type: ignore
yield from self._read_and_cast_types(reader, schema, logger)
config_format = config.format.get(config.file_type) if config.format else CsvFormat()
if not isinstance(config_format, CsvFormat):
raise ValueError(f"Invalid format config: {config_format}")
# Formats are configured individually per-stream so a unique dialect should be registered for each stream.
# We don't unregister the dialect because we are lazily parsing each csv file to generate records
# This will potentially be a problem if we ever process multiple streams concurrently
dialect_name = config.name + DIALECT_NAME
csv.register_dialect(
dialect_name,
delimiter=config_format.delimiter,
quotechar=config_format.quote_char,
escapechar=config_format.escape_char,
doublequote=config_format.double_quote,
quoting=config_to_quoting.get(config_format.quoting_behavior, csv.QUOTE_MINIMAL),
)
with stream_reader.open_file(file, self.file_read_mode, logger) as fp:
# todo: the existing InMemoryFilesSource.open_file() test source doesn't currently require an encoding, but actual
# sources will likely require one. Rather than modify the interface now we can wait until the real use case
self._skip_rows_before_header(fp, config_format.skip_rows_before_header)
field_names = self._auto_generate_headers(fp, config_format) if config_format.autogenerate_column_names else None
reader = csv.DictReader(fp, dialect=dialect_name, fieldnames=field_names) # type: ignore
yield from self._read_and_cast_types(reader, schema, config_format, logger)

@property
def file_read_mode(self) -> FileReadMode:
return FileReadMode.READ

@staticmethod
def _read_and_cast_types(
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], logger: logging.Logger # type: ignore
reader: csv.DictReader, schema: Optional[Mapping[str, Any]], config_format: CsvFormat, logger: logging.Logger # type: ignore
) -> Iterable[Dict[str, Any]]:
"""
If the user provided a schema, attempt to cast the record values to the associated type.
Expand All @@ -107,16 +101,65 @@ def _read_and_cast_types(
cast it to a string. Downstream, the user's validation policy will determine whether the
record should be emitted.
"""
if not schema:
yield from reader
cast_fn = CsvParser._get_cast_function(schema, config_format, logger)
for i, row in enumerate(reader):
if i < config_format.skip_rows_after_header:
continue
# The row was not properly parsed if any of the values are None
if any(val is None for val in row.values()):
raise RecordParseError(FileBasedSourceError.ERROR_PARSING_RECORD)
else:
yield CsvParser._to_nullable(cast_fn(row), config_format.null_values)

else:
@staticmethod
def _get_cast_function(
schema: Optional[Mapping[str, Any]], config_format: CsvFormat, logger: logging.Logger
) -> Callable[[Mapping[str, str]], Mapping[str, str]]:
# Only cast values if the schema is provided
if schema:
property_types = {col: prop["type"] for col, prop in schema["properties"].items()}
for row in reader:
yield cast_types(row, property_types, logger)
return partial(_cast_types, property_types=property_types, config_format=config_format, logger=logger)
else:
# If no schema is provided, yield the rows as they are
return _no_cast

@staticmethod
def _to_nullable(row: Mapping[str, str], null_values: Set[str]) -> Dict[str, Optional[str]]:
nullable = row | {k: None if v in null_values else v for k, v in row.items()}
return nullable

@staticmethod
def _skip_rows_before_header(fp: IOBase, rows_to_skip: int) -> None:
"""
Skip rows before the header. This has to be done on the file object itself, not the reader
"""
for _ in range(rows_to_skip):
fp.readline()

def _get_headers(self, fp: IOBase, config_format: CsvFormat, dialect_name: str) -> List[str]:
# Note that this method assumes the dialect has already been registered if we're parsing the headers
if config_format.autogenerate_column_names:
return self._auto_generate_headers(fp, config_format)
else:
# If we're not autogenerating column names, we need to skip the rows before the header
self._skip_rows_before_header(fp, config_format.skip_rows_before_header)
# Then read the header
reader = csv.DictReader(fp, dialect=dialect_name) # type: ignore
return next(reader) # type: ignore

def _auto_generate_headers(self, fp: IOBase, config_format: CsvFormat) -> List[str]:
"""
Generates field names as [f0, f1, ...] in the same way as pyarrow's csv reader with autogenerate_column_names=True.
See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ReadOptions.html
"""
next_line = next(fp).strip()
number_of_columns = len(next_line.split(config_format.delimiter)) # type: ignore
# Reset the file pointer to the beginning of the file so that the first row is not skipped
fp.seek(0)
return [f"f{i}" for i in range(number_of_columns)]

def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logging.Logger) -> Dict[str, Any]:

def _cast_types(row: Dict[str, str], property_types: Dict[str, Any], config_format: CsvFormat, logger: logging.Logger) -> Dict[str, Any]:
"""
Casts the values in the input 'row' dictionary according to the types defined in the JSON schema.
Expand All @@ -142,7 +185,7 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg

elif python_type == bool:
try:
cast_value = strtobool(value)
cast_value = _value_to_bool(value, config_format.true_values, config_format.false_values)
except ValueError:
warnings.append(_format_warning(key, value, prop_type))

Expand Down Expand Up @@ -178,5 +221,17 @@ def cast_types(row: Dict[str, str], property_types: Dict[str, Any], logger: logg
return result


def _value_to_bool(value: str, true_values: Set[str], false_values: Set[str]) -> bool:
if value in true_values:
return True
if value in false_values:
return False
raise ValueError(f"Value {value} is not a valid boolean value")


def _format_warning(key: str, value: str, expected_type: Optional[Any]) -> str:
return f"{key}: value={value},expected_type={expected_type}"


def _no_cast(row: Mapping[str, str]) -> Mapping[str, str]:
return row
Expand Up @@ -15,6 +15,7 @@
FileBasedSourceError,
InvalidSchemaError,
MissingSchemaError,
RecordParseError,
SchemaInferenceError,
StopSyncPerValidationPolicy,
)
Expand Down Expand Up @@ -105,6 +106,18 @@ def read_records_from_slice(self, stream_slice: StreamSlice) -> Iterable[Mapping
)
break

except RecordParseError:
# Increment line_no because the exception was raised before we could increment it
line_no += 1
yield AirbyteMessage(
type=MessageType.LOG,
log=AirbyteLogMessage(
level=Level.ERROR,
message=f"{FileBasedSourceError.ERROR_PARSING_RECORD.value} stream={self.name} file={file.uri} line_no={line_no} n_skipped={n_skipped}",
stack_trace=traceback.format_exc(),
),
)

except Exception:
yield AirbyteMessage(
type=MessageType.LOG,
Expand Down
@@ -0,0 +1,23 @@
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import pytest as pytest
from airbyte_cdk.sources.file_based.config.csv_format import CsvFormat


@pytest.mark.parametrize(
"skip_rows_before_header, autogenerate_column_names, expected_error",
[
pytest.param(1, True, ValueError, id="test_skip_rows_before_header_and_autogenerate_column_names"),
pytest.param(1, False, None, id="test_skip_rows_before_header_and_no_autogenerate_column_names"),
pytest.param(0, True, None, id="test_no_skip_rows_before_header_and_autogenerate_column_names"),
pytest.param(0, False, None, id="test_no_skip_rows_before_header_and_no_autogenerate_column_names"),
]
)
def test_csv_format(skip_rows_before_header, autogenerate_column_names, expected_error):
if expected_error:
with pytest.raises(expected_error):
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names)
else:
CsvFormat(skip_rows_before_header=skip_rows_before_header, autogenerate_column_names=autogenerate_column_names)

0 comments on commit 641a65a

Please sign in to comment.