Skip to content

Commit

Permalink
ARROW-8079: [Python] Implement a wrapper for KeyValueMetadata, duck-t…
Browse files Browse the repository at this point in the history
…yping dict where relevant

TODOs:
- [x] update the wrapper in public-api.pxi
- [x] update the binding objects to use the new KVM

Closes #6793 from kszucs/ARROW-8079

Lead-authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
Co-authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
kszucs and pitrou committed Apr 2, 2020
1 parent 04c467f commit d297a2f
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 64 deletions.
37 changes: 37 additions & 0 deletions cpp/src/arrow/util/key_value_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
#include <utility>
#include <vector>

#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/sort.h"
Expand Down Expand Up @@ -83,6 +85,41 @@ void KeyValueMetadata::Append(const std::string& key, const std::string& value)
values_.push_back(value);
}

Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
} else {
return value(index);
}
}

Status KeyValueMetadata::Delete(const std::string& key) {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
} else {
keys_.erase(keys_.begin() + index);
values_.erase(values_.begin() + index);
return Status::OK();
}
}

Status KeyValueMetadata::Set(const std::string& key, const std::string& value) {
auto index = FindKey(key);
if (index < 0) {
Append(key, value);
} else {
keys_[index] = key;
values_[index] = value;
}
return Status::OK();
}

bool KeyValueMetadata::Contains(const std::string& key) const {
return FindKey(key) >= 0;
}

void KeyValueMetadata::reserve(int64_t n) {
DCHECK_GE(n, 0);
const auto m = static_cast<size_t>(n);
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/arrow/util/key_value_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <utility>
#include <vector>

#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"

Expand All @@ -38,9 +40,13 @@ class ARROW_EXPORT KeyValueMetadata {
virtual ~KeyValueMetadata() = default;

void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;

void Append(const std::string& key, const std::string& value);

Result<std::string> Get(const std::string& key) const;
Status Delete(const std::string& key);
Status Set(const std::string& key, const std::string& value);
bool Contains(const std::string& key) const;

void reserve(int64_t n);
int64_t size() const;

Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def parse_git(root, **kwargs):
PyExtensionType, UnknownExtensionType,
register_extension_type, unregister_extension_type,
DictionaryMemo,
KeyValueMetadata,
Field,
Schema,
schema,
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@ cdef extern from "arrow/util/key_value_metadata.h" namespace "arrow" nogil:
int64_t size() const
c_string key(int64_t i) const
c_string value(int64_t i) const
int FindKey(const c_string& key) const

shared_ptr[CKeyValueMetadata] Copy() const
c_bool Equals(const CKeyValueMetadata& other)
void Append(const c_string& key, const c_string& value)
void ToUnorderedMap(unordered_map[c_string, c_string]*) const
c_string ToString() const

CResult[c_string] Get(const c_string& key) const
CStatus Delete(const c_string& key)
CStatus Set(const c_string& key, const c_string& value)
c_bool Contains(const c_string& key) const


cdef extern from "arrow/api.h" namespace "arrow" nogil:
Expand Down
17 changes: 17 additions & 0 deletions python/pyarrow/lib.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,23 @@ cdef class PyExtensionType(ExtensionType):
pass


cdef class _Metadata:
# required because KeyValueMetadata also extends collections.abc.Mapping
# and the first parent class must be an extension type
pass


cdef class KeyValueMetadata(_Metadata):
cdef:
shared_ptr[const CKeyValueMetadata] wrapped
const CKeyValueMetadata* metadata

cdef void init(self, const shared_ptr[const CKeyValueMetadata]& wrapped)
@staticmethod
cdef wrap(const shared_ptr[const CKeyValueMetadata]& sp)
cdef inline shared_ptr[const CKeyValueMetadata] unwrap(self) nogil


cdef class Field:
cdef:
shared_ptr[CField] sp_field
Expand Down
6 changes: 2 additions & 4 deletions python/pyarrow/pandas_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import operator
import re
import warnings
from copy import deepcopy

import numpy as np

Expand Down Expand Up @@ -235,7 +234,7 @@ def construct_metadata(df, column_names, index_levels, index_descriptors,
index_descriptors = index_column_metadata = column_indexes = []

return {
b'pandas': json.dumps({
'pandas': json.dumps({
'index_columns': index_descriptors,
'column_indexes': column_indexes,
'columns': column_metadata + index_column_metadata,
Expand Down Expand Up @@ -591,8 +590,7 @@ def convert_column(col, field):
pandas_metadata = construct_metadata(df, column_names, index_columns,
index_descriptors, preserve_index,
types)
metadata = deepcopy(schema.metadata) if schema.metadata else dict()
metadata.update(pandas_metadata)
metadata = pa.KeyValueMetadata(schema.metadata or {}, **pandas_metadata)
schema = schema.with_metadata(metadata)

return arrays, schema
Expand Down
27 changes: 10 additions & 17 deletions python/pyarrow/public-api.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -119,28 +119,21 @@ cdef api object pyarrow_wrap_data_type(

cdef object pyarrow_wrap_metadata(
const shared_ptr[const CKeyValueMetadata]& meta):
cdef const CKeyValueMetadata* cmeta = meta.get()

if cmeta == nullptr:
if meta.get() == nullptr:
return None

result = ordered_dict()
for i in range(cmeta.size()):
result[cmeta.key(i)] = cmeta.value(i)

return result
else:
return KeyValueMetadata.wrap(meta)


cdef shared_ptr[CKeyValueMetadata] pyarrow_unwrap_metadata(object meta) \
except *:
cdef vector[c_string] keys, values
cdef api bint pyarrow_is_metadata(object metadata):
return isinstance(metadata, KeyValueMetadata)

if isinstance(meta, dict):
keys = map(tobytes, meta.keys())
values = map(tobytes, meta.values())
return make_shared[CKeyValueMetadata](keys, values)

return shared_ptr[CKeyValueMetadata]()
cdef shared_ptr[const CKeyValueMetadata] pyarrow_unwrap_metadata(object meta):
cdef shared_ptr[const CKeyValueMetadata] c_meta
if pyarrow_is_metadata(meta):
c_meta = (<KeyValueMetadata>meta).unwrap()
return c_meta


cdef api bint pyarrow_is_field(object field):
Expand Down
24 changes: 8 additions & 16 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -453,13 +453,11 @@ cdef _schema_from_arrays(arrays, names, metadata, shared_ptr[CSchema]* schema):
Py_ssize_t K = len(arrays)
c_string c_name
shared_ptr[CDataType] c_type
shared_ptr[CKeyValueMetadata] c_meta
shared_ptr[const CKeyValueMetadata] c_meta
vector[shared_ptr[CField]] c_fields

if metadata is not None:
if not isinstance(metadata, dict):
raise TypeError('Metadata must be an instance of dict')
c_meta = pyarrow_unwrap_metadata(metadata)
c_meta = KeyValueMetadata(metadata).unwrap()

if K == 0:
schema.reset(new CSchema(c_fields, c_meta))
Expand Down Expand Up @@ -600,14 +598,11 @@ cdef class RecordBatch(_PandasConvertible):
shallow_copy : RecordBatch
"""
cdef:
shared_ptr[CKeyValueMetadata] c_meta
shared_ptr[const CKeyValueMetadata] c_meta
shared_ptr[CRecordBatch] c_batch

if metadata is not None:
if not isinstance(metadata, dict):
raise TypeError('Metadata must be an instance of dict')
c_meta = pyarrow_unwrap_metadata(metadata)

metadata = ensure_metadata(metadata, allow_none=True)
c_meta = pyarrow_unwrap_metadata(metadata)
with nogil:
c_batch = self.batch.ReplaceSchemaMetadata(c_meta)

Expand Down Expand Up @@ -1126,14 +1121,11 @@ cdef class Table(_PandasConvertible):
shallow_copy : Table
"""
cdef:
shared_ptr[CKeyValueMetadata] c_meta
shared_ptr[const CKeyValueMetadata] c_meta
shared_ptr[CTable] c_table

if metadata is not None:
if not isinstance(metadata, dict):
raise TypeError('Metadata must be an instance of dict')
c_meta = pyarrow_unwrap_metadata(metadata)

metadata = ensure_metadata(metadata, allow_none=True)
c_meta = pyarrow_unwrap_metadata(metadata)
with nogil:
c_table = self.table.ReplaceSchemaMetadata(c_meta)

Expand Down
67 changes: 67 additions & 0 deletions python/pyarrow/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from collections import OrderedDict
from collections.abc import Iterator

import pickle
import pytest
Expand Down Expand Up @@ -565,6 +566,72 @@ def test_type_equality_operators():
assert ty != other


def test_key_value_metadata():
m = pa.KeyValueMetadata({'a': 'A', 'b': 'B'})
assert len(m) == 2
assert m['a'] == b'A'
assert m[b'a'] == b'A'
assert m['b'] == b'B'
assert 'a' in m
assert b'a' in m
assert 'c' not in m

m1 = pa.KeyValueMetadata({'a': 'A', 'b': 'B'})
m2 = pa.KeyValueMetadata(a='A', b='B')
m3 = pa.KeyValueMetadata([('a', 'A'), ('b', 'B')])

assert m1 != 2
assert m1 == m2
assert m2 == m3
assert m1 == {'a': 'A', 'b': 'B'}
assert m1 != {'a': 'A', 'b': 'C'}

with pytest.raises(TypeError):
pa.KeyValueMetadata({'a': 1})
with pytest.raises(TypeError):
pa.KeyValueMetadata({1: 'a'})
with pytest.raises(TypeError):
pa.KeyValueMetadata(a=1)

expected = [(b'a', b'A'), (b'b', b'B')]
result = [(k, v) for k, v in m3.items()]
assert result == expected
assert list(m3.items()) == expected
assert list(m3.keys()) == [b'a', b'b']
assert list(m3.values()) == [b'A', b'B']
assert len(m3) == 2

# test duplicate key support
md = pa.KeyValueMetadata([
('a', 'alpha'),
('b', 'beta'),
('a', 'Alpha'),
('a', 'ALPHA'),
], b='BETA')

expected = [
(b'a', b'alpha'),
(b'b', b'beta'),
(b'a', b'Alpha'),
(b'a', b'ALPHA'),
(b'b', b'BETA')
]
assert len(md) == 5
assert isinstance(md.keys(), Iterator)
assert isinstance(md.values(), Iterator)
assert isinstance(md.items(), Iterator)
assert list(md.items()) == expected
assert list(md.keys()) == [k for k, _ in expected]
assert list(md.values()) == [v for _, v in expected]

# first occurence
assert md['a'] == b'alpha'
assert md['b'] == b'beta'
assert md.get_all('a') == [b'alpha', b'Alpha', b'ALPHA']
assert md.get_all('b') == [b'beta', b'BETA']
assert md.get_all('unkown') == []


def test_field_basic():
t = pa.string()
f = pa.field('foo', t)
Expand Down
Loading

0 comments on commit d297a2f

Please sign in to comment.