diff --git a/pyproject.toml b/pyproject.toml index f65f05b..6e4f67d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,6 @@ dependencies = [ "psycopg[binary]==3.2.10", "psycopg2-binary==2.9.10", "sqlalchemy~=2.0", - "sshtunnel==0.4.0", "singer-sdk[sql]~=0.52.0", ] @@ -83,10 +82,6 @@ warn_redundant_casts = true warn_unused_configs = true warn_unused_ignores = true -[[tool.mypy.overrides]] -module = ["sshtunnel"] -ignore_missing_imports = true - [build-system] requires = [ "hatchling==1.27.0", diff --git a/target_postgres/connector.py b/target_postgres/connector.py index 3ff4921..a427cb9 100644 --- a/target_postgres/connector.py +++ b/target_postgres/connector.py @@ -7,9 +7,11 @@ import itertools import math import signal +import socket import sys +import threading import typing as t -from contextlib import contextmanager +from contextlib import contextmanager, suppress from functools import cached_property from os import chmod, path from typing import cast @@ -40,12 +42,169 @@ TIMESTAMP, TypeDecorator, ) -from sshtunnel import SSHTunnelForwarder if t.TYPE_CHECKING: from singer_sdk.sql.connector import FullyQualifiedName +class SSHTunnelForwarder: + """SSH Tunnel forwarder using paramiko. + + This class provides SSH tunnel functionality similar to sshtunnel package, + but implemented directly with paramiko. + """ + + def __init__( + self, + ssh_address_or_host: tuple[str, int], + ssh_username: str, + ssh_pkey: paramiko.PKey, + ssh_private_key_password: str | None, + remote_bind_address: tuple[str, int], + ) -> None: + """Initialize SSH tunnel forwarder. + + Args: + ssh_address_or_host: Tuple of (ssh_host, ssh_port) + ssh_username: SSH username + ssh_pkey: Paramiko private key object + ssh_private_key_password: Private key password (optional) + remote_bind_address: Tuple of (remote_host, remote_port) + """ + self.ssh_host, self.ssh_port = ssh_address_or_host + self.ssh_username = ssh_username + self.ssh_pkey = ssh_pkey + self.ssh_private_key_password = ssh_private_key_password + self.remote_bind_host, self.remote_bind_port = remote_bind_address + + self.ssh_client: paramiko.SSHClient | None = None + self.local_bind_host = "127.0.0.1" + self.local_bind_port: int | None = None + self._server_socket: socket.socket | None = None + self._thread: threading.Thread | None = None + self._stop_event = threading.Event() + + def start(self) -> None: + """Start the SSH tunnel.""" + # Create SSH client + self.ssh_client = paramiko.SSHClient() + self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + # Connect to SSH server + self.ssh_client.connect( + hostname=self.ssh_host, + port=self.ssh_port, + username=self.ssh_username, + pkey=self.ssh_pkey, + passphrase=self.ssh_private_key_password, + ) + + # Create local socket for port forwarding + self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + self._server_socket.bind((self.local_bind_host, 0)) + self._server_socket.listen(5) + + # Get the dynamically assigned local port + self.local_bind_port = self._server_socket.getsockname()[1] + + # Start forwarding thread + self._thread = threading.Thread(target=self._forward_tunnel, daemon=True) + self._thread.start() + + def _forward_tunnel(self) -> None: + """Forward connections through the SSH tunnel.""" + if self._server_socket is None or self.ssh_client is None: + return + + while not self._stop_event.is_set(): + try: + # Set timeout so we can check stop event periodically + self._server_socket.settimeout(1.0) + try: + local_socket, _ = self._server_socket.accept() + except TimeoutError: + continue + + # Create channel through SSH tunnel + transport = self.ssh_client.get_transport() + if transport is None: + local_socket.close() + continue + + channel = transport.open_channel( + "direct-tcpip", + (self.remote_bind_host, self.remote_bind_port), + local_socket.getpeername(), + ) + + # Start forwarding data between local socket and channel + threading.Thread( + target=self._forward_data, + args=(local_socket, channel), + daemon=True, + ).start() + except OSError: + if not self._stop_event.is_set(): + break + + def _forward_data( + self, local_socket: socket.socket, channel: paramiko.Channel + ) -> None: + """Forward data between local socket and SSH channel. + + Args: + local_socket: Local socket + channel: SSH channel + """ + try: + + def forward_local_to_remote(): + while True: + data = local_socket.recv(4096) + if len(data) == 0: + break + channel.send(data) + channel.close() + + def forward_remote_to_local(): + while True: + data = channel.recv(4096) + if len(data) == 0: + break + local_socket.send(data) + local_socket.close() + + # Start both forwarding directions + t1 = threading.Thread(target=forward_local_to_remote, daemon=True) + t2 = threading.Thread(target=forward_remote_to_local, daemon=True) + t1.start() + t2.start() + t1.join() + t2.join() + except OSError: + pass + finally: + with suppress(OSError): + local_socket.close() + with suppress(OSError): + channel.close() + + def stop(self) -> None: + """Stop the SSH tunnel.""" + self._stop_event.set() + + if self._server_socket: + with suppress(OSError): + self._server_socket.close() + + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=2.0) + + if self.ssh_client: + self.ssh_client.close() + + class JSONSchemaToPostgres(JSONSchemaToSQL): """Convert JSON Schema types to Postgres types.""" @@ -88,10 +247,13 @@ def __init__(self, config: dict) -> None: """ url: URL = make_url(self.get_sqlalchemy_url(config=config)) ssh_config = config.get("ssh_tunnel", {}) - self.ssh_tunnel: SSHTunnelForwarder + self.ssh_tunnel: SSHTunnelForwarder | None = None if ssh_config.get("enable", False): # Return a new URL with SSH tunnel parameters + if url.host is None or url.port is None: + msg = "Database host and port must be specified for SSH tunnel" + raise ValueError(msg) self.ssh_tunnel = SSHTunnelForwarder( ssh_address_or_host=(ssh_config["host"], ssh_config["port"]), ssh_username=ssh_config["username"], diff --git a/uv.lock b/uv.lock index d2e7cea..06a616e 100644 --- a/uv.lock +++ b/uv.lock @@ -157,7 +157,7 @@ name = "cffi" version = "2.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "pycparser", marker = "implementation_name != 'PyPy'" }, + { name = "pycparser", marker = "implementation_name != 'PyPy' and platform_python_implementation != 'PyPy'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529", size = 523588, upload-time = "2025-09-08T23:24:04.541Z" } wheels = [ @@ -507,7 +507,7 @@ name = "importlib-metadata" version = "8.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "zipp" }, + { name = "zipp", marker = "python_full_version < '3.14' or platform_python_implementation == 'PyPy'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } wheels = [ @@ -589,7 +589,6 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "singer-sdk", extra = ["sql"] }, { name = "sqlalchemy" }, - { name = "sshtunnel" }, ] [package.optional-dependencies] @@ -633,7 +632,6 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "==2.9.10" }, { name = "singer-sdk", extras = ["sql"], specifier = "~=0.52.0" }, { name = "sqlalchemy", specifier = "~=2.0" }, - { name = "sshtunnel", specifier = "==0.4.0" }, ] provides-extras = ["faker"] @@ -1441,18 +1439,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/5e/6a29fa884d9fb7ddadf6b69490a9d45fded3b38541713010dad16b77d015/sqlalchemy-2.0.44-py3-none-any.whl", hash = "sha256:19de7ca1246fbef9f9d1bff8f1ab25641569df226364a0e40457dc5457c54b05", size = 1928718, upload-time = "2025-10-10T15:29:45.32Z" }, ] -[[package]] -name = "sshtunnel" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "paramiko" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/8d/ad/4c587adf79865be268ee0b6bd52cfaa7a75d827a23ced072dc5ab554b4af/sshtunnel-0.4.0.tar.gz", hash = "sha256:e7cb0ea774db81bf91844db22de72a40aae8f7b0f9bb9ba0f666d474ef6bf9fc", size = 62716, upload-time = "2021-01-11T13:26:32.975Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/58/13/8476c4328dcadfe26f8bd7f3a1a03bf9ddb890a7e7b692f54a179bc525bf/sshtunnel-0.4.0-py2.py3-none-any.whl", hash = "sha256:98e54c26f726ab8bd42b47a3a21fca5c3e60f58956f0f70de2fb8ab0046d0606", size = 24729, upload-time = "2021-01-11T13:26:29.969Z" }, -] - [[package]] name = "tap-countries" version = "0.1.0"