Skip to content

Commit

Permalink
ARROW-8209: [Python] Improve error message when trying to access dupl…
Browse files Browse the repository at this point in the history
…icate Table column

Also adds small binding for `Schema::GetAllFieldIndices`

Closes #6831 from wesm/ARROW-8209

Authored-by: Wes McKinney <wesm+git@apache.org>
Signed-off-by: Wes McKinney <wesm+git@apache.org>
  • Loading branch information
wesm committed Apr 4, 2020
1 parent a2a2475 commit ccb9b84
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 16 deletions.
6 changes: 6 additions & 0 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,9 @@ std::vector<int> StructType::GetAllFieldIndices(const std::string& name) const {
for (auto it = p.first; it != p.second; ++it) {
result.push_back(it->second);
}
if (result.size() > 1) {
std::sort(result.begin(), result.end());
}
return result;
}

Expand Down Expand Up @@ -1104,6 +1107,9 @@ std::vector<int> Schema::GetAllFieldIndices(const std::string& name) const {
for (auto it = p.first; it != p.second; ++it) {
result.push_back(it->second);
}
if (result.size() > 1) {
std::sort(result.begin(), result.end());
}
return result;
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,7 @@ class ARROW_EXPORT StructType : public NestedType {
/// same name
int GetFieldIndex(const std::string& name) const;

/// Return the indices of all fields having this name
/// \brief Return the indices of all fields having this name in sorted order
std::vector<int> GetAllFieldIndices(const std::string& name) const;

private:
Expand Down Expand Up @@ -1656,7 +1656,7 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable,
/// Returns null if name not found
std::shared_ptr<Field> GetFieldByName(const std::string& name) const;

/// Return all fields having this name
/// \brief Return the indices of all fields having this name in sorted order
std::vector<std::shared_ptr<Field>> GetAllFieldsByName(const std::string& name) const;

/// Returns -1 if name not found
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ TEST_F(TestSchema, GetFieldDuplicates) {
ASSERT_EQ(2, schema->GetFieldIndex(f2->name()));
ASSERT_EQ(-1, schema->GetFieldIndex("not-found"));
ASSERT_EQ(std::vector<int>{0}, schema->GetAllFieldIndices(f0->name()));
AssertSortedEquals(std::vector<int>{1, 3}, schema->GetAllFieldIndices(f1->name()));
ASSERT_EQ(std::vector<int>({1, 3}), schema->GetAllFieldIndices(f1->name()));

ASSERT_TRUE(::arrow::schema({f0, f1, f2})->HasDistinctFieldNames());
ASSERT_FALSE(schema->HasDistinctFieldNames());
Expand Down Expand Up @@ -1439,7 +1439,7 @@ TEST(TestStructType, GetFieldDuplicates) {
ASSERT_EQ(0, struct_type.GetFieldIndex("f0"));
ASSERT_EQ(-1, struct_type.GetFieldIndex("f1"));
ASSERT_EQ(std::vector<int>{0}, struct_type.GetAllFieldIndices(f0->name()));
AssertSortedEquals(std::vector<int>{1, 2}, struct_type.GetAllFieldIndices(f1->name()));
ASSERT_EQ(std::vector<int>({1, 2}), struct_type.GetAllFieldIndices(f1->name()));

std::vector<std::shared_ptr<Field>> results;

Expand Down
5 changes: 3 additions & 2 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
shared_ptr[CField] GetFieldByName(const c_string& name)
vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
int GetFieldIndex(const c_string& name)
vector[int] GetAllFieldIndices(const c_string& name)

cdef cppclass CUnionType" arrow::UnionType"(CDataType):
CUnionType(const vector[shared_ptr[CField]]& fields,
Expand All @@ -393,8 +394,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:
shared_ptr[const CKeyValueMetadata] metadata()
shared_ptr[CField] GetFieldByName(const c_string& name)
vector[shared_ptr[CField]] GetAllFieldsByName(const c_string& name)
int64_t GetFieldIndex(const c_string& name)
vector[int64_t] GetAllFieldIndice(const c_string& name)
int GetFieldIndex(const c_string& name)
vector[int] GetAllFieldIndices(const c_string& name)
int num_fields()
c_string ToString()

Expand Down
13 changes: 9 additions & 4 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1559,11 +1559,16 @@ cdef class Table(_PandasConvertible):
pyarrow.ChunkedArray
"""
if isinstance(i, (bytes, str)):
field_index = self.schema.get_field_index(i)
if field_index < 0:
raise KeyError("Column {} does not exist in table".format(i))
field_indices = self.schema.get_all_field_indices(i)

if len(field_indices) == 0:
raise KeyError("Field \"{}\" does not exist in table schema"
.format(i))
elif len(field_indices) > 1:
raise KeyError("Field \"{}\" exists {} times in table schema"
.format(i, len(field_indices)))
else:
return self._column(field_index)
return self._column(field_indices[0])
elif isinstance(i, int):
return self._column(i)
else:
Expand Down
6 changes: 6 additions & 0 deletions python/pyarrow/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,12 @@ def test_schema_duplicate_fields():
with pytest.warns((UserWarning, FutureWarning)):
assert sch.field_by_name('foo') is None

# Schema::GetFieldIndex
assert sch.get_field_index('foo') == -1

# Schema::GetAllFieldIndices
assert sch.get_all_field_indices('foo') == [0, 2]


def test_field_flatten():
f0 = pa.field('foo', pa.int32()).with_metadata({b'foo': b'bar'})
Expand Down
14 changes: 13 additions & 1 deletion python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,8 @@ def test_table_select_column():

assert table.column('a').equals(table.column(0))

with pytest.raises(KeyError):
with pytest.raises(KeyError,
match='Field "d" does not exist in table schema'):
table.column('d')

with pytest.raises(TypeError):
Expand All @@ -735,6 +736,17 @@ def test_table_select_column():
table.column(4)


def test_table_column_with_duplicates():
# ARROW-8209
table = pa.table([pa.array([1, 2, 3]),
pa.array([4, 5, 6]),
pa.array([7, 8, 9])], names=['a', 'b', 'a'])

with pytest.raises(KeyError,
match='Field "a" exists 2 times in table schema'):
table.column('a')


def test_table_add_column():
data = [
pa.array(range(5)),
Expand Down
25 changes: 20 additions & 5 deletions python/pyarrow/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,6 @@ def test_struct_type():

assert ty['b'] == ty[2]

# Duplicate
with pytest.warns(UserWarning):
with pytest.raises(KeyError):
ty['a']

# Not found
with pytest.raises(KeyError):
ty['c']
Expand Down Expand Up @@ -385,6 +380,26 @@ def test_struct_type():
pa.struct([('a', None)])


def test_struct_duplicate_field_names():
fields = [
pa.field('a', pa.int64()),
pa.field('b', pa.int32()),
pa.field('a', pa.int32())
]
ty = pa.struct(fields)

# Duplicate
with pytest.warns(UserWarning):
with pytest.raises(KeyError):
ty['a']

# StructType::GetFieldIndex
assert ty.get_field_index('a') == -1

# StructType::GetAllFieldIndices
assert ty.get_all_field_indices('a') == [0, 2]


def test_union_type():
def check_fields(ty, fields):
assert ty.num_children == len(fields)
Expand Down
23 changes: 23 additions & 0 deletions python/pyarrow/types.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,19 @@ cdef class StructType(DataType):
else:
return pyarrow_wrap_field(fields[0])

def get_field_index(self, name):
"""
Return index of field with given unique name. Returns -1 if not found
or if duplicated
"""
return self.struct_type.GetFieldIndex(tobytes(name))

def get_all_field_indices(self, name):
"""
Return sorted list of indices for fields with the given name
"""
return self.struct_type.GetAllFieldIndices(tobytes(name))

def __len__(self):
"""
Like num_children().
Expand Down Expand Up @@ -1322,8 +1335,18 @@ cdef class Schema:
return pyarrow_wrap_field(results[0])

def get_field_index(self, name):
"""
Return index of field with given unique name. Returns -1 if not found
or if duplicated
"""
return self.schema.GetFieldIndex(tobytes(name))

def get_all_field_indices(self, name):
"""
Return sorted list of indices for fields with the given name
"""
return self.schema.GetAllFieldIndices(tobytes(name))

def append(self, Field field):
"""
Append a field at the end of the schema.
Expand Down

0 comments on commit ccb9b84

Please sign in to comment.