From 8cd0e520158894af7a67688748dfcd4f0bd1458c Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 25 Jan 2024 16:46:38 +0000 Subject: [PATCH 1/9] create_table from pyarrow schema --- pyiceberg/catalog/__init__.py | 6 +- pyiceberg/catalog/dynamodb.py | 14 ++- pyiceberg/catalog/glue.py | 14 ++- pyiceberg/catalog/hive.py | 15 ++- pyiceberg/catalog/noop.py | 6 +- pyiceberg/catalog/rest.py | 14 ++- pyiceberg/catalog/sql.py | 14 ++- pyiceberg/io/pyarrow.py | 162 +++++++++++++++++++++++++++++-- pyproject.toml | 2 +- tests/catalog/test_base.py | 26 ++++- tests/catalog/test_dynamodb.py | 18 ++++ tests/catalog/test_glue.py | 23 +++++ tests/catalog/test_sql.py | 21 ++++ tests/conftest.py | 115 ++++++++++++++++++++++ tests/io/test_pyarrow_visitor.py | 121 ++++------------------- 15 files changed, 454 insertions(+), 117 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index a39d0e915c..bec6b8cf06 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -24,6 +24,7 @@ from dataclasses import dataclass from enum import Enum from typing import ( + TYPE_CHECKING, Callable, Dict, List, @@ -56,6 +57,9 @@ ) from pyiceberg.utils.config import Config, merge_config +if TYPE_CHECKING: + import pyarrow as pa + logger = logging.getLogger(__name__) _ENV_CONFIG = Config() @@ -288,7 +292,7 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[ def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py index 6c3f931bd8..f9bdb01470 100644 --- a/pyiceberg/catalog/dynamodb.py +++ b/pyiceberg/catalog/dynamodb.py @@ -17,6 +17,7 @@ import uuid from time import time from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -57,6 +58,9 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT +if TYPE_CHECKING: + import pyarrow as pa + DYNAMODB_CLIENT = "dynamodb" DYNAMODB_COL_IDENTIFIER = "identifier" @@ -127,7 +131,7 @@ def _dynamodb_table_exists(self) -> bool: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -152,6 +156,14 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ + if not isinstance(schema, Schema): + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + database_name, table_name = self.identifier_to_database_and_table(identifier) location = self._resolve_table_location(location, database_name, table_name) diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index 645568f80a..aad7c01083 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -17,6 +17,7 @@ from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -88,6 +89,9 @@ UUIDType, ) +if TYPE_CHECKING: + import pyarrow as pa + # If Glue should skip archiving an old table version when creating a new version in a commit. By # default, Glue archives all old table versions after an UpdateTable call, but Glue has a default # max number of archived table versions (can be increased). So for streaming use case with lots @@ -329,7 +333,7 @@ def _get_glue_table(self, database_name: str, table_name: str) -> TableTypeDef: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -354,6 +358,14 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ + if not isinstance(schema, Schema): + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + database_name, table_name = self.identifier_to_database_and_table(identifier) location = self._resolve_table_location(location, database_name, table_name) diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 331b9ca80d..82c5f2cd3d 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -18,6 +18,7 @@ import time from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -91,6 +92,10 @@ UUIDType, ) +if TYPE_CHECKING: + import pyarrow as pa + + # Replace by visitor hive_types = { BooleanType: "boolean", @@ -250,7 +255,7 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -273,6 +278,14 @@ def create_table( AlreadyExistsError: If a table with the name already exists. ValueError: If the identifier is invalid. """ + if not isinstance(schema, Schema): + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + properties = {**DEFAULT_PROPERTIES, **properties} database_name, table_name = self.identifier_to_database_and_table(identifier) current_time_millis = int(time.time() * 1000) diff --git a/pyiceberg/catalog/noop.py b/pyiceberg/catalog/noop.py index 083f851d1c..a8b7154621 100644 --- a/pyiceberg/catalog/noop.py +++ b/pyiceberg/catalog/noop.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. from typing import ( + TYPE_CHECKING, List, Optional, Set, @@ -33,12 +34,15 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties +if TYPE_CHECKING: + import pyarrow as pa + class NoopCatalog(Catalog): def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index de192a9e0b..bf15a68449 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -16,6 +16,7 @@ # under the License. from json import JSONDecodeError from typing import ( + TYPE_CHECKING, Any, Dict, List, @@ -68,6 +69,9 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel +if TYPE_CHECKING: + import pyarrow as pa + ICEBERG_REST_SPEC_VERSION = "0.14.1" @@ -437,12 +441,20 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> Table: + if not isinstance(schema, Schema): + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + namespace_and_table = self._split_identifier_for_path(identifier) request = CreateTableRequest( name=namespace_and_table["table"], diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 593c6b54a1..593a885c73 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -16,6 +16,7 @@ # under the License. from typing import ( + TYPE_CHECKING, List, Optional, Set, @@ -65,6 +66,9 @@ from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT +if TYPE_CHECKING: + import pyarrow as pa + class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase): pass @@ -140,7 +144,7 @@ def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, @@ -165,6 +169,14 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ + if not isinstance(schema, Schema): + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + database_name, table_name = self.identifier_to_database_and_table(identifier) if not self._namespace_exists(database_name): raise NoSuchNamespaceError(f"Namespace does not exist: {database_name}") diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 035f5e8031..ca36c7b571 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -26,6 +26,7 @@ from __future__ import annotations import concurrent.futures +import itertools import logging import os import re @@ -33,8 +34,7 @@ from concurrent.futures import Future from dataclasses import dataclass from enum import Enum -from functools import lru_cache, singledispatch -from itertools import chain +from functools import lru_cache, partial, singledispatch from typing import ( TYPE_CHECKING, Any, @@ -631,7 +631,7 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows: if len(positional_deletes) == 1: all_chunks = positional_deletes[0] else: - all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in positional_deletes])) + all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes])) return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False) @@ -711,6 +711,60 @@ def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T: return visitor.primitive(obj) +@singledispatch +def pre_order_visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: + """Apply a pyarrow schema visitor to any point within a schema. + + The function traverses the schema in pre-order fashion. + + Args: + obj (Union[pa.DataType, pa.Schema]): An instance of a Schema or an IcebergType. + visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of the generic PyarrowSchemaVisitor base class. + + Raises: + NotImplementedError: If attempting to visit an unrecognized object type. + """ + raise NotImplementedError(f"Cannot visit non-type: {obj}") + + +@pre_order_visit_pyarrow.register(pa.Schema) +def _(obj: pa.Schema, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: + return visitor.schema(obj, lambda: pre_order_visit_pyarrow(pa.struct(obj), visitor)) + + +@pre_order_visit_pyarrow.register(pa.StructType) +def _(obj: pa.StructType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: + return visitor.struct( + obj, + [ + partial( + lambda field: visitor.field(field, partial(lambda field: pre_order_visit_pyarrow(field.type, visitor), field)), + field, + ) + for field in obj + ], + ) + + +@pre_order_visit_pyarrow.register(pa.ListType) +def _(obj: pa.ListType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: + return visitor.list(obj, lambda: pre_order_visit_pyarrow(obj.value_type, visitor)) + + +@pre_order_visit_pyarrow.register(pa.MapType) +def _(obj: pa.MapType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: + return visitor.map( + obj, lambda: pre_order_visit_pyarrow(obj.key_type, visitor), lambda: pre_order_visit_pyarrow(obj.item_type, visitor) + ) + + +@pre_order_visit_pyarrow.register(pa.DataType) +def _(obj: pa.DataType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: + if pa.types.is_nested(obj): + raise TypeError(f"Expected primitive type, got: {type(obj)}") + return visitor.primitive(obj) + + class PyArrowSchemaVisitor(Generic[T], ABC): def before_field(self, field: pa.Field) -> None: """Override this method to perform an action immediately before visiting a field.""" @@ -761,6 +815,32 @@ def primitive(self, primitive: pa.DataType) -> T: """Visit a primitive type.""" +class PreOrderPyArrowSchemaVisitor(Generic[T], ABC): + @abstractmethod + def schema(self, schema: pa.Schema, struct_result: Callable[[], T]) -> T: + """Visit a schema.""" + + @abstractmethod + def struct(self, struct: pa.StructType, field_results: List[Callable[[], T]]) -> T: + """Visit a struct.""" + + @abstractmethod + def field(self, field: pa.Field, field_result: Callable[[], T]) -> T: + """Visit a field.""" + + @abstractmethod + def list(self, list_type: pa.ListType, element_result: Callable[[], T]) -> T: + """Visit a list.""" + + @abstractmethod + def map(self, map_type: pa.MapType, key_result: Callable[[], T], value_result: Callable[[], T]) -> T: + """Visit a map.""" + + @abstractmethod + def primitive(self, primitive: pa.DataType) -> T: + """Visit a primitive type.""" + + def _get_field_id(field: pa.Field) -> Optional[int]: return ( int(field_id_str.decode()) @@ -906,6 +986,76 @@ def after_map_value(self, element: pa.Field) -> None: self._field_names.pop() +class _ConvertToIcebergWithFreshIds(PreOrderPyArrowSchemaVisitor[Union[IcebergType, Schema]]): + """Converts PyArrowSchema to Iceberg Schema with fresh ids.""" + + def __init__(self) -> None: + self.counter = itertools.count(1) + + def _field_id(self) -> int: + return next(self.counter) + + def schema(self, schema: pa.Schema, struct_result: Callable[[], StructType]) -> Schema: + return Schema(*struct_result().fields) + + def struct(self, struct: pa.StructType, field_results: List[Callable[[], NestedField]]) -> StructType: + return StructType(*[field() for field in field_results]) + + def field(self, field: pa.Field, field_result: Callable[[], IcebergType]) -> NestedField: + field_id = self._field_id() + field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None + field_type = field_result() + return NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc) + + def list(self, list_type: pa.ListType, element_result: Callable[[], IcebergType]) -> ListType: + element_field = list_type.value_field + element_id = self._field_id() + return ListType(element_id, element_result(), element_required=not element_field.nullable) + + def map( + self, map_type: pa.MapType, key_result: Callable[[], IcebergType], value_result: Callable[[], IcebergType] + ) -> MapType: + key_id = self._field_id() + value_field = map_type.item_field + value_id = self._field_id() + return MapType(key_id, key_result(), value_id, value_result(), value_required=not value_field.nullable) + + def primitive(self, primitive: pa.DataType) -> PrimitiveType: + if pa.types.is_boolean(primitive): + return BooleanType() + elif pa.types.is_int32(primitive): + return IntegerType() + elif pa.types.is_int64(primitive): + return LongType() + elif pa.types.is_float32(primitive): + return FloatType() + elif pa.types.is_float64(primitive): + return DoubleType() + elif isinstance(primitive, pa.Decimal128Type): + primitive = cast(pa.Decimal128Type, primitive) + return DecimalType(primitive.precision, primitive.scale) + elif pa.types.is_string(primitive): + return StringType() + elif pa.types.is_date32(primitive): + return DateType() + elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": + return TimeType() + elif pa.types.is_timestamp(primitive): + primitive = cast(pa.TimestampType, primitive) + if primitive.unit == "us": + if primitive.tz == "UTC" or primitive.tz == "+00:00": + return TimestamptzType() + elif primitive.tz is None: + return TimestampType() + elif pa.types.is_binary(primitive): + return BinaryType() + elif pa.types.is_fixed_size_binary(primitive): + primitive = cast(pa.FixedSizeBinaryType, primitive) + return FixedType(primitive.byte_width) + + raise TypeError(f"Unsupported type: {primitive}") + + def _task_to_table( fs: FileSystem, task: FileScanTask, @@ -993,7 +1143,7 @@ def _task_to_table( def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]: deletes_per_file: Dict[str, List[ChunkedArray]] = {} - unique_deletes = set(chain.from_iterable([task.delete_files for task in tasks])) + unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks])) if len(unique_deletes) > 0: executor = ExecutorFactory.get_or_create() deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map( @@ -1399,7 +1549,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsColl def struct( self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]] ) -> List[StatisticsCollector]: - return list(chain(*[result() for result in field_results])) + return list(itertools.chain(*[result() for result in field_results])) def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]: self._field_id = field.field_id @@ -1491,7 +1641,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath return struct_result() def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]: - return list(chain(*[result() for result in field_results])) + return list(itertools.chain(*[result() for result in field_results])) def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]: self._field_id = field.field_id diff --git a/pyproject.toml b/pyproject.toml index 505fedf813..03d55c956a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -311,7 +311,7 @@ select = [ "I", # isort "UP", # pyupgrade ] -ignore = ["E501","E203","B024","B028"] +ignore = ["E501","E203","B024","B028","UP037"] # Allow autofix for all enabled rules (when `--fix`) is provided. fixable = ["ALL"] diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 911c06b27a..21091d6410 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -24,6 +24,7 @@ Union, ) +import pyarrow as pa import pytest from pyiceberg.catalog import ( @@ -72,12 +73,20 @@ def __init__(self, name: str, **properties: str) -> None: def create_table( self, identifier: Union[str, Identifier], - schema: Schema, + schema: Union[Schema, "pa.Schema"], location: Optional[str] = None, partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC, sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> Table: + if not isinstance(schema, Schema): + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + identifier = Catalog.identifier_to_tuple(identifier) namespace = Catalog.namespace_from(identifier) @@ -255,6 +264,10 @@ def catalog() -> InMemoryCatalog: return InMemoryCatalog("test.in.memory.catalog", **{"test.key": "test.value"}) +TEST_PYARROW_SCHEMA = pa.schema([ + pa.field('some_int', pa.int32(), nullable=True), + pa.field('some_string', pa.string(), nullable=False), +]) TEST_TABLE_IDENTIFIER = ("com", "organization", "department", "my_table") TEST_TABLE_NAMESPACE = ("com", "organization", "department") TEST_TABLE_NAME = "my_table" @@ -330,6 +343,17 @@ def test_create_table(catalog: InMemoryCatalog) -> None: assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table +def test_create_table_pyarrow_schema(catalog: InMemoryCatalog, pyarrow_schema_simple_without_ids: pa.Schema) -> None: + table = catalog.create_table( + identifier=TEST_TABLE_IDENTIFIER, + schema=pyarrow_schema_simple_without_ids, + location=TEST_TABLE_LOCATION, + partition_spec=TEST_TABLE_PARTITION_SPEC, + properties=TEST_TABLE_PROPERTIES, + ) + assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table + + def test_create_table_raises_error_when_table_already_exists(catalog: InMemoryCatalog) -> None: # Given given_catalog_has_a_table(catalog) diff --git a/tests/catalog/test_dynamodb.py b/tests/catalog/test_dynamodb.py index 5af89ef3be..bc801463c5 100644 --- a/tests/catalog/test_dynamodb.py +++ b/tests/catalog/test_dynamodb.py @@ -18,6 +18,7 @@ from unittest import mock import boto3 +import pyarrow as pa import pytest from moto import mock_dynamodb @@ -71,6 +72,23 @@ def test_create_table_with_database_location( assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) +@mock_dynamodb +def test_create_table_with_pyarrow_schema( + _bucket_initialize: None, + moto_endpoint_url: str, + pyarrow_schema_simple_without_ids: pa.Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "test_ddb_catalog" + identifier = (database_name, table_name) + test_catalog = DynamoDbCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"}) + table = test_catalog.create_table(identifier, pyarrow_schema_simple_without_ids) + assert table.identifier == (catalog_name,) + identifier + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + + @mock_dynamodb def test_create_table_with_default_warehouse( _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index b1f1371a04..63a213f94f 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -18,6 +18,7 @@ from unittest import mock import boto3 +import pyarrow as pa import pytest from moto import mock_glue @@ -101,6 +102,28 @@ def test_create_table_with_given_location( assert test_catalog._parse_metadata_version(table.metadata_location) == 0 +@mock_glue +def test_create_table_with_pyarrow_schema( + _bucket_initialize: None, + moto_endpoint_url: str, + pyarrow_schema_simple_without_ids: pa.Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "glue" + identifier = (database_name, table_name) + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name) + table = test_catalog.create_table( + identifier=identifier, + schema=pyarrow_schema_simple_without_ids, + location=f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}", + ) + assert table.identifier == (catalog_name,) + identifier + assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) + assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + + @mock_glue def test_create_table_with_no_location( _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 217ea8f535..19c03e218a 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -19,6 +19,7 @@ from pathlib import Path from typing import Generator, List +import pyarrow as pa import pytest from pytest import TempPathFactory from pytest_lazyfixture import lazy_fixture @@ -141,6 +142,26 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + ], +) +def test_create_table_with_pyarrow_schema( + catalog: SqlCatalog, + pyarrow_schema_simple_without_ids: pa.Schema, + iceberg_table_schema_simple: Schema, + random_identifier: Identifier, +) -> None: + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, pyarrow_schema_simple_without_ids) + assert table.schema() == iceberg_table_schema_simple + catalog.drop_table(random_identifier) + + @pytest.mark.parametrize( 'catalog', [ diff --git a/tests/conftest.py b/tests/conftest.py index 9c53301776..1b35f082bd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,6 +45,7 @@ from urllib.parse import urlparse import boto3 +import pyarrow as pa import pytest from moto import mock_dynamodb, mock_glue from moto.server import ThreadedMotoServer # type: ignore @@ -267,6 +268,120 @@ def table_schema_nested_with_struct_key_map() -> Schema: ) +@pytest.fixture(scope="session") +def pyarrow_schema_simple_without_ids() -> pa.Schema: + return pa.schema([ + pa.field('foo', pa.string(), nullable=True), + pa.field('bar', pa.int32(), nullable=False), + pa.field('baz', pa.bool_(), nullable=True), + ]) + + +@pytest.fixture(scope="session") +def pyarrow_schema_nested_without_ids() -> pa.Schema: + return pa.schema([ + pa.field('foo', pa.string(), nullable=False), + pa.field('bar', pa.int32(), nullable=False), + pa.field('baz', pa.bool_(), nullable=True), + pa.field('qux', pa.list_(pa.string()), nullable=False), + pa.field( + 'quux', + pa.map_( + pa.string(), + pa.map_(pa.string(), pa.int32()), + ), + nullable=False, + ), + pa.field( + 'location', + pa.list_( + pa.struct([ + pa.field('latitude', pa.float32(), nullable=False), + pa.field('longitude', pa.float32(), nullable=False), + ]), + ), + nullable=False, + ), + pa.field( + 'person', + pa.struct([ + pa.field('name', pa.string(), nullable=True), + pa.field('age', pa.int32(), nullable=False), + ]), + nullable=True, + ), + ]) + + +@pytest.fixture(scope="session") +def iceberg_schema_simple() -> Schema: + return Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + ) + + +@pytest.fixture(scope="session") +def iceberg_table_schema_simple() -> Schema: + return Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[], + ) + + +@pytest.fixture(scope="session") +def iceberg_schema_nested() -> Schema: + return Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=True), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=4, + name="qux", + field_type=ListType(element_id=5, element_type=StringType(), element_required=False), + required=True, + ), + NestedField( + field_id=6, + name="quux", + field_type=MapType( + key_id=7, + key_type=StringType(), + value_id=8, + value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=False), + value_required=False, + ), + required=True, + ), + NestedField( + field_id=11, + name="location", + field_type=ListType( + element_id=12, + element_type=StructType( + NestedField(field_id=13, name="latitude", field_type=FloatType(), required=True), + NestedField(field_id=14, name="longitude", field_type=FloatType(), required=True), + ), + element_required=False, + ), + required=True, + ), + NestedField( + field_id=15, + name="person", + field_type=StructType( + NestedField(field_id=16, name="name", field_type=StringType(), required=False), + NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + ) + + @pytest.fixture(scope="session") def all_avro_types() -> Dict[str, Any]: return { diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 0986eac409..509f474a3d 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -23,7 +23,9 @@ from pyiceberg.io.pyarrow import ( _ConvertToArrowSchema, _ConvertToIceberg, + _ConvertToIcebergWithFreshIds, _HasIds, + pre_order_visit_pyarrow, pyarrow_to_schema, schema_to_pyarrow, visit_pyarrow, @@ -51,104 +53,6 @@ ) -@pytest.fixture(scope="module") -def pyarrow_schema_simple_without_ids() -> pa.Schema: - return pa.schema([pa.field('some_int', pa.int32(), nullable=True), pa.field('some_string', pa.string(), nullable=False)]) - - -@pytest.fixture(scope="module") -def pyarrow_schema_nested_without_ids() -> pa.Schema: - return pa.schema([ - pa.field('foo', pa.string(), nullable=False), - pa.field('bar', pa.int32(), nullable=False), - pa.field('baz', pa.bool_(), nullable=True), - pa.field('qux', pa.list_(pa.string()), nullable=False), - pa.field( - 'quux', - pa.map_( - pa.string(), - pa.map_(pa.string(), pa.int32()), - ), - nullable=False, - ), - pa.field( - 'location', - pa.list_( - pa.struct([ - pa.field('latitude', pa.float32(), nullable=False), - pa.field('longitude', pa.float32(), nullable=False), - ]), - ), - nullable=False, - ), - pa.field( - 'person', - pa.struct([ - pa.field('name', pa.string(), nullable=True), - pa.field('age', pa.int32(), nullable=False), - ]), - nullable=True, - ), - ]) - - -@pytest.fixture(scope="module") -def iceberg_schema_simple() -> Schema: - return Schema( - NestedField(field_id=1, name="some_int", field_type=IntegerType(), required=False), - NestedField(field_id=2, name="some_string", field_type=StringType(), required=True), - ) - - -@pytest.fixture(scope="module") -def iceberg_schema_nested() -> Schema: - return Schema( - NestedField(field_id=1, name="foo", field_type=StringType(), required=True), - NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), - NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), - NestedField( - field_id=4, - name="qux", - field_type=ListType(element_id=5, element_type=StringType(), element_required=False), - required=True, - ), - NestedField( - field_id=6, - name="quux", - field_type=MapType( - key_id=7, - key_type=StringType(), - value_id=8, - value_type=MapType(key_id=9, key_type=StringType(), value_id=10, value_type=IntegerType(), value_required=False), - value_required=False, - ), - required=True, - ), - NestedField( - field_id=11, - name="location", - field_type=ListType( - element_id=12, - element_type=StructType( - NestedField(field_id=13, name="latitude", field_type=FloatType(), required=True), - NestedField(field_id=14, name="longitude", field_type=FloatType(), required=True), - ), - element_required=False, - ), - required=True, - ), - NestedField( - field_id=15, - name="person", - field_type=StructType( - NestedField(field_id=16, name="name", field_type=StringType(), required=False), - NestedField(field_id=17, name="age", field_type=IntegerType(), required=True), - ), - required=False, - ), - ) - - def test_pyarrow_binary_to_iceberg() -> None: length = 23 pyarrow_type = pa.binary(length) @@ -468,8 +372,9 @@ def test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping( ) -> None: schema = pyarrow_schema_simple_without_ids name_mapping = NameMapping([ - MappedField(field_id=1, names=['some_int']), - MappedField(field_id=2, names=['some_string']), + MappedField(field_id=1, names=['foo']), + MappedField(field_id=2, names=['bar']), + MappedField(field_id=3, names=['baz']), ]) assert pyarrow_to_schema(schema, name_mapping) == iceberg_schema_simple @@ -480,11 +385,11 @@ def test_simple_pyarrow_schema_to_schema_missing_ids_using_name_mapping_partial_ ) -> None: schema = pyarrow_schema_simple_without_ids name_mapping = NameMapping([ - MappedField(field_id=1, names=['some_string']), + MappedField(field_id=1, names=['foo']), ]) with pytest.raises(ValueError) as exc_info: _ = pyarrow_to_schema(schema, name_mapping) - assert "Could not find field with name: some_int" in str(exc_info.value) + assert "Could not find field with name: bar" in str(exc_info.value) def test_nested_pyarrow_schema_to_schema_missing_ids_using_name_mapping( @@ -572,3 +477,15 @@ def test_pyarrow_schema_to_schema_missing_ids_using_name_mapping_nested_missing_ with pytest.raises(ValueError) as exc_info: _ = pyarrow_to_schema(schema, name_mapping) assert "Could not find field with name: quux.value.key" in str(exc_info.value) + + +def test_pyarrow_schema_to_schema_fresh_ids_simple_schema( + pyarrow_schema_simple_without_ids: pa.Schema, iceberg_schema_simple: Schema +) -> None: + assert pre_order_visit_pyarrow(pyarrow_schema_simple_without_ids, _ConvertToIcebergWithFreshIds()) == iceberg_schema_simple + + +def test_pyarrow_schema_to_schema_fresh_ids_nested_schema( + pyarrow_schema_nested_without_ids: pa.Schema, iceberg_schema_nested: Schema +) -> None: + assert pre_order_visit_pyarrow(pyarrow_schema_nested_without_ids, _ConvertToIcebergWithFreshIds()) == iceberg_schema_nested From 3d0445b5cbaa045a9c22d54a58e6813e3d8a6513 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 25 Jan 2024 18:33:24 +0000 Subject: [PATCH 2/9] incorporate review comments --- pyiceberg/catalog/__init__.py | 16 ++++++ pyiceberg/catalog/dynamodb.py | 8 +-- pyiceberg/catalog/glue.py | 8 +-- pyiceberg/catalog/hive.py | 8 +-- pyiceberg/catalog/rest.py | 8 +-- pyiceberg/catalog/sql.py | 8 +-- pyiceberg/io/pyarrow.py | 104 +++++++++++++--------------------- tests/catalog/test_base.py | 25 +++++--- 8 files changed, 77 insertions(+), 108 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index bec6b8cf06..d40b38b566 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -516,6 +516,22 @@ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> Non if overlap: raise ValueError(f"Updates and deletes have an overlap: {overlap}") + @staticmethod + def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: + try: + import pyarrow as pa + + from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + + if isinstance(schema, pa.Schema): + schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + except ModuleNotFoundError: + pass + + if not isinstance(schema, Schema): + raise ValueError(f"{type(schema)=} must be pyiceberg.schema.Schema") + return schema + def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: if not location: return self._get_default_warehouse_location(database_name, table_name) diff --git a/pyiceberg/catalog/dynamodb.py b/pyiceberg/catalog/dynamodb.py index f9bdb01470..d5f3b5e14c 100644 --- a/pyiceberg/catalog/dynamodb.py +++ b/pyiceberg/catalog/dynamodb.py @@ -156,13 +156,7 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ - if not isinstance(schema, Schema): - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore database_name, table_name = self.identifier_to_database_and_table(identifier) diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index aad7c01083..8f860fabba 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -358,13 +358,7 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ - if not isinstance(schema, Schema): - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore database_name, table_name = self.identifier_to_database_and_table(identifier) diff --git a/pyiceberg/catalog/hive.py b/pyiceberg/catalog/hive.py index 82c5f2cd3d..8069321095 100644 --- a/pyiceberg/catalog/hive.py +++ b/pyiceberg/catalog/hive.py @@ -278,13 +278,7 @@ def create_table( AlreadyExistsError: If a table with the name already exists. ValueError: If the identifier is invalid. """ - if not isinstance(schema, Schema): - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore properties = {**DEFAULT_PROPERTIES, **properties} database_name, table_name = self.identifier_to_database_and_table(identifier) diff --git a/pyiceberg/catalog/rest.py b/pyiceberg/catalog/rest.py index bf15a68449..34d75b5936 100644 --- a/pyiceberg/catalog/rest.py +++ b/pyiceberg/catalog/rest.py @@ -447,13 +447,7 @@ def create_table( sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> Table: - if not isinstance(schema, Schema): - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore namespace_and_table = self._split_identifier_for_path(identifier) request = CreateTableRequest( diff --git a/pyiceberg/catalog/sql.py b/pyiceberg/catalog/sql.py index 593a885c73..8a02b20dfc 100644 --- a/pyiceberg/catalog/sql.py +++ b/pyiceberg/catalog/sql.py @@ -169,13 +169,7 @@ def create_table( ValueError: If the identifier is invalid, or no path is given to store metadata. """ - if not isinstance(schema, Schema): - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore database_name, table_name = self.identifier_to_database_and_table(identifier) if not self._namespace_exists(database_name): diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index ca36c7b571..732a8f9134 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -875,6 +875,42 @@ def primitive(self, primitive: pa.DataType) -> bool: return True +def _pyarrow_datatype_to_iceberg_primitive_type(primitive: pa.DataType) -> PrimitiveType: + if pa.types.is_boolean(primitive): + return BooleanType() + elif pa.types.is_int32(primitive): + return IntegerType() + elif pa.types.is_int64(primitive): + return LongType() + elif pa.types.is_float32(primitive): + return FloatType() + elif pa.types.is_float64(primitive): + return DoubleType() + elif isinstance(primitive, pa.Decimal128Type): + primitive = cast(pa.Decimal128Type, primitive) + return DecimalType(primitive.precision, primitive.scale) + elif pa.types.is_string(primitive): + return StringType() + elif pa.types.is_date32(primitive): + return DateType() + elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": + return TimeType() + elif pa.types.is_timestamp(primitive): + primitive = cast(pa.TimestampType, primitive) + if primitive.unit == "us": + if primitive.tz == "UTC" or primitive.tz == "+00:00": + return TimestamptzType() + elif primitive.tz is None: + return TimestampType() + elif pa.types.is_binary(primitive): + return BinaryType() + elif pa.types.is_fixed_size_binary(primitive): + primitive = cast(pa.FixedSizeBinaryType, primitive) + return FixedType(primitive.byte_width) + + raise TypeError(f"Unsupported type: {primitive}") + + class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided.""" @@ -927,39 +963,7 @@ def map(self, map_type: pa.MapType, key_result: IcebergType, value_result: Icebe return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) def primitive(self, primitive: pa.DataType) -> PrimitiveType: - if pa.types.is_boolean(primitive): - return BooleanType() - elif pa.types.is_int32(primitive): - return IntegerType() - elif pa.types.is_int64(primitive): - return LongType() - elif pa.types.is_float32(primitive): - return FloatType() - elif pa.types.is_float64(primitive): - return DoubleType() - elif isinstance(primitive, pa.Decimal128Type): - primitive = cast(pa.Decimal128Type, primitive) - return DecimalType(primitive.precision, primitive.scale) - elif pa.types.is_string(primitive): - return StringType() - elif pa.types.is_date32(primitive): - return DateType() - elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": - return TimeType() - elif pa.types.is_timestamp(primitive): - primitive = cast(pa.TimestampType, primitive) - if primitive.unit == "us": - if primitive.tz == "UTC" or primitive.tz == "+00:00": - return TimestamptzType() - elif primitive.tz is None: - return TimestampType() - elif pa.types.is_binary(primitive): - return BinaryType() - elif pa.types.is_fixed_size_binary(primitive): - primitive = cast(pa.FixedSizeBinaryType, primitive) - return FixedType(primitive.byte_width) - - raise TypeError(f"Unsupported type: {primitive}") + return _pyarrow_datatype_to_iceberg_primitive_type(primitive) def before_field(self, field: pa.Field) -> None: self._field_names.append(field.name) @@ -1021,39 +1025,7 @@ def map( return MapType(key_id, key_result(), value_id, value_result(), value_required=not value_field.nullable) def primitive(self, primitive: pa.DataType) -> PrimitiveType: - if pa.types.is_boolean(primitive): - return BooleanType() - elif pa.types.is_int32(primitive): - return IntegerType() - elif pa.types.is_int64(primitive): - return LongType() - elif pa.types.is_float32(primitive): - return FloatType() - elif pa.types.is_float64(primitive): - return DoubleType() - elif isinstance(primitive, pa.Decimal128Type): - primitive = cast(pa.Decimal128Type, primitive) - return DecimalType(primitive.precision, primitive.scale) - elif pa.types.is_string(primitive): - return StringType() - elif pa.types.is_date32(primitive): - return DateType() - elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": - return TimeType() - elif pa.types.is_timestamp(primitive): - primitive = cast(pa.TimestampType, primitive) - if primitive.unit == "us": - if primitive.tz == "UTC" or primitive.tz == "+00:00": - return TimestamptzType() - elif primitive.tz is None: - return TimestampType() - elif pa.types.is_binary(primitive): - return BinaryType() - elif pa.types.is_fixed_size_binary(primitive): - primitive = cast(pa.FixedSizeBinaryType, primitive) - return FixedType(primitive.byte_width) - - raise TypeError(f"Unsupported type: {primitive}") + return _pyarrow_datatype_to_iceberg_primitive_type(primitive) def _task_to_table( diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 21091d6410..bebaf5a398 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -79,13 +79,7 @@ def create_table( sort_order: SortOrder = UNSORTED_SORT_ORDER, properties: Properties = EMPTY_DICT, ) -> Table: - if not isinstance(schema, Schema): - import pyarrow as pa - - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow - - if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = self._convert_schema_if_needed(schema) # type: ignore identifier = Catalog.identifier_to_tuple(identifier) namespace = Catalog.namespace_from(identifier) @@ -343,6 +337,23 @@ def test_create_table(catalog: InMemoryCatalog) -> None: assert catalog.load_table(TEST_TABLE_IDENTIFIER) == table +@pytest.mark.parametrize( + "schema_fixture,expected_fixture", + [ + ("pyarrow_schema_simple_without_ids", "iceberg_schema_simple"), + ("iceberg_schema_simple", "iceberg_schema_simple"), + ("iceberg_schema_nested", "iceberg_schema_nested"), + ("pyarrow_schema_nested_without_ids", "iceberg_schema_nested"), + ], +) +def test_convert_schema_if_needed( + schema_fixture: str, expected_fixture: str, catalog: InMemoryCatalog, request: pytest.FixtureRequest +) -> None: + schema = request.getfixturevalue(schema_fixture) + expected = request.getfixturevalue(expected_fixture) + assert expected == catalog._convert_schema_if_needed(schema) + + def test_create_table_pyarrow_schema(catalog: InMemoryCatalog, pyarrow_schema_simple_without_ids: pa.Schema) -> None: table = catalog.create_table( identifier=TEST_TABLE_IDENTIFIER, From 6ea089214867211711637219ae6bca4bc2fc790a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Thu, 25 Jan 2024 18:48:25 +0000 Subject: [PATCH 3/9] docs --- mkdocs/docs/api.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 9d97d4f676..fd4b9454e4 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -146,6 +146,26 @@ catalog.create_table( ) ``` +One can also create an Iceberg table using a pyarrow schema: + +```python +import pyarrow as pa + +pa.schema( + [ + pa.field("foo", pa.string(), nullable=True), + pa.field("bar", pa.int32(), nullable=False), + pa.field("baz", pa.bool_(), nullable=True), + ] +) + +catalog.create_table( + identifier="docs_example.bids", + schema=schema, + location="s3://pyiceberg", +) +``` + ## Load a table ### Catalog table From 7ba0fd12736086c27c2effc3bdad04f13ec9770a Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 27 Jan 2024 14:35:20 -0500 Subject: [PATCH 4/9] Take suggestion Co-authored-by: Kevin Liu --- mkdocs/docs/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index fd4b9454e4..c8d084584a 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -146,7 +146,7 @@ catalog.create_table( ) ``` -One can also create an Iceberg table using a pyarrow schema: +To create a table using a pyarrow schema: ```python import pyarrow as pa From 39a97a2c5295a857f3e42520088faa953583fb47 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 27 Jan 2024 14:35:49 -0500 Subject: [PATCH 5/9] Take suggestion Co-authored-by: Kevin Liu --- mkdocs/docs/api.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index c8d084584a..9aba307b84 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -151,7 +151,7 @@ To create a table using a pyarrow schema: ```python import pyarrow as pa -pa.schema( +schema = pa.schema( [ pa.field("foo", pa.string(), nullable=True), pa.field("bar", pa.int32(), nullable=False), @@ -162,7 +162,6 @@ pa.schema( catalog.create_table( identifier="docs_example.bids", schema=schema, - location="s3://pyiceberg", ) ``` From 60ac8f838e632c5710ff2baf0e278b6dda956813 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 27 Jan 2024 19:52:16 +0000 Subject: [PATCH 6/9] take nit --- tests/catalog/test_base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index bebaf5a398..8b80d67423 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -258,10 +258,6 @@ def catalog() -> InMemoryCatalog: return InMemoryCatalog("test.in.memory.catalog", **{"test.key": "test.value"}) -TEST_PYARROW_SCHEMA = pa.schema([ - pa.field('some_int', pa.int32(), nullable=True), - pa.field('some_string', pa.string(), nullable=False), -]) TEST_TABLE_IDENTIFIER = ("com", "organization", "department", "my_table") TEST_TABLE_NAMESPACE = ("com", "organization", "department") TEST_TABLE_NAME = "my_table" From b917afd2f30fac874c8284922eaea6f664e1c389 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Sat, 27 Jan 2024 20:14:49 +0000 Subject: [PATCH 7/9] more nit --- pyiceberg/catalog/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index d40b38b566..f5c1079620 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -518,6 +518,8 @@ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> Non @staticmethod def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: + if isinstance(schema, Schema): + return schema try: import pyarrow as pa @@ -525,12 +527,10 @@ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: if isinstance(schema, pa.Schema): schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + return schema except ModuleNotFoundError: pass - - if not isinstance(schema, Schema): - raise ValueError(f"{type(schema)=} must be pyiceberg.schema.Schema") - return schema + raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema") def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str: if not location: From e4e9f9b2ce2f0a5fafa6043e85bf843b5ef289b6 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 29 Jan 2024 02:23:51 +0000 Subject: [PATCH 8/9] refactoring --- pyiceberg/catalog/__init__.py | 4 +- pyiceberg/io/pyarrow.py | 171 ++++++++----------------------- pyiceberg/schema.py | 53 +++++----- tests/catalog/test_base.py | 17 +-- tests/conftest.py | 58 +++++++++++ tests/io/test_pyarrow_visitor.py | 11 +- 6 files changed, 149 insertions(+), 165 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index f5c1079620..82b487a2bc 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -523,10 +523,10 @@ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: try: import pyarrow as pa - from pyiceberg.io.pyarrow import _ConvertToIcebergWithFreshIds, pre_order_visit_pyarrow + from pyiceberg.io.pyarrow import _ConvertToIcebergWithNoIds, visit_pyarrow if isinstance(schema, pa.Schema): - schema: Schema = pre_order_visit_pyarrow(schema, _ConvertToIcebergWithFreshIds()) # type: ignore + schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithNoIds()) # type: ignore return schema except ModuleNotFoundError: pass diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 732a8f9134..530c0c85a6 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -34,7 +34,7 @@ from concurrent.futures import Future from dataclasses import dataclass from enum import Enum -from functools import lru_cache, partial, singledispatch +from functools import lru_cache, singledispatch from typing import ( TYPE_CHECKING, Any, @@ -711,60 +711,6 @@ def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T: return visitor.primitive(obj) -@singledispatch -def pre_order_visit_pyarrow(obj: Union[pa.DataType, pa.Schema], visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: - """Apply a pyarrow schema visitor to any point within a schema. - - The function traverses the schema in pre-order fashion. - - Args: - obj (Union[pa.DataType, pa.Schema]): An instance of a Schema or an IcebergType. - visitor (PyArrowSchemaVisitor[T]): An instance of an implementation of the generic PyarrowSchemaVisitor base class. - - Raises: - NotImplementedError: If attempting to visit an unrecognized object type. - """ - raise NotImplementedError(f"Cannot visit non-type: {obj}") - - -@pre_order_visit_pyarrow.register(pa.Schema) -def _(obj: pa.Schema, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: - return visitor.schema(obj, lambda: pre_order_visit_pyarrow(pa.struct(obj), visitor)) - - -@pre_order_visit_pyarrow.register(pa.StructType) -def _(obj: pa.StructType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: - return visitor.struct( - obj, - [ - partial( - lambda field: visitor.field(field, partial(lambda field: pre_order_visit_pyarrow(field.type, visitor), field)), - field, - ) - for field in obj - ], - ) - - -@pre_order_visit_pyarrow.register(pa.ListType) -def _(obj: pa.ListType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: - return visitor.list(obj, lambda: pre_order_visit_pyarrow(obj.value_type, visitor)) - - -@pre_order_visit_pyarrow.register(pa.MapType) -def _(obj: pa.MapType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: - return visitor.map( - obj, lambda: pre_order_visit_pyarrow(obj.key_type, visitor), lambda: pre_order_visit_pyarrow(obj.item_type, visitor) - ) - - -@pre_order_visit_pyarrow.register(pa.DataType) -def _(obj: pa.DataType, visitor: PreOrderPyArrowSchemaVisitor[T]) -> T: - if pa.types.is_nested(obj): - raise TypeError(f"Expected primitive type, got: {type(obj)}") - return visitor.primitive(obj) - - class PyArrowSchemaVisitor(Generic[T], ABC): def before_field(self, field: pa.Field) -> None: """Override this method to perform an action immediately before visiting a field.""" @@ -875,42 +821,6 @@ def primitive(self, primitive: pa.DataType) -> bool: return True -def _pyarrow_datatype_to_iceberg_primitive_type(primitive: pa.DataType) -> PrimitiveType: - if pa.types.is_boolean(primitive): - return BooleanType() - elif pa.types.is_int32(primitive): - return IntegerType() - elif pa.types.is_int64(primitive): - return LongType() - elif pa.types.is_float32(primitive): - return FloatType() - elif pa.types.is_float64(primitive): - return DoubleType() - elif isinstance(primitive, pa.Decimal128Type): - primitive = cast(pa.Decimal128Type, primitive) - return DecimalType(primitive.precision, primitive.scale) - elif pa.types.is_string(primitive): - return StringType() - elif pa.types.is_date32(primitive): - return DateType() - elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": - return TimeType() - elif pa.types.is_timestamp(primitive): - primitive = cast(pa.TimestampType, primitive) - if primitive.unit == "us": - if primitive.tz == "UTC" or primitive.tz == "+00:00": - return TimestamptzType() - elif primitive.tz is None: - return TimestampType() - elif pa.types.is_binary(primitive): - return BinaryType() - elif pa.types.is_fixed_size_binary(primitive): - primitive = cast(pa.FixedSizeBinaryType, primitive) - return FixedType(primitive.byte_width) - - raise TypeError(f"Unsupported type: {primitive}") - - class _ConvertToIceberg(PyArrowSchemaVisitor[Union[IcebergType, Schema]]): """Converts PyArrowSchema to Iceberg Schema. Applies the IDs from name_mapping if provided.""" @@ -963,7 +873,39 @@ def map(self, map_type: pa.MapType, key_result: IcebergType, value_result: Icebe return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable) def primitive(self, primitive: pa.DataType) -> PrimitiveType: - return _pyarrow_datatype_to_iceberg_primitive_type(primitive) + if pa.types.is_boolean(primitive): + return BooleanType() + elif pa.types.is_int32(primitive): + return IntegerType() + elif pa.types.is_int64(primitive): + return LongType() + elif pa.types.is_float32(primitive): + return FloatType() + elif pa.types.is_float64(primitive): + return DoubleType() + elif isinstance(primitive, pa.Decimal128Type): + primitive = cast(pa.Decimal128Type, primitive) + return DecimalType(primitive.precision, primitive.scale) + elif pa.types.is_string(primitive): + return StringType() + elif pa.types.is_date32(primitive): + return DateType() + elif isinstance(primitive, pa.Time64Type) and primitive.unit == "us": + return TimeType() + elif pa.types.is_timestamp(primitive): + primitive = cast(pa.TimestampType, primitive) + if primitive.unit == "us": + if primitive.tz == "UTC" or primitive.tz == "+00:00": + return TimestamptzType() + elif primitive.tz is None: + return TimestampType() + elif pa.types.is_binary(primitive): + return BinaryType() + elif pa.types.is_fixed_size_binary(primitive): + primitive = cast(pa.FixedSizeBinaryType, primitive) + return FixedType(primitive.byte_width) + + raise TypeError(f"Unsupported type: {primitive}") def before_field(self, field: pa.Field) -> None: self._field_names.append(field.name) @@ -990,42 +932,19 @@ def after_map_value(self, element: pa.Field) -> None: self._field_names.pop() -class _ConvertToIcebergWithFreshIds(PreOrderPyArrowSchemaVisitor[Union[IcebergType, Schema]]): - """Converts PyArrowSchema to Iceberg Schema with fresh ids.""" - - def __init__(self) -> None: - self.counter = itertools.count(1) - - def _field_id(self) -> int: - return next(self.counter) - - def schema(self, schema: pa.Schema, struct_result: Callable[[], StructType]) -> Schema: - return Schema(*struct_result().fields) - - def struct(self, struct: pa.StructType, field_results: List[Callable[[], NestedField]]) -> StructType: - return StructType(*[field() for field in field_results]) - - def field(self, field: pa.Field, field_result: Callable[[], IcebergType]) -> NestedField: - field_id = self._field_id() - field_doc = doc_str.decode() if (field.metadata and (doc_str := field.metadata.get(PYARROW_FIELD_DOC_KEY))) else None - field_type = field_result() - return NestedField(field_id, field.name, field_type, required=not field.nullable, doc=field_doc) - - def list(self, list_type: pa.ListType, element_result: Callable[[], IcebergType]) -> ListType: - element_field = list_type.value_field - element_id = self._field_id() - return ListType(element_id, element_result(), element_required=not element_field.nullable) +class _ConvertToIcebergWithNoIds(_ConvertToIceberg): + """ + Converts PyArrowSchema to Iceberg Schema with all -1 ids. - def map( - self, map_type: pa.MapType, key_result: Callable[[], IcebergType], value_result: Callable[[], IcebergType] - ) -> MapType: - key_id = self._field_id() - value_field = map_type.item_field - value_id = self._field_id() - return MapType(key_id, key_result(), value_id, value_result(), value_required=not value_field.nullable) + The schema generated through this visitor should always be + used in conjunction with `new_table_metadata` function to + assign new field ids in order. This is currently used only + when creating an Iceberg Schema from a PyArrow schema when + creating a new Iceberg table. + """ - def primitive(self, primitive: pa.DataType) -> PrimitiveType: - return _pyarrow_datatype_to_iceberg_primitive_type(primitive) + def _field_id(self, field: pa.Field) -> int: + return -1 def _task_to_table( diff --git a/pyiceberg/schema.py b/pyiceberg/schema.py index b61e4678b9..6dd174f325 100644 --- a/pyiceberg/schema.py +++ b/pyiceberg/schema.py @@ -1221,50 +1221,57 @@ def assign_fresh_schema_ids(schema_or_type: Union[Schema, IcebergType], next_id: class _SetFreshIDs(PreOrderSchemaVisitor[IcebergType]): """Traverses the schema and assigns monotonically increasing ids.""" - reserved_ids: Dict[int, int] + old_id_to_new_id: Dict[int, int] def __init__(self, next_id_func: Optional[Callable[[], int]] = None) -> None: - self.reserved_ids = {} + self.old_id_to_new_id = {} counter = itertools.count(1) self.next_id_func = next_id_func if next_id_func is not None else lambda: next(counter) - def _get_and_increment(self) -> int: - return self.next_id_func() + def _get_and_increment(self, current_id: int) -> int: + new_id = self.next_id_func() + self.old_id_to_new_id[current_id] = new_id + return new_id def schema(self, schema: Schema, struct_result: Callable[[], StructType]) -> Schema: - # First we keep the original identifier_field_ids here, we remap afterwards - fields = struct_result().fields - return Schema(*fields, identifier_field_ids=[self.reserved_ids[field_id] for field_id in schema.identifier_field_ids]) + return Schema( + *struct_result().fields, + identifier_field_ids=[self.old_id_to_new_id[field_id] for field_id in schema.identifier_field_ids], + ) def struct(self, struct: StructType, field_results: List[Callable[[], IcebergType]]) -> StructType: - # assign IDs for this struct's fields first - self.reserved_ids.update({field.field_id: self._get_and_increment() for field in struct.fields}) - return StructType(*[field() for field in field_results]) + new_ids = [self._get_and_increment(field.field_id) for field in struct.fields] + new_fields = [] + for field_id, field, field_type in zip(new_ids, struct.fields, field_results): + new_fields.append( + NestedField( + field_id=field_id, + name=field.name, + field_type=field_type(), + required=field.required, + doc=field.doc, + ) + ) + return StructType(*new_fields) def field(self, field: NestedField, field_result: Callable[[], IcebergType]) -> IcebergType: - return NestedField( - field_id=self.reserved_ids[field.field_id], - name=field.name, - field_type=field_result(), - required=field.required, - doc=field.doc, - ) + return field_result() def list(self, list_type: ListType, element_result: Callable[[], IcebergType]) -> ListType: - self.reserved_ids[list_type.element_id] = self._get_and_increment() + element_id = self._get_and_increment(list_type.element_id) return ListType( - element_id=self.reserved_ids[list_type.element_id], + element_id=element_id, element=element_result(), element_required=list_type.element_required, ) def map(self, map_type: MapType, key_result: Callable[[], IcebergType], value_result: Callable[[], IcebergType]) -> MapType: - self.reserved_ids[map_type.key_id] = self._get_and_increment() - self.reserved_ids[map_type.value_id] = self._get_and_increment() + key_id = self._get_and_increment(map_type.key_id) + value_id = self._get_and_increment(map_type.value_id) return MapType( - key_id=self.reserved_ids[map_type.key_id], + key_id=key_id, key_type=key_result(), - value_id=self.reserved_ids[map_type.value_id], + value_id=value_id, value_type=value_result(), value_required=map_type.value_required, ) diff --git a/tests/catalog/test_base.py b/tests/catalog/test_base.py index 8b80d67423..d15c90fee3 100644 --- a/tests/catalog/test_base.py +++ b/tests/catalog/test_base.py @@ -26,6 +26,7 @@ import pyarrow as pa import pytest +from pytest_lazyfixture import lazy_fixture from pyiceberg.catalog import ( Catalog, @@ -334,19 +335,19 @@ def test_create_table(catalog: InMemoryCatalog) -> None: @pytest.mark.parametrize( - "schema_fixture,expected_fixture", + "schema,expected", [ - ("pyarrow_schema_simple_without_ids", "iceberg_schema_simple"), - ("iceberg_schema_simple", "iceberg_schema_simple"), - ("iceberg_schema_nested", "iceberg_schema_nested"), - ("pyarrow_schema_nested_without_ids", "iceberg_schema_nested"), + (lazy_fixture("pyarrow_schema_simple_without_ids"), lazy_fixture("iceberg_schema_simple_no_ids")), + (lazy_fixture("iceberg_schema_simple"), lazy_fixture("iceberg_schema_simple")), + (lazy_fixture("iceberg_schema_nested"), lazy_fixture("iceberg_schema_nested")), + (lazy_fixture("pyarrow_schema_nested_without_ids"), lazy_fixture("iceberg_schema_nested_no_ids")), ], ) def test_convert_schema_if_needed( - schema_fixture: str, expected_fixture: str, catalog: InMemoryCatalog, request: pytest.FixtureRequest + schema: Union[Schema, pa.Schema], + expected: Schema, + catalog: InMemoryCatalog, ) -> None: - schema = request.getfixturevalue(schema_fixture) - expected = request.getfixturevalue(expected_fixture) assert expected == catalog._convert_schema_if_needed(schema) diff --git a/tests/conftest.py b/tests/conftest.py index 1b35f082bd..d9a8dfdf07 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -322,6 +322,15 @@ def iceberg_schema_simple() -> Schema: ) +@pytest.fixture(scope="session") +def iceberg_schema_simple_no_ids() -> Schema: + return Schema( + NestedField(field_id=-1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=-1, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=-1, name="baz", field_type=BooleanType(), required=False), + ) + + @pytest.fixture(scope="session") def iceberg_table_schema_simple() -> Schema: return Schema( @@ -382,6 +391,55 @@ def iceberg_schema_nested() -> Schema: ) +@pytest.fixture(scope="session") +def iceberg_schema_nested_no_ids() -> Schema: + return Schema( + NestedField(field_id=-1, name="foo", field_type=StringType(), required=True), + NestedField(field_id=-1, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=-1, name="baz", field_type=BooleanType(), required=False), + NestedField( + field_id=-1, + name="qux", + field_type=ListType(element_id=-1, element_type=StringType(), element_required=False), + required=True, + ), + NestedField( + field_id=-1, + name="quux", + field_type=MapType( + key_id=-1, + key_type=StringType(), + value_id=-1, + value_type=MapType(key_id=-1, key_type=StringType(), value_id=-1, value_type=IntegerType(), value_required=False), + value_required=False, + ), + required=True, + ), + NestedField( + field_id=-1, + name="location", + field_type=ListType( + element_id=-1, + element_type=StructType( + NestedField(field_id=-1, name="latitude", field_type=FloatType(), required=True), + NestedField(field_id=-1, name="longitude", field_type=FloatType(), required=True), + ), + element_required=False, + ), + required=True, + ), + NestedField( + field_id=-1, + name="person", + field_type=StructType( + NestedField(field_id=-1, name="name", field_type=StringType(), required=False), + NestedField(field_id=-1, name="age", field_type=IntegerType(), required=True), + ), + required=False, + ), + ) + + @pytest.fixture(scope="session") def all_avro_types() -> Dict[str, Any]: return { diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 509f474a3d..5600f67ccf 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -23,9 +23,8 @@ from pyiceberg.io.pyarrow import ( _ConvertToArrowSchema, _ConvertToIceberg, - _ConvertToIcebergWithFreshIds, + _ConvertToIcebergWithNoIds, _HasIds, - pre_order_visit_pyarrow, pyarrow_to_schema, schema_to_pyarrow, visit_pyarrow, @@ -480,12 +479,12 @@ def test_pyarrow_schema_to_schema_missing_ids_using_name_mapping_nested_missing_ def test_pyarrow_schema_to_schema_fresh_ids_simple_schema( - pyarrow_schema_simple_without_ids: pa.Schema, iceberg_schema_simple: Schema + pyarrow_schema_simple_without_ids: pa.Schema, iceberg_schema_simple_no_ids: Schema ) -> None: - assert pre_order_visit_pyarrow(pyarrow_schema_simple_without_ids, _ConvertToIcebergWithFreshIds()) == iceberg_schema_simple + assert visit_pyarrow(pyarrow_schema_simple_without_ids, _ConvertToIcebergWithNoIds()) == iceberg_schema_simple_no_ids def test_pyarrow_schema_to_schema_fresh_ids_nested_schema( - pyarrow_schema_nested_without_ids: pa.Schema, iceberg_schema_nested: Schema + pyarrow_schema_nested_without_ids: pa.Schema, iceberg_schema_nested_no_ids: Schema ) -> None: - assert pre_order_visit_pyarrow(pyarrow_schema_nested_without_ids, _ConvertToIcebergWithFreshIds()) == iceberg_schema_nested + assert visit_pyarrow(pyarrow_schema_nested_without_ids, _ConvertToIcebergWithNoIds()) == iceberg_schema_nested_no_ids From c40c553e2da406aeb7cddca17e75f65f1c2ab959 Mon Sep 17 00:00:00 2001 From: Sung Yun <107272191+syun64@users.noreply.github.com> Date: Mon, 29 Jan 2024 14:54:52 +0000 Subject: [PATCH 9/9] adopt review suggestions --- pyiceberg/catalog/__init__.py | 4 ++-- pyiceberg/io/pyarrow.py | 28 +--------------------------- tests/io/test_pyarrow_visitor.py | 6 +++--- 3 files changed, 6 insertions(+), 32 deletions(-) diff --git a/pyiceberg/catalog/__init__.py b/pyiceberg/catalog/__init__.py index 82b487a2bc..6e5dc2748f 100644 --- a/pyiceberg/catalog/__init__.py +++ b/pyiceberg/catalog/__init__.py @@ -523,10 +523,10 @@ def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema: try: import pyarrow as pa - from pyiceberg.io.pyarrow import _ConvertToIcebergWithNoIds, visit_pyarrow + from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow if isinstance(schema, pa.Schema): - schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithNoIds()) # type: ignore + schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore return schema except ModuleNotFoundError: pass diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 530c0c85a6..af4bb3d083 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -761,32 +761,6 @@ def primitive(self, primitive: pa.DataType) -> T: """Visit a primitive type.""" -class PreOrderPyArrowSchemaVisitor(Generic[T], ABC): - @abstractmethod - def schema(self, schema: pa.Schema, struct_result: Callable[[], T]) -> T: - """Visit a schema.""" - - @abstractmethod - def struct(self, struct: pa.StructType, field_results: List[Callable[[], T]]) -> T: - """Visit a struct.""" - - @abstractmethod - def field(self, field: pa.Field, field_result: Callable[[], T]) -> T: - """Visit a field.""" - - @abstractmethod - def list(self, list_type: pa.ListType, element_result: Callable[[], T]) -> T: - """Visit a list.""" - - @abstractmethod - def map(self, map_type: pa.MapType, key_result: Callable[[], T], value_result: Callable[[], T]) -> T: - """Visit a map.""" - - @abstractmethod - def primitive(self, primitive: pa.DataType) -> T: - """Visit a primitive type.""" - - def _get_field_id(field: pa.Field) -> Optional[int]: return ( int(field_id_str.decode()) @@ -932,7 +906,7 @@ def after_map_value(self, element: pa.Field) -> None: self._field_names.pop() -class _ConvertToIcebergWithNoIds(_ConvertToIceberg): +class _ConvertToIcebergWithoutIDs(_ConvertToIceberg): """ Converts PyArrowSchema to Iceberg Schema with all -1 ids. diff --git a/tests/io/test_pyarrow_visitor.py b/tests/io/test_pyarrow_visitor.py index 5600f67ccf..c7f364b920 100644 --- a/tests/io/test_pyarrow_visitor.py +++ b/tests/io/test_pyarrow_visitor.py @@ -23,7 +23,7 @@ from pyiceberg.io.pyarrow import ( _ConvertToArrowSchema, _ConvertToIceberg, - _ConvertToIcebergWithNoIds, + _ConvertToIcebergWithoutIDs, _HasIds, pyarrow_to_schema, schema_to_pyarrow, @@ -481,10 +481,10 @@ def test_pyarrow_schema_to_schema_missing_ids_using_name_mapping_nested_missing_ def test_pyarrow_schema_to_schema_fresh_ids_simple_schema( pyarrow_schema_simple_without_ids: pa.Schema, iceberg_schema_simple_no_ids: Schema ) -> None: - assert visit_pyarrow(pyarrow_schema_simple_without_ids, _ConvertToIcebergWithNoIds()) == iceberg_schema_simple_no_ids + assert visit_pyarrow(pyarrow_schema_simple_without_ids, _ConvertToIcebergWithoutIDs()) == iceberg_schema_simple_no_ids def test_pyarrow_schema_to_schema_fresh_ids_nested_schema( pyarrow_schema_nested_without_ids: pa.Schema, iceberg_schema_nested_no_ids: Schema ) -> None: - assert visit_pyarrow(pyarrow_schema_nested_without_ids, _ConvertToIcebergWithNoIds()) == iceberg_schema_nested_no_ids + assert visit_pyarrow(pyarrow_schema_nested_without_ids, _ConvertToIcebergWithoutIDs()) == iceberg_schema_nested_no_ids