Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ dependencies = [
"psycopg[binary]==3.2.10",
"psycopg2-binary==2.9.10",
"sqlalchemy~=2.0",
"sshtunnel==0.4.0",
"singer-sdk[sql]~=0.52.0",
]

Expand Down Expand Up @@ -83,10 +82,6 @@ warn_redundant_casts = true
warn_unused_configs = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module = ["sshtunnel"]
ignore_missing_imports = true

[build-system]
requires = [
"hatchling==1.27.0",
Expand Down
168 changes: 165 additions & 3 deletions target_postgres/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import itertools
import math
import signal
import socket
import sys
import threading
import typing as t
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from functools import cached_property
from os import chmod, path
from typing import cast
Expand Down Expand Up @@ -40,12 +42,169 @@
TIMESTAMP,
TypeDecorator,
)
from sshtunnel import SSHTunnelForwarder

if t.TYPE_CHECKING:
from singer_sdk.sql.connector import FullyQualifiedName


class SSHTunnelForwarder:
"""SSH Tunnel forwarder using paramiko.

This class provides SSH tunnel functionality similar to sshtunnel package,
but implemented directly with paramiko.
"""

def __init__(
self,
ssh_address_or_host: tuple[str, int],
ssh_username: str,
ssh_pkey: paramiko.PKey,
ssh_private_key_password: str | None,
remote_bind_address: tuple[str, int],
) -> None:
"""Initialize SSH tunnel forwarder.

Args:
ssh_address_or_host: Tuple of (ssh_host, ssh_port)
ssh_username: SSH username
ssh_pkey: Paramiko private key object
ssh_private_key_password: Private key password (optional)
remote_bind_address: Tuple of (remote_host, remote_port)
"""
self.ssh_host, self.ssh_port = ssh_address_or_host
self.ssh_username = ssh_username
self.ssh_pkey = ssh_pkey
self.ssh_private_key_password = ssh_private_key_password
self.remote_bind_host, self.remote_bind_port = remote_bind_address

self.ssh_client: paramiko.SSHClient | None = None
self.local_bind_host = "127.0.0.1"
self.local_bind_port: int | None = None
self._server_socket: socket.socket | None = None
self._thread: threading.Thread | None = None
self._stop_event = threading.Event()

def start(self) -> None:
"""Start the SSH tunnel."""
# Create SSH client
self.ssh_client = paramiko.SSHClient()
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())

# Connect to SSH server
self.ssh_client.connect(
hostname=self.ssh_host,
port=self.ssh_port,
username=self.ssh_username,
pkey=self.ssh_pkey,
passphrase=self.ssh_private_key_password,
)

# Create local socket for port forwarding
self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self._server_socket.bind((self.local_bind_host, 0))
self._server_socket.listen(5)

# Get the dynamically assigned local port
self.local_bind_port = self._server_socket.getsockname()[1]

# Start forwarding thread
self._thread = threading.Thread(target=self._forward_tunnel, daemon=True)
self._thread.start()

def _forward_tunnel(self) -> None:
"""Forward connections through the SSH tunnel."""
if self._server_socket is None or self.ssh_client is None:
return

while not self._stop_event.is_set():
try:
# Set timeout so we can check stop event periodically
self._server_socket.settimeout(1.0)
try:
local_socket, _ = self._server_socket.accept()
except TimeoutError:
continue

# Create channel through SSH tunnel
transport = self.ssh_client.get_transport()
if transport is None:
local_socket.close()
continue

channel = transport.open_channel(
"direct-tcpip",
(self.remote_bind_host, self.remote_bind_port),
local_socket.getpeername(),
)

# Start forwarding data between local socket and channel
threading.Thread(
target=self._forward_data,
args=(local_socket, channel),
daemon=True,
).start()
except OSError:
if not self._stop_event.is_set():
break

def _forward_data(
self, local_socket: socket.socket, channel: paramiko.Channel
) -> None:
"""Forward data between local socket and SSH channel.

Args:
local_socket: Local socket
channel: SSH channel
"""
try:

def forward_local_to_remote():
while True:
data = local_socket.recv(4096)
if len(data) == 0:
break
channel.send(data)
channel.close()

def forward_remote_to_local():
while True:
data = channel.recv(4096)
if len(data) == 0:
break
local_socket.send(data)
local_socket.close()

# Start both forwarding directions
t1 = threading.Thread(target=forward_local_to_remote, daemon=True)
t2 = threading.Thread(target=forward_remote_to_local, daemon=True)
t1.start()
t2.start()
t1.join()
t2.join()
except OSError:
pass
finally:
with suppress(OSError):
local_socket.close()
with suppress(OSError):
channel.close()

def stop(self) -> None:
"""Stop the SSH tunnel."""
self._stop_event.set()

if self._server_socket:
with suppress(OSError):
self._server_socket.close()

if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)

if self.ssh_client:
self.ssh_client.close()


class JSONSchemaToPostgres(JSONSchemaToSQL):
"""Convert JSON Schema types to Postgres types."""

Expand Down Expand Up @@ -88,10 +247,13 @@ def __init__(self, config: dict) -> None:
"""
url: URL = make_url(self.get_sqlalchemy_url(config=config))
ssh_config = config.get("ssh_tunnel", {})
self.ssh_tunnel: SSHTunnelForwarder
self.ssh_tunnel: SSHTunnelForwarder | None = None

if ssh_config.get("enable", False):
# Return a new URL with SSH tunnel parameters
if url.host is None or url.port is None:
msg = "Database host and port must be specified for SSH tunnel"
raise ValueError(msg)
self.ssh_tunnel = SSHTunnelForwarder(
ssh_address_or_host=(ssh_config["host"], ssh_config["port"]),
ssh_username=ssh_config["username"],
Expand Down
18 changes: 2 additions & 16 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.