Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ build-backend = "hatchling.build"
[tool.hatch.version]
source = "vcs"

[tool.ruff]
line-length = 100

[tool.ruff.lint]
select = [
"F", # Pyflakes
Expand Down
17 changes: 4 additions & 13 deletions tap_postgres/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,7 @@ def patched_conform(elem: t.Any, property_schema: dict) -> t.Any:
epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc)
timedelta_from_epoch = epoch + elem
if timedelta_from_epoch.tzinfo is None:
timedelta_from_epoch = timedelta_from_epoch.replace(
tzinfo=datetime.timezone.utc
)
timedelta_from_epoch = timedelta_from_epoch.replace(tzinfo=datetime.timezone.utc)
return timedelta_from_epoch.isoformat()
if isinstance(elem, datetime.time): # copied
return str(elem)
Expand Down Expand Up @@ -336,17 +334,14 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]:
timeout = (
status_interval
- (
datetime.datetime.now()
- logical_replication_cursor.feedback_timestamp
datetime.datetime.now() - logical_replication_cursor.feedback_timestamp
).total_seconds()
)
try:
# If the timeout has passed and the cursor still has no new
# messages, the sync has completed.
if (
select.select(
[logical_replication_cursor], [], [], max(0, timeout)
)[0]
select.select([logical_replication_cursor], [], [], max(0, timeout))[0]
== []
):
break
Expand Down Expand Up @@ -383,11 +378,7 @@ def consume(self, message, cursor) -> dict | None:
for column in message_payload["identity"]:
row.update({column["name"]: self._parse_column_value(column, cursor)})
row.update(
{
"_sdc_deleted_at": datetime.datetime.utcnow().strftime(
r"%Y-%m-%dT%H:%M:%SZ"
)
}
{"_sdc_deleted_at": datetime.datetime.utcnow().strftime(r"%Y-%m-%dT%H:%M:%SZ")}
)
row.update({"_sdc_lsn": message.data_start})
elif message_payload["action"] in truncate_actions:
Expand Down
48 changes: 12 additions & 36 deletions tap_postgres/tap.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,7 @@ def __init__(
assert (
(self.config.get("sqlalchemy_url") is not None)
or (self.config.get("ssl_enable") is False)
or (
self.config.get("ssl_mode") in {"disable", "allow", "prefer", "require"}
)
or (self.config.get("ssl_mode") in {"disable", "allow", "prefer", "require"})
or (
self.config.get("ssl_mode") in {"verify-ca", "verify-full"}
and self.config.get("ssl_certificate_authority") is not None
Expand Down Expand Up @@ -141,33 +139,25 @@ def __init__(
th.StringType,
secret=True,
description=(
"Password used to authenticate. "
"Note if sqlalchemy_url is set this will be ignored."
"Password used to authenticate. Note if sqlalchemy_url is set this will be ignored."
),
),
th.Property(
"database",
th.StringType,
description=(
"Database name. "
+ "Note if sqlalchemy_url is set this will be ignored."
),
description=("Database name. " + "Note if sqlalchemy_url is set this will be ignored."),
),
th.Property(
"max_record_count",
th.IntegerType,
default=None,
description=(
"Optional. The maximum number of records to return in a single stream."
),
description=("Optional. The maximum number of records to return in a single stream."),
),
th.Property(
"sqlalchemy_url",
th.StringType,
secret=True,
description=(
"Example postgresql://[username]:[password]@localhost:5432/[db_name]"
),
description=("Example postgresql://[username]:[password]@localhost:5432/[db_name]"),
),
th.Property(
"filter_schemas",
Expand All @@ -191,9 +181,7 @@ def __init__(
th.Property(
"json_as_object",
th.BooleanType,
description=(
"Defaults to false, if true, json and jsonb fields will be Objects."
),
description=("Defaults to false, if true, json and jsonb fields will be Objects."),
default=False,
),
th.Property(
Expand Down Expand Up @@ -232,8 +220,7 @@ def __init__(
th.StringType,
required=False,
description=(
"Host of the bastion server, this is the host "
"we'll connect to via ssh"
"Host of the bastion server, this is the host we'll connect to via ssh"
),
),
th.Property(
Expand Down Expand Up @@ -262,9 +249,7 @@ def __init__(
required=False,
secret=True,
default=None,
description=(
"Private Key Password, leave None if no password is set"
),
description=("Private Key Password, leave None if no password is set"),
),
),
required=False,
Expand Down Expand Up @@ -396,9 +381,7 @@ def get_sqlalchemy_query(self, config: Mapping[str, Any]) -> dict:
if config["ssl_enable"]:
ssl_mode = config["ssl_mode"]
query.update({"sslmode": ssl_mode})
if ssl_mode in ("verify-ca", "verify-full") and config.get(
"ssl_certificate_authority"
):
if ssl_mode in ("verify-ca", "verify-full") and config.get("ssl_certificate_authority"):
query["sslrootcert"] = self.filepath_or_certificate(
value=config["ssl_certificate_authority"],
alternative_name=config["ssl_storage_directory"] + "/root.crt",
Expand Down Expand Up @@ -583,10 +566,7 @@ def catalog(self) -> Catalog:
for stream in super().catalog.streams:
stream_modified = False
new_stream = copy.deepcopy(stream)
if (
new_stream.replication_method == "LOG_BASED"
and new_stream.schema.properties
):
if new_stream.replication_method == "LOG_BASED" and new_stream.schema.properties:
for property in new_stream.schema.properties.values():
if "null" not in property.type:
if isinstance(property.type, list):
Expand Down Expand Up @@ -641,12 +621,8 @@ def discover_streams(self) -> Sequence[Stream]:
for catalog_entry in self.catalog_dict["streams"]:
if catalog_entry["replication_method"] == "LOG_BASED":
streams.append(
PostgresLogBasedStream(
self, catalog_entry, connector=self.connector
)
PostgresLogBasedStream(self, catalog_entry, connector=self.connector)
)
else:
streams.append(
PostgresStream(self, catalog_entry, connector=self.connector)
)
streams.append(PostgresStream(self, catalog_entry, connector=self.connector))
return streams
64 changes: 24 additions & 40 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import copy
import datetime
import decimal
import json
from pathlib import Path

import pytest
import sqlalchemy as sa
Expand Down Expand Up @@ -42,6 +44,15 @@
}


def _load_catalog(path: Path) -> Catalog:
return Catalog.from_dict(json.loads(path.read_text()))


DATADIR = Path("tests/resources")
FULL_CATALOG = _load_catalog(DATADIR / "data.json")
SELECTED_COLUMNS_ONLY_CATALOG = _load_catalog(DATADIR / "data_selected_columns_only.json")


def setup_test_table(table_name, sqlalchemy_url):
"""setup any state specific to the execution of the given module."""
engine = sa.create_engine(sqlalchemy_url, future=True)
Expand Down Expand Up @@ -73,9 +84,7 @@ def teardown_test_table(table_name, sqlalchemy_url):
conn.execute(sa.text(f"DROP TABLE {table_name}"))


custom_test_replication_key = suites.SingerTestSuite(
kind="tap", tests=[TapTestReplicationKey]
)
custom_test_replication_key = suites.SingerTestSuite(kind="tap", tests=[TapTestReplicationKey])

custom_test_selected_columns_only = suites.SingerTestSuite(
kind="tap", tests=[TapTestSelectedColumnsOnly]
Expand All @@ -84,14 +93,14 @@ def teardown_test_table(table_name, sqlalchemy_url):
TapPostgresTest = get_tap_test_class(
tap_class=TapPostgres,
config=SAMPLE_CONFIG,
catalog="tests/resources/data.json",
catalog=FULL_CATALOG,
custom_suites=[custom_test_replication_key],
)

TapPostgresTestNOSQLALCHEMY = get_tap_test_class(
tap_class=TapPostgres,
config=NO_SQLALCHEMY_CONFIG,
catalog="tests/resources/data.json",
catalog=FULL_CATALOG,
custom_suites=[custom_test_replication_key],
)

Expand All @@ -100,7 +109,7 @@ def teardown_test_table(table_name, sqlalchemy_url):
TapPostgresTestSelectedColumnsOnly = get_tap_test_class(
tap_class=TapPostgres,
config=SAMPLE_CONFIG,
catalog="tests/resources/data_selected_columns_only.json",
catalog=SELECTED_COLUMNS_ONLY_CATALOG,
custom_suites=[custom_test_selected_columns_only],
)

Expand Down Expand Up @@ -182,17 +191,10 @@ def test_temporal_datatypes():
)
test_runner.sync_all()
for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
assert (
schema_message["schema"]["properties"]["column_date"]["format"]
== "date"
)
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
assert schema_message["schema"]["properties"]["column_date"]["format"] == "date"
assert (
schema_message["schema"]["properties"]["column_timestamp"]["format"]
== "date-time"
schema_message["schema"]["properties"]["column_timestamp"]["format"] == "date-time"
)
assert test_runner.records[altered_table_name][0] == {
"column_date": "2022-03-19",
Expand Down Expand Up @@ -246,10 +248,7 @@ def test_jsonb_json():
)
test_runner.sync_all()
for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
assert schema_message["schema"]["properties"]["column_jsonb"] == {
"type": [
"string",
Expand Down Expand Up @@ -320,10 +319,7 @@ def test_jsonb_array():
)
test_runner.sync_all()
for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
assert schema_message["schema"]["properties"]["column_jsonb_array"] == {
"items": {
"type": [
Expand Down Expand Up @@ -393,10 +389,7 @@ def test_json_as_object():
)
test_runner.sync_all()
for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
assert schema_message["schema"]["properties"]["column_jsonb"] == {
"type": [
"object",
Expand Down Expand Up @@ -467,10 +460,7 @@ def test_numeric_types():
test_runner.sync_all()

for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
props = schema_message["schema"]["properties"]
assert "number" in props["my_numeric"]["type"]
assert "number" in props["my_real"]["type"]
Expand Down Expand Up @@ -523,10 +513,7 @@ def test_hstore():
test_runner.sync_all()

for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
assert schema_message["schema"]["properties"]["hstore_column"] == {
"type": ["object", "null"],
"additionalProperties": True,
Expand Down Expand Up @@ -640,10 +627,7 @@ def test_invalid_python_dates(): # noqa: PLR0912
test_runner.sync_all()

for schema_message in test_runner.schema_messages:
if (
"stream" in schema_message
and schema_message["stream"] == altered_table_name
):
if "stream" in schema_message and schema_message["stream"] == altered_table_name:
assert schema_message["schema"]["properties"]["date"]["type"] == [
"string",
"null",
Expand Down
4 changes: 1 addition & 3 deletions tests/test_log_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,7 @@ def test_string_array_column():
as arrays (ex: "text[]") LOG_BASED replication can properly decode their value.
"""
table_name = "test_array_column"
engine = sa.create_engine(
"postgresql://postgres:postgres@localhost:5434/postgres", future=True
)
engine = sa.create_engine("postgresql://postgres:postgres@localhost:5434/postgres", future=True)

metadata_obj = sa.MetaData()
table = sa.Table(
Expand Down
4 changes: 1 addition & 3 deletions tests/test_slot_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ def test_multiple_slots(default_config: dict):
tap_1 = TapPostgres(config=config_1, setup_mapper=False)
tap_2 = TapPostgres(config=config_2, setup_mapper=False)

assert (
tap_1.config["replication_slot_name"] != tap_2.config["replication_slot_name"]
)
assert tap_1.config["replication_slot_name"] != tap_2.config["replication_slot_name"]
assert tap_1.config["replication_slot_name"] == "slot_1"
assert tap_2.config["replication_slot_name"] == "slot_2"

Expand Down