Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: (de)serialize using Pydantic #4973

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ packages = find:
python_requires = >=3.8
install_requires =
mmh3
pydantic
[options.extras_require]
arrow =
pyarrow==8.0.0
Expand Down
11 changes: 11 additions & 0 deletions python/spellcheck-dictionary.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ Timestamptz
Timestamptzs
unscaled
URI
json
py
conftest
pytest
parametrize
uri
URI
InputFile
OutputFile
bytestream
deserialize
UnboundPredicate
BoundPredicate
BooleanExpression
Expand Down
76 changes: 40 additions & 36 deletions python/src/iceberg/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
TypeVar,
)

from pydantic import Field, PrivateAttr

from iceberg.files import StructProtocol
from iceberg.types import (
IcebergType,
Expand All @@ -37,27 +39,38 @@
PrimitiveType,
StructType,
)
from iceberg.utils.iceberg_base_model import IcebergBaseModel

T = TypeVar("T")


class Schema:
class Schema(IcebergBaseModel):
"""A table Schema

Example:
>>> from iceberg import schema
>>> from iceberg import types
"""

def __init__(self, *columns: NestedField, schema_id: int, identifier_field_ids: list[int] | None = None):
self._struct = StructType(*columns)
self._schema_id = schema_id
self._identifier_field_ids = identifier_field_ids or []
self._name_to_id: dict[str, int] = index_by_name(self)
self._name_to_id_lower: dict[str, int] = {} # Should be accessed through self._lazy_name_to_id_lower()
self._id_to_field: dict[int, NestedField] = {} # Should be accessed through self._lazy_id_to_field()
self._id_to_name: dict[int, str] = {} # Should be accessed through self._lazy_id_to_name()
self._id_to_accessor: dict[int, Accessor] = {} # Should be accessed through self._lazy_id_to_accessor()
fields: tuple[NestedField, ...] = Field()
schema_id: int = Field(alias="schema-id")
identifier_field_ids: list[int] = Field(alias="identifier-field-ids", default_factory=list)

_name_to_id: dict[str, int] = PrivateAttr()
# Should be accessed through self._lazy_name_to_id_lower()
_name_to_id_lower: dict[str, int] = PrivateAttr(default_factory=dict)
# Should be accessed through self._lazy_id_to_field()
_id_to_field: dict[int, NestedField] = PrivateAttr(default_factory=dict)
# Should be accessed through self._lazy_id_to_name()
_id_to_name: dict[int, str] = PrivateAttr(default_factory=dict)
# Should be accessed through self._lazy_id_to_accessor()
_id_to_accessor: dict[int, Accessor] = PrivateAttr(default_factory=dict)

def __init__(self, *fields: NestedField, **data):
if fields:
data["fields"] = fields
super().__init__(**data)
self._name_to_id = index_by_name(self)

def __str__(self):
return "table {\n" + "\n".join([" " + str(field) for field in self.columns]) + "\n}"
Expand Down Expand Up @@ -85,16 +98,7 @@ def __eq__(self, other) -> bool:
@property
def columns(self) -> tuple[NestedField, ...]:
"""A list of the top-level fields in the underlying struct"""
return self._struct.fields

@property
def schema_id(self) -> int:
"""The ID of this Schema"""
return self._schema_id

@property
def identifier_field_ids(self) -> list[int]:
return self._identifier_field_ids
return self.fields

def _lazy_id_to_field(self) -> dict[int, NestedField]:
"""Returns an index of field ID to NestedField instance
Expand Down Expand Up @@ -134,7 +138,7 @@ def _lazy_id_to_accessor(self) -> dict[int, Accessor]:

def as_struct(self) -> StructType:
"""Returns the underlying struct"""
return self._struct
return StructType(*self.fields)

def find_field(self, name_or_id: str | int, case_sensitive: bool = True) -> NestedField:
"""Find a field using a field name or field ID
Expand Down Expand Up @@ -343,23 +347,23 @@ def _(obj: StructType, visitor: SchemaVisitor[T]) -> T:
def _(obj: ListType, visitor: SchemaVisitor[T]) -> T:
"""Visit a ListType with a concrete SchemaVisitor"""

visitor.before_list_element(obj.element)
result = visit(obj.element.field_type, visitor)
visitor.after_list_element(obj.element)
visitor.before_list_element(obj.element_field)
result = visit(obj.element_field.field_type, visitor)
visitor.after_list_element(obj.element_field)

return visitor.list(obj, result)


@visit.register(MapType)
def _(obj: MapType, visitor: SchemaVisitor[T]) -> T:
"""Visit a MapType with a concrete SchemaVisitor"""
visitor.before_map_key(obj.key)
key_result = visit(obj.key.field_type, visitor)
visitor.after_map_key(obj.key)
visitor.before_map_key(obj.key_field)
key_result = visit(obj.key, visitor)
visitor.after_map_key(obj.key_field)

visitor.before_map_value(obj.value)
value_result = visit(obj.value.field_type, visitor)
visitor.after_list_element(obj.value)
visitor.before_map_value(obj.value_field)
value_result = visit(obj.value, visitor)
visitor.after_list_element(obj.value_field)

return visitor.map(obj, key_result, value_result)

Expand Down Expand Up @@ -389,13 +393,13 @@ def field(self, field: NestedField, field_result) -> dict[int, NestedField]:

def list(self, list_type: ListType, element_result) -> dict[int, NestedField]:
"""Add the list element ID to the index"""
self._index[list_type.element.field_id] = list_type.element
self._index[list_type.element_field.field_id] = list_type.element_field
return self._index

def map(self, map_type: MapType, key_result, value_result) -> dict[int, NestedField]:
"""Add the key ID and value ID as individual items in the index"""
self._index[map_type.key.field_id] = map_type.key
self._index[map_type.value.field_id] = map_type.value
self._index[map_type.key_field.field_id] = map_type.key_field
self._index[map_type.value_field.field_id] = map_type.value_field
return self._index

def primitive(self, primitive) -> dict[int, NestedField]:
Expand Down Expand Up @@ -458,13 +462,13 @@ def field(self, field: NestedField, field_result: dict[str, int]) -> dict[str, i

def list(self, list_type: ListType, element_result: dict[str, int]) -> dict[str, int]:
"""Add the list element name to the index"""
self._add_field(list_type.element.name, list_type.element.field_id)
self._add_field(list_type.element_field.name, list_type.element_field.field_id)
return self._index

def map(self, map_type: MapType, key_result: dict[str, int], value_result: dict[str, int]) -> dict[str, int]:
"""Add the key name and value name as individual items in the index"""
self._add_field(map_type.key.name, map_type.key.field_id)
self._add_field(map_type.value.name, map_type.value.field_id)
self._add_field(map_type.key_field.name, map_type.key_field.field_id)
self._add_field(map_type.value_field.name, map_type.value_field.field_id)
return self._index

def _add_field(self, name: str, field_id: int):
Expand Down
75 changes: 75 additions & 0 deletions python/src/iceberg/serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import codecs
import json
from typing import Union

from iceberg.io.base import InputFile, InputStream, OutputFile
from iceberg.table.metadata import TableMetadata, TableMetadataV1, TableMetadataV2


class FromByteStream:
"""A collection of methods that deserialize dictionaries into Iceberg objects"""

@staticmethod
def table_metadata(byte_stream: InputStream, encoding: str = "utf-8") -> TableMetadata:
"""Instantiate a TableMetadata object from a byte stream

Args:
byte_stream: A file-like byte stream object
encoding (default "utf-8"): The byte encoder to use for the reader
"""
reader = codecs.getreader(encoding)
metadata = json.load(reader(byte_stream)) # type: ignore
return TableMetadata.parse_obj(metadata) # type: ignore


class FromInputFile:
"""A collection of methods that deserialize InputFiles into Iceberg objects"""

@staticmethod
def table_metadata(input_file: InputFile, encoding: str = "utf-8") -> TableMetadata:
"""Create a TableMetadata instance from an input file

Args:
input_file (InputFile): A custom implementation of the iceberg.io.file.InputFile abstract base class
encoding (str): Encoding to use when loading bytestream

Returns:
TableMetadata: A table metadata instance

"""
return FromByteStream.table_metadata(byte_stream=input_file.open(), encoding=encoding)


class ToOutputFile:
"""A collection of methods that serialize Iceberg objects into files given an OutputFile instance"""

@staticmethod
def table_metadata(
metadata: Union[TableMetadataV1, TableMetadataV2], output_file: OutputFile, overwrite: bool = False
) -> None:
"""Write a TableMetadata instance to an output file

Args:
output_file (OutputFile): A custom implementation of the iceberg.io.file.OutputFile abstract base class
overwrite (bool): Where to overwrite the file if it already exists. Defaults to `False`.
"""
f = output_file.create(overwrite=overwrite)
f.write(metadata.json().encode("utf-8"))
f.close()
Loading