Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: load_method overwrite. refactor: prepare_table() is only called upon Sink setup #321

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
215 changes: 64 additions & 151 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
VARCHAR,
TypeDecorator,
)
from singer_sdk.helpers.capabilities import TargetLoadMethods
from sshtunnel import SSHTunnelForwarder


Expand All @@ -41,6 +42,7 @@ class PostgresConnector(SQLConnector):
allow_column_rename: bool = True # Whether RENAME COLUMN is supported.
allow_column_alter: bool = False # Whether altering column types is supported.
allow_merge_upsert: bool = True # Whether MERGE UPSERT is supported.
allow_overwrite: bool = True # Whether overwrite load method is supported.
allow_temp_tables: bool = True # Whether temp tables are supported.

def __init__(self, config: dict) -> None:
Expand Down Expand Up @@ -92,6 +94,24 @@ def interpret_content_encoding(self) -> bool:
"""
return self.config.get("interpret_content_encoding", False)

def get_table_from_metadata(
self,
full_table_name: str,
connection: sa.engine.Connection
) -> sa.Table:
"""Returns an existing table object from the database

Args:
full_table_name: the fully qualified table name.

Returns:
The table object.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
meta.reflect(connection, only=[table_name])
return meta.tables[full_table_name]

def prepare_table( # type: ignore[override]
self,
full_table_name: str,
Expand All @@ -100,7 +120,7 @@ def prepare_table( # type: ignore[override]
connection: sa.engine.Connection,
partition_keys: list[str] | None = None,
as_temp_table: bool = False,
) -> sa.Table:
) -> None:
"""Adapt target table to provided schema if possible.

Args:
Expand All @@ -117,26 +137,39 @@ def prepare_table( # type: ignore[override]
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sa.MetaData(schema=schema_name)
table: sa.Table

if not self.table_exists(full_table_name=full_table_name):
table = self.create_empty_table(
table_name=table_name,
self.create_empty_table(
full_table_name=full_table_name,
meta=meta,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
connection=connection,
)
return

if self.config["load_method"] == TargetLoadMethods.OVERWRITE:
self.get_table(full_table_name=full_table_name).drop(self._engine)
self.create_empty_table(
full_table_name=full_table_name,
meta=meta,
schema=schema,
primary_keys=primary_keys,
partition_keys=partition_keys,
as_temp_table=as_temp_table,
connection=connection,
)
return table
return

meta.reflect(connection, only=[table_name])
table = meta.tables[
full_table_name
] # So we don't mess up the casing of the Table reference

columns = self.get_table_columns(
schema_name=cast(str, schema_name),
table_name=table_name,
connection=connection,
full_table_name=full_table_name,
)

for property_name, property_def in schema["properties"].items():
Expand All @@ -151,8 +184,6 @@ def prepare_table( # type: ignore[override]
column_object=column_object,
)

return meta.tables[full_table_name]

def copy_table_structure(
self,
full_table_name: str,
Expand Down Expand Up @@ -331,7 +362,7 @@ def pick_best_sql_type(sql_type_array: list):

def create_empty_table( # type: ignore[override]
self,
table_name: str,
full_table_name: str,
meta: sa.MetaData,
schema: dict,
connection: sa.engine.Connection,
Expand All @@ -357,6 +388,9 @@ def create_empty_table( # type: ignore[override]
NotImplementedError: if temp tables are unsupported and as_temp_table=True.
RuntimeError: if a variant schema is passed with no properties defined.
"""

_, schema_name, table_name = self.parse_full_table_name(full_table_name)

columns: list[sa.Column] = []
primary_keys = primary_keys or []
try:
Expand Down Expand Up @@ -410,66 +444,31 @@ def prepare_column(
_, schema_name, table_name = self.parse_full_table_name(full_table_name)

column_exists = column_object is not None or self.column_exists(
full_table_name, column_name, connection=connection
full_table_name, column_name,
)

if not column_exists:
self._create_empty_column(
# We should migrate every function to use sa.Table
# instead of having to know what the function wants
table_name=table_name,
full_table_name=full_table_name,
column_name=column_name,
sql_type=sql_type,
schema_name=cast(str, schema_name),
connection=connection,
)
return

self._adapt_column_type(
schema_name=cast(str, schema_name),
table_name=table_name,
full_table_name=full_table_name,
column_name=column_name,
sql_type=sql_type,
connection=connection,
column_object=column_object,
)

def _create_empty_column( # type: ignore[override]
self,
schema_name: str,
table_name: str,
column_name: str,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection,
) -> None:
"""Create a new column.

Args:
schema_name: The schema name.
table_name: The table name.
column_name: The name of the new column.
sql_type: SQLAlchemy type engine to be used in creating the new column.
connection: The database connection.

Raises:
NotImplementedError: if adding columns is not supported.
"""
if not self.allow_column_add:
msg = "Adding columns is not supported."
raise NotImplementedError(msg)

column_add_ddl = self.get_column_add_ddl(
schema_name=schema_name,
table_name=table_name,
column_name=column_name,
column_type=sql_type,
)
connection.execute(column_add_ddl)

def get_column_add_ddl( # type: ignore[override]
self,
table_name: str,
schema_name: str,
column_name: str,
column_type: sa.types.TypeEngine,
) -> sa.DDL:
Expand All @@ -484,6 +483,8 @@ def get_column_add_ddl( # type: ignore[override]
Returns:
A sqlalchemy DDL instance.
"""
_, schema_name, table_name = self.parse_full_table_name(table_name)

column = sa.Column(column_name, column_type)

return sa.DDL(
Expand All @@ -501,12 +502,11 @@ def get_column_add_ddl( # type: ignore[override]

def _adapt_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
full_table_name: str,
column_name: str,
sql_type: sa.types.TypeEngine,
connection: sa.engine.Connection,
column_object: sa.Column | None,
connection: sa.engine.Connection | None = None,
column_object: sa.Column | None = None,
) -> None:
"""Adapt table column type to support the new JSON schema type.

Expand All @@ -521,15 +521,21 @@ def _adapt_column_type( # type: ignore[override]
Raises:
NotImplementedError: if altering columns is not supported.
"""
if connection is None:
super()._adapt_column_type(
full_table_name=full_table_name,
column_name=column_name,
sql_type=sql_type,
)
return

current_type: sa.types.TypeEngine
if column_object is not None:
current_type = t.cast(sa.types.TypeEngine, column_object.type)
else:
current_type = self._get_column_type(
schema_name=schema_name,
table_name=table_name,
full_table_name=full_table_name,
column_name=column_name,
connection=connection,
)

# remove collation if present and save it
Expand All @@ -556,22 +562,20 @@ def _adapt_column_type( # type: ignore[override]
if not self.allow_column_alter:
msg = (
"Altering columns is not supported. Could not convert column "
f"'{schema_name}.{table_name}.{column_name}' from '{current_type}' to "
f"'{full_table_name}.{column_name}' from '{current_type}' to "
f"'{compatible_sql_type}'."
)
raise NotImplementedError(msg)

alter_column_ddl = self.get_column_alter_ddl(
schema_name=schema_name,
table_name=table_name,
table_name=full_table_name,
column_name=column_name,
column_type=compatible_sql_type,
)
connection.execute(alter_column_ddl)

def get_column_alter_ddl( # type: ignore[override]
self,
schema_name: str,
table_name: str,
column_name: str,
column_type: sa.types.TypeEngine,
Expand All @@ -589,6 +593,7 @@ def get_column_alter_ddl( # type: ignore[override]
Returns:
A sqlalchemy DDL instance.
"""
_, schema_name, _ = self.parse_full_table_name(table_name)
column = sa.Column(column_name, column_type)
return sa.DDL(
(
Expand Down Expand Up @@ -736,98 +741,6 @@ def catch_signal(self, signum, frame) -> None:
"""
exit(1) # Calling this to be sure atexit is called, so clean_up gets called

def _get_column_type( # type: ignore[override]
self,
schema_name: str,
table_name: str,
column_name: str,
connection: sa.engine.Connection,
) -> sa.types.TypeEngine:
"""Get the SQL type of the declared column.

Args:
schema_name: The schema name.
table_name: The table name.
column_name: The name of the column.
connection: The database connection.

Returns:
The type of the column.

Raises:
KeyError: If the provided column name does not exist.
"""
try:
column = self.get_table_columns(
schema_name=schema_name,
table_name=table_name,
connection=connection,
)[column_name]
except KeyError as ex:
msg = (
f"Column `{column_name}` does not exist in table"
"`{schema_name}.{table_name}`."
)
raise KeyError(msg) from ex

return t.cast(sa.types.TypeEngine, column.type)

def get_table_columns( # type: ignore[override]
self,
schema_name: str,
table_name: str,
connection: sa.engine.Connection,
column_names: list[str] | None = None,
) -> dict[str, sa.Column]:
"""Return a list of table columns.

Overrode to support schema_name

Args:
schema_name: schema name.
table_name: table name to get columns for.
connection: database connection.
column_names: A list of column names to filter to.

Returns:
An ordered list of column objects.
"""
inspector = sa.inspect(connection)
columns = inspector.get_columns(table_name, schema_name)

return {
col_meta["name"]: sa.Column(
col_meta["name"],
col_meta["type"],
nullable=col_meta.get("nullable", False),
)
for col_meta in columns
if not column_names
or col_meta["name"].casefold() in {col.casefold() for col in column_names}
}

def column_exists( # type: ignore[override]
self,
full_table_name: str,
column_name: str,
connection: sa.engine.Connection,
) -> bool:
"""Determine if the target column already exists.

Args:
full_table_name: the target table name.
column_name: the target column name.
connection: the database connection.

Returns:
True if table exists, False if not.
"""
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
assert schema_name is not None
assert table_name is not None
return column_name in self.get_table_columns(
schema_name=schema_name, table_name=table_name, connection=connection
)


class NOTYPE(TypeDecorator):
Expand Down