Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CreateTableTransaction for HiveCatalog #683

Merged
merged 15 commits into from
May 31, 2024
Merged
2 changes: 1 addition & 1 deletion pyiceberg/catalog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def _create_staged_table(
metadata = new_table_metadata(
location=location, schema=schema, partition_spec=partition_spec, sort_order=sort_order, properties=properties
)
io = load_file_io(properties=self.properties, location=metadata_location)
io = self._load_file_io(properties=properties, location=metadata_location)
return StagedTable(
identifier=(self.name, database_name, table_name),
metadata=metadata,
Expand Down
168 changes: 101 additions & 67 deletions pyiceberg/catalog/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,21 @@
NamespaceNotEmptyError,
NoSuchIcebergTableError,
NoSuchNamespaceError,
NoSuchPropertyException,
NoSuchTableError,
TableAlreadyExistsError,
)
from pyiceberg.io import FileIO, load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, PropertyUtil, Table, TableProperties, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table import (
CommitTableRequest,
CommitTableResponse,
PropertyUtil,
StagedTable,
Table,
TableProperties,
)
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT, Identifier, Properties
from pyiceberg.types import (
Expand Down Expand Up @@ -240,10 +246,12 @@ def __init__(self, name: str, **properties: str):
super().__init__(name, **properties)
self._client = _HiveClient(properties["uri"], properties.get("ugi"))

def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table:
def _convert_hive_into_iceberg(self, table: HiveTable) -> Table:
properties: Dict[str, str] = table.parameters
if TABLE_TYPE not in properties:
raise NoSuchTableError(f"Property table_type missing, could not determine type: {table.dbName}.{table.tableName}")
raise NoSuchPropertyException(
f"Property table_type missing, could not determine type: {table.dbName}.{table.tableName}"
)

table_type = properties[TABLE_TYPE]
if table_type.lower() != ICEBERG:
Expand All @@ -254,8 +262,9 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table:
if prop_metadata_location := properties.get(METADATA_LOCATION):
metadata_location = prop_metadata_location
else:
raise NoSuchTableError(f"Table property {METADATA_LOCATION} is missing")
raise NoSuchPropertyException(f"Table property {METADATA_LOCATION} is missing")

io = self._load_file_io(location=metadata_location)
file = io.new_input(metadata_location)
metadata = FromInputFile.table_metadata(file)
return Table(
Expand All @@ -266,6 +275,38 @@ def _convert_hive_into_iceberg(self, table: HiveTable, io: FileIO) -> Table:
catalog=self,
)

def _convert_iceberg_into_hive(self, table: Table) -> HiveTable:
identifier_tuple = self.identifier_to_tuple_without_catalog(table.identifier)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
current_time_millis = int(time.time() * 1000)

return HiveTable(
dbName=database_name,
tableName=table_name,
owner=table.properties[OWNER] if table.properties and OWNER in table.properties else getpass.getuser(),
createTime=current_time_millis // 1000,
lastAccessTime=current_time_millis // 1000,
sd=_construct_hive_storage_descriptor(
table.schema(),
table.location(),
PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT),
),
tableType=EXTERNAL_TABLE,
parameters=_construct_parameters(table.metadata_location),
)

def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None:
try:
open_client.create_table(hive_table)
except AlreadyExistsException as e:
raise TableAlreadyExistsError(f"Table {hive_table.dbName}.{hive_table.tableName} already exists") from e

def _get_hive_table(self, open_client: Client, database_name: str, table_name: str) -> HiveTable:
try:
return open_client.get_table(dbname=database_name, tbl_name=table_name)
except NoSuchObjectException as e:
raise NoSuchTableError(f"Table does not exists: {table_name}") from e

def create_table(
self,
identifier: Union[str, Identifier],
Expand All @@ -292,45 +333,25 @@ def create_table(
AlreadyExistsError: If a table with the name already exists.
ValueError: If the identifier is invalid.
"""
schema: Schema = self._convert_schema_if_needed(schema) # type: ignore

properties = {**DEFAULT_PROPERTIES, **properties}
database_name, table_name = self.identifier_to_database_and_table(identifier)
current_time_millis = int(time.time() * 1000)

location = self._resolve_table_location(location, database_name, table_name)

metadata_location = self._get_metadata_location(location=location)
metadata = new_table_metadata(
location=location,
staged_table = self._create_staged_table(
identifier=identifier,
schema=schema,
location=location,
partition_spec=partition_spec,
sort_order=sort_order,
properties=properties,
)
io = load_file_io({**self.properties, **properties}, location=location)
self._write_metadata(metadata, io, metadata_location)
database_name, table_name = self.identifier_to_database_and_table(identifier)

tbl = HiveTable(
dbName=database_name,
tableName=table_name,
owner=properties[OWNER] if properties and OWNER in properties else getpass.getuser(),
createTime=current_time_millis // 1000,
lastAccessTime=current_time_millis // 1000,
sd=_construct_hive_storage_descriptor(
schema, location, PropertyUtil.property_as_bool(self.properties, HIVE2_COMPATIBLE, HIVE2_COMPATIBLE_DEFAULT)
),
tableType=EXTERNAL_TABLE,
parameters=_construct_parameters(metadata_location),
)
try:
with self._client as open_client:
open_client.create_table(tbl)
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
except AlreadyExistsException as e:
raise TableAlreadyExistsError(f"Table {database_name}.{table_name} already exists") from e
self._write_metadata(staged_table.metadata, staged_table.io, staged_table.metadata_location)
tbl = self._convert_iceberg_into_hive(staged_table)

with self._client as open_client:
self._create_hive_table(open_client, tbl)
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)

return self._convert_hive_into_iceberg(hive_table, io)
return self._convert_hive_into_iceberg(hive_table)

def register_table(self, identifier: Union[str, Identifier], metadata_location: str) -> Table:
"""Register a new table using existing metadata.
Expand Down Expand Up @@ -382,34 +403,50 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
if lock.state != LockState.ACQUIRED:
raise CommitFailedException(f"Failed to acquire lock for {table_request.identifier}, state: {lock.state}")

hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
io = load_file_io({**self.properties, **hive_table.parameters}, hive_table.sd.location)
current_table = self._convert_hive_into_iceberg(hive_table, io)

base_metadata = current_table.metadata
for requirement in table_request.requirements:
requirement.validate(base_metadata)

updated_metadata = update_table_metadata(base_metadata, table_request.updates)
if updated_metadata == base_metadata:
hive_table: Optional[HiveTable]
current_table: Optional[Table]
try:
hive_table = self._get_hive_table(open_client, database_name, table_name)
current_table = self._convert_hive_into_iceberg(hive_table)
except NoSuchTableError:
hive_table = None
current_table = None

updated_staged_table = self._update_and_stage_table(current_table, table_request)
if current_table and updated_staged_table.metadata == current_table.metadata:
# no changes, do nothing
return CommitTableResponse(metadata=base_metadata, metadata_location=current_table.metadata_location)

# write new metadata
new_metadata_version = self._parse_metadata_version(current_table.metadata_location) + 1
new_metadata_location = self._get_metadata_location(current_table.metadata.location, new_metadata_version)
self._write_metadata(updated_metadata, current_table.io, new_metadata_location)

hive_table.parameters = _construct_parameters(
metadata_location=new_metadata_location, previous_metadata_location=current_table.metadata_location
return CommitTableResponse(metadata=current_table.metadata, metadata_location=current_table.metadata_location)
self._write_metadata(
metadata=updated_staged_table.metadata,
io=updated_staged_table.io,
metadata_path=updated_staged_table.metadata_location,
)
open_client.alter_table(dbname=database_name, tbl_name=table_name, new_tbl=hive_table)
except NoSuchObjectException as e:
raise NoSuchTableError(f"Table does not exist: {table_name}") from e

if hive_table and current_table:
# Table exists, update it.
hive_table.parameters = _construct_parameters(
metadata_location=updated_staged_table.metadata_location,
previous_metadata_location=current_table.metadata_location,
)
open_client.alter_table(dbname=database_name, tbl_name=table_name, new_tbl=hive_table)
else:
# Table does not exist, create it.
hive_table = self._convert_iceberg_into_hive(
StagedTable(
identifier=(self.name, database_name, table_name),
metadata=updated_staged_table.metadata,
metadata_location=updated_staged_table.metadata_location,
io=updated_staged_table.io,
catalog=self,
)
)
self._create_hive_table(open_client, hive_table)
finally:
open_client.unlock(UnlockRequest(lockid=lock.lockid))

return CommitTableResponse(metadata=updated_metadata, metadata_location=new_metadata_location)
return CommitTableResponse(
metadata=updated_staged_table.metadata, metadata_location=updated_staged_table.metadata_location
)

def load_table(self, identifier: Union[str, Identifier]) -> Table:
"""Load the table's metadata and return the table instance.
Expand All @@ -428,14 +465,11 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
"""
identifier_tuple = self.identifier_to_tuple_without_catalog(identifier)
database_name, table_name = self.identifier_to_database_and_table(identifier_tuple, NoSuchTableError)
try:
with self._client as open_client:
hive_table = open_client.get_table(dbname=database_name, tbl_name=table_name)
except NoSuchObjectException as e:
raise NoSuchTableError(f"Table does not exists: {table_name}") from e

io = load_file_io({**self.properties, **hive_table.parameters}, hive_table.sd.location)
return self._convert_hive_into_iceberg(hive_table, io)
with self._client as open_client:
hive_table = self._get_hive_table(open_client, database_name, table_name)

return self._convert_hive_into_iceberg(hive_table)

def drop_table(self, identifier: Union[str, Identifier]) -> None:
"""Drop a table.
Expand Down
16 changes: 9 additions & 7 deletions tests/integration/test_writes/test_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from pyiceberg.catalog import Catalog
from pyiceberg.catalog.hive import HiveCatalog
from pyiceberg.catalog.rest import RestCatalog
from pyiceberg.catalog.sql import SqlCatalog
from pyiceberg.exceptions import NoSuchTableError
from pyiceberg.table import TableProperties, _dataframe_to_data_files
Expand Down Expand Up @@ -609,17 +610,18 @@ def test_write_and_evolve(session_catalog: Catalog, format_version: int) -> None


@pytest.mark.integration
@pytest.mark.parametrize("format_version", [2])
def test_create_table_transaction(session_catalog: Catalog, format_version: int) -> None:
if format_version == 1:
@pytest.mark.parametrize("format_version", [1, 2])
@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('session_catalog_hive'), pytest.lazy_fixture('session_catalog')])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're mixing double and single quotes 😱 Shouldn't Ruff fix this?

Copy link
Contributor Author

@HonahX HonahX May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our current setting is

[format]
quote-style = "preserve"

which leave quotes unchanged.

I changed it to "double" and found that 39 files need to be re-formatted. I've created a separete PR to give it a try: #781

I will re-visit this issue tomorrow if the test fails :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now #781 is in, we need to rebase every PR anyway :D

def test_create_table_transaction(catalog: Catalog, format_version: int) -> None:
if format_version == 1 and isinstance(catalog, RestCatalog):
pytest.skip(
"There is a bug in the REST catalog (maybe server side) that prevents create and commit a staged version 1 table"
)

identifier = f"default.arrow_create_table_transaction{format_version}"
identifier = f"default.arrow_create_table_transaction_{catalog.name}_{format_version}"

try:
session_catalog.drop_table(identifier=identifier)
catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

Expand All @@ -641,7 +643,7 @@ def test_create_table_transaction(session_catalog: Catalog, format_version: int)
]),
)

with session_catalog.create_table_transaction(
with catalog.create_table_transaction(
identifier=identifier, schema=pa_table.schema, properties={"format-version": str(format_version)}
) as txn:
with txn.update_snapshot().fast_append() as snapshot_update:
Expand All @@ -657,7 +659,7 @@ def test_create_table_transaction(session_catalog: Catalog, format_version: int)
):
snapshot_update.append_data_file(data_file)

tbl = session_catalog.load_table(identifier=identifier)
tbl = catalog.load_table(identifier=identifier)
assert tbl.format_version == format_version
assert len(tbl.scan().to_arrow()) == 6

Expand Down