Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions cpp/src/arrow/pretty_print.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ class PrettyPrinter {
void OpenArray();
void CloseArray();

void Flush() {
(*sink_) << std::flush;
}
void Flush() { (*sink_) << std::flush; }

protected:
int indent_;
Expand Down
7 changes: 7 additions & 0 deletions python/doc/source/parquet.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,13 @@ such as those produced by Hive:
dataset = pq.ParquetDataset('dataset_name/')
table = dataset.read()

Using with Spark
----------------

Spark places some constraints on the types of Parquet files it will read. The
option ``flavor='spark'`` will set these options automatically and also
sanitize field characters unsupported by Spark SQL.

Multithreaded Reads
-------------------

Expand Down
96 changes: 88 additions & 8 deletions python/pyarrow/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@
import os
import inspect
import json

import re
import six

import numpy as np

from pyarrow.filesystem import FileSystem, LocalFileSystem, S3FSWrapper
from pyarrow._parquet import (ParquetReader, FileMetaData, # noqa
RowGroupMetaData, ParquetSchema,
ParquetWriter)
RowGroupMetaData, ParquetSchema)
import pyarrow._parquet as _parquet # noqa
import pyarrow.lib as lib
import pyarrow as pa


# ----------------------------------------------------------------------
Expand Down Expand Up @@ -164,6 +164,73 @@ def _get_column_indices(self, column_names, use_pandas_metadata=False):
return indices


_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]')


def _sanitized_spark_field_name(name):
return _SPARK_DISALLOWED_CHARS.sub('_', name)


def _sanitize_schema(schema, flavor):
if 'spark' in flavor:
sanitized_fields = []

schema_changed = False

for field in schema:
name = field.name
sanitized_name = _sanitized_spark_field_name(name)

if sanitized_name != name:
schema_changed = True
sanitized_field = pa.field(sanitized_name, field.type,
field.nullable, field.metadata)
sanitized_fields.append(sanitized_field)
else:
sanitized_fields.append(field)
return pa.schema(sanitized_fields), schema_changed
else:
return schema, False


def _sanitize_table(table, new_schema, flavor):
# TODO: This will not handle prohibited characters in nested field names
if 'spark' in flavor:
column_data = [table[i].data for i in range(table.num_columns)]
return pa.Table.from_arrays(column_data, schema=new_schema)
else:
return table


class ParquetWriter(object):
"""

Parameters
----------
where
schema
flavor : {'spark', ...}
Set options for compatibility with a particular reader
"""
def __init__(self, where, schema, flavor=None, **options):
self.flavor = flavor
if flavor is not None:
schema, self.schema_changed = _sanitize_schema(schema, flavor)
else:
self.schema_changed = False

self.schema = schema
self.writer = _parquet.ParquetWriter(where, schema, **options)

def write_table(self, table, row_group_size=None):
if self.schema_changed:
table = _sanitize_table(table, self.schema, self.flavor)
self.writer.write_table(table, row_group_size=row_group_size)

def close(self):
self.writer.close()


def _get_pandas_index_columns(keyvalues):
return (json.loads(keyvalues[b'pandas'].decode('utf8'))
['index_columns'])
Expand Down Expand Up @@ -787,8 +854,9 @@ def read_pandas(source, columns=None, nthreads=1, metadata=None):

def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary=True, compression='snappy',
use_deprecated_int96_timestamps=False,
coerce_timestamps=None, **kwargs):
use_deprecated_int96_timestamps=None,
coerce_timestamps=None,
flavor=None, **kwargs):
"""
Write a Table to Parquet format

Expand All @@ -804,15 +872,26 @@ def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary : bool or list
Specify if we should use dictionary encoding in general or only for
some columns.
use_deprecated_int96_timestamps : boolean, default False
Write nanosecond resolution timestamps to INT96 Parquet format
use_deprecated_int96_timestamps : boolean, default None
Write nanosecond resolution timestamps to INT96 Parquet
format. Defaults to False unless enabled by flavor argument
coerce_timestamps : string, default None
Cast timestamps a particular resolution.
Valid values: {None, 'ms', 'us'}
compression : str or dict
Specify the compression codec, either on a general basis or per-column.
flavor : {'spark'}, default None
Sanitize schema or set other compatibility options for compatibility
"""
row_group_size = kwargs.get('chunk_size', row_group_size)

if use_deprecated_int96_timestamps is None:
# Use int96 timestamps for Spark
if flavor is not None and 'spark' in flavor:
use_deprecated_int96_timestamps = True
else:
use_deprecated_int96_timestamps = False

options = dict(
use_dictionary=use_dictionary,
compression=compression,
Expand All @@ -822,7 +901,8 @@ def write_table(table, where, row_group_size=None, version='1.0',

writer = None
try:
writer = ParquetWriter(where, table.schema, **options)
writer = ParquetWriter(where, table.schema, flavor=flavor,
**options)
writer.write_table(table, row_group_size=row_group_size)
except:
if writer is not None:
Expand Down
31 changes: 24 additions & 7 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,7 @@ cdef class Table:
return cls.from_arrays(arrays, names=names, metadata=metadata)

@staticmethod
def from_arrays(arrays, names=None, dict metadata=None):
def from_arrays(arrays, names=None, schema=None, dict metadata=None):
"""
Construct a Table from Arrow arrays or columns

Expand All @@ -777,35 +777,52 @@ cdef class Table:
"""
cdef:
vector[shared_ptr[CColumn]] columns
shared_ptr[CSchema] schema
Schema cy_schema
shared_ptr[CSchema] c_schema
shared_ptr[CTable] table
int i, K = <int> len(arrays)

_schema_from_arrays(arrays, names, metadata, &schema)
if schema is None:
_schema_from_arrays(arrays, names, metadata, &c_schema)
elif schema is not None:
if names is not None:
raise ValueError('Cannot pass schema and arrays')
cy_schema = schema

if len(schema) != len(arrays):
raise ValueError('Schema and number of arrays unequal')

c_schema = cy_schema.sp_schema

columns.reserve(K)

for i in range(K):
if isinstance(arrays[i], Array):
columns.push_back(
make_shared[CColumn](
schema.get().field(i),
c_schema.get().field(i),
(<Array> arrays[i]).sp_array
)
)
elif isinstance(arrays[i], ChunkedArray):
columns.push_back(
make_shared[CColumn](
schema.get().field(i),
c_schema.get().field(i),
(<ChunkedArray> arrays[i]).sp_chunked_array
)
)
elif isinstance(arrays[i], Column):
columns.push_back((<Column> arrays[i]).sp_column)
# Make sure schema field and column are consistent
columns.push_back(
make_shared[CColumn](
c_schema.get().field(i),
(<Column> arrays[i]).sp_column.get().data()
)
)
else:
raise ValueError(type(arrays[i]))

table.reset(new CTable(schema, columns))
table.reset(new CTable(c_schema, columns))
return pyarrow_wrap_table(table)

@staticmethod
Expand Down
28 changes: 26 additions & 2 deletions python/pyarrow/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from os.path import join as pjoin
import datetime
import gc
import io
import os
import json
Expand Down Expand Up @@ -562,6 +563,10 @@ def test_date_time_types():
_check_roundtrip(table, expected=expected, version='2.0',
use_deprecated_int96_timestamps=True)

# Check that setting flavor to 'spark' uses int96 timestamps
_check_roundtrip(table, expected=expected, version='2.0',
flavor='spark')

# Unsupported stuff
def _assert_unsupported(array):
table = pa.Table.from_arrays([array], ['unsupported'])
Expand All @@ -576,6 +581,18 @@ def _assert_unsupported(array):
_assert_unsupported(a7)


@parquet
def test_sanitized_spark_field_names():
a0 = pa.array([0, 1, 2, 3, 4])
name = 'prohib; ,\t{}'
table = pa.Table.from_arrays([a0], [name])

result = _roundtrip_table(table, flavor='spark')

expected_name = 'prohib______'
assert result.schema[0].name == expected_name


@parquet
def test_fixed_size_binary():
t0 = pa.binary(10)
Expand All @@ -587,15 +604,19 @@ def test_fixed_size_binary():
_check_roundtrip(table)


def _check_roundtrip(table, expected=None, **params):
def _roundtrip_table(table, **params):
buf = io.BytesIO()
_write_table(table, buf, **params)
buf.seek(0)

return _read_table(buf)


def _check_roundtrip(table, expected=None, **params):
if expected is None:
expected = table

result = _read_table(buf)
result = _roundtrip_table(table, **params)
assert result.equals(expected)


Expand Down Expand Up @@ -1181,6 +1202,9 @@ def test_write_error_deletes_incomplete_file(tmpdir):
except pa.ArrowException:
pass

# Ensure that object has been destructed; this causes test failures on
# Windows
gc.collect()
assert not os.path.exists(filename)


Expand Down
5 changes: 4 additions & 1 deletion python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ cdef class Schema:
return self.schema.num_fields()

def __getitem__(self, int i):

cdef:
Field result = Field()
int num_fields = self.schema.num_fields()
Expand All @@ -318,6 +317,10 @@ cdef class Schema:

return result

def __iter__(self):
for i in range(len(self)):
yield self[i]

def _check_null(self):
if self.schema == NULL:
raise ReferenceError(
Expand Down