diff --git a/pyproject.toml b/pyproject.toml index f55466c9..8b33f772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,9 @@ build-backend = "hatchling.build" [tool.hatch.version] source = "vcs" +[tool.ruff] +line-length = 100 + [tool.ruff.lint] select = [ "F", # Pyflakes diff --git a/tap_postgres/client.py b/tap_postgres/client.py index 16d04ab4..f6a216fc 100644 --- a/tap_postgres/client.py +++ b/tap_postgres/client.py @@ -130,9 +130,7 @@ def patched_conform(elem: t.Any, property_schema: dict) -> t.Any: epoch = datetime.datetime.fromtimestamp(0, datetime.timezone.utc) timedelta_from_epoch = epoch + elem if timedelta_from_epoch.tzinfo is None: - timedelta_from_epoch = timedelta_from_epoch.replace( - tzinfo=datetime.timezone.utc - ) + timedelta_from_epoch = timedelta_from_epoch.replace(tzinfo=datetime.timezone.utc) return timedelta_from_epoch.isoformat() if isinstance(elem, datetime.time): # copied return str(elem) @@ -336,17 +334,14 @@ def get_records(self, context: Context | None) -> Iterable[dict[str, t.Any]]: timeout = ( status_interval - ( - datetime.datetime.now() - - logical_replication_cursor.feedback_timestamp + datetime.datetime.now() - logical_replication_cursor.feedback_timestamp ).total_seconds() ) try: # If the timeout has passed and the cursor still has no new # messages, the sync has completed. if ( - select.select( - [logical_replication_cursor], [], [], max(0, timeout) - )[0] + select.select([logical_replication_cursor], [], [], max(0, timeout))[0] == [] ): break @@ -383,11 +378,7 @@ def consume(self, message, cursor) -> dict | None: for column in message_payload["identity"]: row.update({column["name"]: self._parse_column_value(column, cursor)}) row.update( - { - "_sdc_deleted_at": datetime.datetime.utcnow().strftime( - r"%Y-%m-%dT%H:%M:%SZ" - ) - } + {"_sdc_deleted_at": datetime.datetime.utcnow().strftime(r"%Y-%m-%dT%H:%M:%SZ")} ) row.update({"_sdc_lsn": message.data_start}) elif message_payload["action"] in truncate_actions: diff --git a/tap_postgres/tap.py b/tap_postgres/tap.py index 4dea1188..c89ddd91 100644 --- a/tap_postgres/tap.py +++ b/tap_postgres/tap.py @@ -71,9 +71,7 @@ def __init__( assert ( (self.config.get("sqlalchemy_url") is not None) or (self.config.get("ssl_enable") is False) - or ( - self.config.get("ssl_mode") in {"disable", "allow", "prefer", "require"} - ) + or (self.config.get("ssl_mode") in {"disable", "allow", "prefer", "require"}) or ( self.config.get("ssl_mode") in {"verify-ca", "verify-full"} and self.config.get("ssl_certificate_authority") is not None @@ -141,33 +139,25 @@ def __init__( th.StringType, secret=True, description=( - "Password used to authenticate. " - "Note if sqlalchemy_url is set this will be ignored." + "Password used to authenticate. Note if sqlalchemy_url is set this will be ignored." ), ), th.Property( "database", th.StringType, - description=( - "Database name. " - + "Note if sqlalchemy_url is set this will be ignored." - ), + description=("Database name. " + "Note if sqlalchemy_url is set this will be ignored."), ), th.Property( "max_record_count", th.IntegerType, default=None, - description=( - "Optional. The maximum number of records to return in a single stream." - ), + description=("Optional. The maximum number of records to return in a single stream."), ), th.Property( "sqlalchemy_url", th.StringType, secret=True, - description=( - "Example postgresql://[username]:[password]@localhost:5432/[db_name]" - ), + description=("Example postgresql://[username]:[password]@localhost:5432/[db_name]"), ), th.Property( "filter_schemas", @@ -191,9 +181,7 @@ def __init__( th.Property( "json_as_object", th.BooleanType, - description=( - "Defaults to false, if true, json and jsonb fields will be Objects." - ), + description=("Defaults to false, if true, json and jsonb fields will be Objects."), default=False, ), th.Property( @@ -232,8 +220,7 @@ def __init__( th.StringType, required=False, description=( - "Host of the bastion server, this is the host " - "we'll connect to via ssh" + "Host of the bastion server, this is the host we'll connect to via ssh" ), ), th.Property( @@ -262,9 +249,7 @@ def __init__( required=False, secret=True, default=None, - description=( - "Private Key Password, leave None if no password is set" - ), + description=("Private Key Password, leave None if no password is set"), ), ), required=False, @@ -396,9 +381,7 @@ def get_sqlalchemy_query(self, config: Mapping[str, Any]) -> dict: if config["ssl_enable"]: ssl_mode = config["ssl_mode"] query.update({"sslmode": ssl_mode}) - if ssl_mode in ("verify-ca", "verify-full") and config.get( - "ssl_certificate_authority" - ): + if ssl_mode in ("verify-ca", "verify-full") and config.get("ssl_certificate_authority"): query["sslrootcert"] = self.filepath_or_certificate( value=config["ssl_certificate_authority"], alternative_name=config["ssl_storage_directory"] + "/root.crt", @@ -583,10 +566,7 @@ def catalog(self) -> Catalog: for stream in super().catalog.streams: stream_modified = False new_stream = copy.deepcopy(stream) - if ( - new_stream.replication_method == "LOG_BASED" - and new_stream.schema.properties - ): + if new_stream.replication_method == "LOG_BASED" and new_stream.schema.properties: for property in new_stream.schema.properties.values(): if "null" not in property.type: if isinstance(property.type, list): @@ -641,12 +621,8 @@ def discover_streams(self) -> Sequence[Stream]: for catalog_entry in self.catalog_dict["streams"]: if catalog_entry["replication_method"] == "LOG_BASED": streams.append( - PostgresLogBasedStream( - self, catalog_entry, connector=self.connector - ) + PostgresLogBasedStream(self, catalog_entry, connector=self.connector) ) else: - streams.append( - PostgresStream(self, catalog_entry, connector=self.connector) - ) + streams.append(PostgresStream(self, catalog_entry, connector=self.connector)) return streams diff --git a/tests/test_core.py b/tests/test_core.py index e07eca65..5be84938 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,6 +1,8 @@ import copy import datetime import decimal +import json +from pathlib import Path import pytest import sqlalchemy as sa @@ -42,6 +44,15 @@ } +def _load_catalog(path: Path) -> Catalog: + return Catalog.from_dict(json.loads(path.read_text())) + + +DATADIR = Path("tests/resources") +FULL_CATALOG = _load_catalog(DATADIR / "data.json") +SELECTED_COLUMNS_ONLY_CATALOG = _load_catalog(DATADIR / "data_selected_columns_only.json") + + def setup_test_table(table_name, sqlalchemy_url): """setup any state specific to the execution of the given module.""" engine = sa.create_engine(sqlalchemy_url, future=True) @@ -73,9 +84,7 @@ def teardown_test_table(table_name, sqlalchemy_url): conn.execute(sa.text(f"DROP TABLE {table_name}")) -custom_test_replication_key = suites.SingerTestSuite( - kind="tap", tests=[TapTestReplicationKey] -) +custom_test_replication_key = suites.SingerTestSuite(kind="tap", tests=[TapTestReplicationKey]) custom_test_selected_columns_only = suites.SingerTestSuite( kind="tap", tests=[TapTestSelectedColumnsOnly] @@ -84,14 +93,14 @@ def teardown_test_table(table_name, sqlalchemy_url): TapPostgresTest = get_tap_test_class( tap_class=TapPostgres, config=SAMPLE_CONFIG, - catalog="tests/resources/data.json", + catalog=FULL_CATALOG, custom_suites=[custom_test_replication_key], ) TapPostgresTestNOSQLALCHEMY = get_tap_test_class( tap_class=TapPostgres, config=NO_SQLALCHEMY_CONFIG, - catalog="tests/resources/data.json", + catalog=FULL_CATALOG, custom_suites=[custom_test_replication_key], ) @@ -100,7 +109,7 @@ def teardown_test_table(table_name, sqlalchemy_url): TapPostgresTestSelectedColumnsOnly = get_tap_test_class( tap_class=TapPostgres, config=SAMPLE_CONFIG, - catalog="tests/resources/data_selected_columns_only.json", + catalog=SELECTED_COLUMNS_ONLY_CATALOG, custom_suites=[custom_test_selected_columns_only], ) @@ -182,17 +191,10 @@ def test_temporal_datatypes(): ) test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): - assert ( - schema_message["schema"]["properties"]["column_date"]["format"] - == "date" - ) + if "stream" in schema_message and schema_message["stream"] == altered_table_name: + assert schema_message["schema"]["properties"]["column_date"]["format"] == "date" assert ( - schema_message["schema"]["properties"]["column_timestamp"]["format"] - == "date-time" + schema_message["schema"]["properties"]["column_timestamp"]["format"] == "date-time" ) assert test_runner.records[altered_table_name][0] == { "column_date": "2022-03-19", @@ -246,10 +248,7 @@ def test_jsonb_json(): ) test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): + if "stream" in schema_message and schema_message["stream"] == altered_table_name: assert schema_message["schema"]["properties"]["column_jsonb"] == { "type": [ "string", @@ -320,10 +319,7 @@ def test_jsonb_array(): ) test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): + if "stream" in schema_message and schema_message["stream"] == altered_table_name: assert schema_message["schema"]["properties"]["column_jsonb_array"] == { "items": { "type": [ @@ -393,10 +389,7 @@ def test_json_as_object(): ) test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): + if "stream" in schema_message and schema_message["stream"] == altered_table_name: assert schema_message["schema"]["properties"]["column_jsonb"] == { "type": [ "object", @@ -467,10 +460,7 @@ def test_numeric_types(): test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): + if "stream" in schema_message and schema_message["stream"] == altered_table_name: props = schema_message["schema"]["properties"] assert "number" in props["my_numeric"]["type"] assert "number" in props["my_real"]["type"] @@ -523,10 +513,7 @@ def test_hstore(): test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): + if "stream" in schema_message and schema_message["stream"] == altered_table_name: assert schema_message["schema"]["properties"]["hstore_column"] == { "type": ["object", "null"], "additionalProperties": True, @@ -640,10 +627,7 @@ def test_invalid_python_dates(): # noqa: PLR0912 test_runner.sync_all() for schema_message in test_runner.schema_messages: - if ( - "stream" in schema_message - and schema_message["stream"] == altered_table_name - ): + if "stream" in schema_message and schema_message["stream"] == altered_table_name: assert schema_message["schema"]["properties"]["date"]["type"] == [ "string", "null", diff --git a/tests/test_log_based.py b/tests/test_log_based.py index a16ee347..0af7d995 100644 --- a/tests/test_log_based.py +++ b/tests/test_log_based.py @@ -70,9 +70,7 @@ def test_string_array_column(): as arrays (ex: "text[]") LOG_BASED replication can properly decode their value. """ table_name = "test_array_column" - engine = sa.create_engine( - "postgresql://postgres:postgres@localhost:5434/postgres", future=True - ) + engine = sa.create_engine("postgresql://postgres:postgres@localhost:5434/postgres", future=True) metadata_obj = sa.MetaData() table = sa.Table( diff --git a/tests/test_slot_name.py b/tests/test_slot_name.py index 8ef7f5df..b1909a5b 100644 --- a/tests/test_slot_name.py +++ b/tests/test_slot_name.py @@ -37,9 +37,7 @@ def test_multiple_slots(default_config: dict): tap_1 = TapPostgres(config=config_1, setup_mapper=False) tap_2 = TapPostgres(config=config_2, setup_mapper=False) - assert ( - tap_1.config["replication_slot_name"] != tap_2.config["replication_slot_name"] - ) + assert tap_1.config["replication_slot_name"] != tap_2.config["replication_slot_name"] assert tap_1.config["replication_slot_name"] == "slot_1" assert tap_2.config["replication_slot_name"] == "slot_2"