Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions mkdocs/docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,25 @@ catalog.create_table(
)
```

To create a table using a pyarrow schema:

```python
import pyarrow as pa

schema = pa.schema(
[
pa.field("foo", pa.string(), nullable=True),
pa.field("bar", pa.int32(), nullable=False),
pa.field("baz", pa.bool_(), nullable=True),
]
)

catalog.create_table(
identifier="docs_example.bids",
schema=schema,
)
```

## Load a table

### Catalog table
Expand Down
22 changes: 21 additions & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from dataclasses import dataclass
from enum import Enum
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Expand Down Expand Up @@ -56,6 +57,9 @@
)
from pyiceberg.utils.config import Config, merge_config

if TYPE_CHECKING:
import pyarrow as pa

logger = logging.getLogger(__name__)

_ENV_CONFIG = Config()
Expand Down Expand Up @@ -288,7 +292,7 @@ def _load_file_io(self, properties: Properties = EMPTY_DICT, location: Optional[
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand Down Expand Up @@ -512,6 +516,22 @@ def _check_for_overlap(removals: Optional[Set[str]], updates: Properties) -> Non
if overlap:
raise ValueError(f"Updates and deletes have an overlap: {overlap}")

@staticmethod
def _convert_schema_if_needed(schema: Union[Schema, "pa.Schema"]) -> Schema:
if isinstance(schema, Schema):
return schema
try:
import pyarrow as pa

from pyiceberg.io.pyarrow import _ConvertToIcebergWithoutIDs, visit_pyarrow

if isinstance(schema, pa.Schema):
schema: Schema = visit_pyarrow(schema, _ConvertToIcebergWithoutIDs()) # type: ignore
return schema
except ModuleNotFoundError:
pass
raise ValueError(f"{type(schema)=}, but it must be pyiceberg.schema.Schema or pyarrow.Schema")

def _resolve_table_location(self, location: Optional[str], database_name: str, table_name: str) -> str:
if not location:
return self._get_default_warehouse_location(database_name, table_name)
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import uuid
from time import time
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -57,6 +58,9 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT

if TYPE_CHECKING:
import pyarrow as pa

DYNAMODB_CLIENT = "dynamodb"

DYNAMODB_COL_IDENTIFIER = "identifier"
Expand Down Expand Up @@ -127,7 +131,7 @@ def _dynamodb_table_exists(self) -> bool:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -152,6 +156,8 @@ def create_table(
ValueError: If the identifier is invalid, or no path is given to store metadata.

"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

database_name, table_name = self.identifier_to_database_and_table(identifier)

location = self._resolve_table_location(location, database_name, table_name)
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@


from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -88,6 +89,9 @@
UUIDType,
)

if TYPE_CHECKING:
import pyarrow as pa

# If Glue should skip archiving an old table version when creating a new version in a commit. By
# default, Glue archives all old table versions after an UpdateTable call, but Glue has a default
# max number of archived table versions (can be increased). So for streaming use case with lots
Expand Down Expand Up @@ -329,7 +333,7 @@ def _get_glue_table(self, database_name: str, table_name: str) -> TableTypeDef:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -354,6 +358,8 @@ def create_table(
ValueError: If the identifier is invalid, or no path is given to store metadata.

"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

database_name, table_name = self.identifier_to_database_and_table(identifier)

location = self._resolve_table_location(location, database_name, table_name)
Expand Down
9 changes: 8 additions & 1 deletion pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import time
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -91,6 +92,10 @@
UUIDType,
)

if TYPE_CHECKING:
import pyarrow as pa


# Replace by visitor
hive_types = {
BooleanType: "boolean",
Expand Down Expand Up @@ -250,7 +255,7 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -273,6 +278,8 @@ def create_table(
AlreadyExistsError: If a table with the name already exists.
ValueError: If the identifier is invalid.
"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

properties = {**DEFAULT_PROPERTIES, **properties}
database_name, table_name = self.identifier_to_database_and_table(identifier)
current_time_millis = int(time.time() * 1000)
Expand Down
6 changes: 5 additions & 1 deletion pyiceberg/catalog/noop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import (
TYPE_CHECKING,
List,
Optional,
Set,
Expand All @@ -33,12 +34,15 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties

if TYPE_CHECKING:
import pyarrow as pa


class NoopCatalog(Catalog):
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
Expand Down Expand Up @@ -68,6 +69,9 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, UTF8, IcebergBaseModel

if TYPE_CHECKING:
import pyarrow as pa

ICEBERG_REST_SPEC_VERSION = "0.14.1"


Expand Down Expand Up @@ -437,12 +441,14 @@ def _response_to_table(self, identifier_tuple: Tuple[str, ...], table_response:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
properties: Properties = EMPTY_DICT,
) -> Table:
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

namespace_and_table = self._split_identifier_for_path(identifier)
request = CreateTableRequest(
name=namespace_and_table["table"],
Expand Down
8 changes: 7 additions & 1 deletion pyiceberg/catalog/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

from typing import (
TYPE_CHECKING,
List,
Optional,
Set,
Expand Down Expand Up @@ -65,6 +66,9 @@
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT

if TYPE_CHECKING:
import pyarrow as pa


class SqlCatalogBaseTable(MappedAsDataclass, DeclarativeBase):
pass
Expand Down Expand Up @@ -140,7 +144,7 @@ def _convert_orm_to_iceberg(self, orm_table: IcebergTables) -> Table:
def create_table(
self,
identifier: Union[str, Identifier],
schema: Schema,
schema: Union[Schema, "pa.Schema"],
location: Optional[str] = None,
partition_spec: PartitionSpec = UNPARTITIONED_PARTITION_SPEC,
sort_order: SortOrder = UNSORTED_SORT_ORDER,
Expand All @@ -165,6 +169,8 @@ def create_table(
ValueError: If the identifier is invalid, or no path is given to store metadata.

"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

database_name, table_name = self.identifier_to_database_and_table(identifier)
if not self._namespace_exists(database_name):
raise NoSuchNamespaceError(f"Namespace does not exist: {database_name}")
Expand Down
25 changes: 20 additions & 5 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

import concurrent.futures
import itertools
import logging
import os
import re
Expand All @@ -34,7 +35,6 @@
from dataclasses import dataclass
from enum import Enum
from functools import lru_cache, singledispatch
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -631,7 +631,7 @@ def _combine_positional_deletes(positional_deletes: List[pa.ChunkedArray], rows:
if len(positional_deletes) == 1:
all_chunks = positional_deletes[0]
else:
all_chunks = pa.chunked_array(chain(*[arr.chunks for arr in positional_deletes]))
all_chunks = pa.chunked_array(itertools.chain(*[arr.chunks for arr in positional_deletes]))
return np.setdiff1d(np.arange(rows), all_chunks, assume_unique=False)


Expand Down Expand Up @@ -906,6 +906,21 @@ def after_map_value(self, element: pa.Field) -> None:
self._field_names.pop()


class _ConvertToIcebergWithoutIDs(_ConvertToIceberg):
"""
Converts PyArrowSchema to Iceberg Schema with all -1 ids.

The schema generated through this visitor should always be
used in conjunction with `new_table_metadata` function to
assign new field ids in order. This is currently used only
when creating an Iceberg Schema from a PyArrow schema when
creating a new Iceberg table.
"""

def _field_id(self, field: pa.Field) -> int:
return -1


def _task_to_table(
fs: FileSystem,
task: FileScanTask,
Expand Down Expand Up @@ -993,7 +1008,7 @@ def _task_to_table(

def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dict[str, List[ChunkedArray]]:
deletes_per_file: Dict[str, List[ChunkedArray]] = {}
unique_deletes = set(chain.from_iterable([task.delete_files for task in tasks]))
unique_deletes = set(itertools.chain.from_iterable([task.delete_files for task in tasks]))
if len(unique_deletes) > 0:
executor = ExecutorFactory.get_or_create()
deletes_per_files: Iterator[Dict[str, ChunkedArray]] = executor.map(
Expand Down Expand Up @@ -1399,7 +1414,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[StatisticsColl
def struct(
self, struct: StructType, field_results: List[Callable[[], List[StatisticsCollector]]]
) -> List[StatisticsCollector]:
return list(chain(*[result() for result in field_results]))
return list(itertools.chain(*[result() for result in field_results]))

def field(self, field: NestedField, field_result: Callable[[], List[StatisticsCollector]]) -> List[StatisticsCollector]:
self._field_id = field.field_id
Expand Down Expand Up @@ -1491,7 +1506,7 @@ def schema(self, schema: Schema, struct_result: Callable[[], List[ID2ParquetPath
return struct_result()

def struct(self, struct: StructType, field_results: List[Callable[[], List[ID2ParquetPath]]]) -> List[ID2ParquetPath]:
return list(chain(*[result() for result in field_results]))
return list(itertools.chain(*[result() for result in field_results]))

def field(self, field: NestedField, field_result: Callable[[], List[ID2ParquetPath]]) -> List[ID2ParquetPath]:
self._field_id = field.field_id
Expand Down
Loading