Skip to content

Commit

Permalink
Initial commit, add way to pass pgOptions when establishing connection
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael DeRoy committed Mar 30, 2021
1 parent 1a148fe commit 74f5089
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
4 changes: 2 additions & 2 deletions nzpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def _getError(self, error):

def __init__(
self, user, host, unix_sock, port, database, password, ssl,
securityLevel, timeout, application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive, char_varchar_encoding, logOptions=LogOptions.Inherit):
securityLevel, timeout, application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive, char_varchar_encoding, logOptions=LogOptions.Inherit, pgOptions=None):
self._char_varchar_encoding = char_varchar_encoding
self._client_encoding = "utf8"
self._commands_with_count = (
Expand Down Expand Up @@ -1528,7 +1528,7 @@ def conn_send_query():
COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE}

hs = handshake.Handshake(self._usock, self._sock, ssl, self.log)
response = hs.startup(database, securityLevel, user, password)
response = hs.startup(database, securityLevel, user, password, pgOptions)

if response is not False:
self._flush = response.flush
Expand Down
28 changes: 21 additions & 7 deletions nzpy/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def __init__(self, _usock, _sock, ssl, log):
self.guardium_clientHostName = gethostname()
self.guardium_applName = path.basename(argv[0])

def startup(self, database, securityLevel, user, password):
def startup(self, database, securityLevel, user, password, pgOptions):

#Negotiate the handshake version (connection protocol)
if not self.conn_handshake_negotiate(self._sock.write, self._sock.read, self._sock.flush, self._hsVersion, self._protocol2):
self.log.info("Handshake negotiation unsuccessful")
return False

self.log.info("Sending handshake information to server")
if not self.conn_send_handshake_info(self._sock.write, self._sock.read, self._sock.flush, database, securityLevel, self._hsVersion, self._protocol1, self._protocol2, user):
if not self.conn_send_handshake_info(self._sock.write, self._sock.read, self._sock.flush, database, securityLevel, self._hsVersion, self._protocol1, self._protocol2, user, pgOptions):
self.log.warning("Error in conn_send_handshake_info")
return False

Expand Down Expand Up @@ -154,7 +154,7 @@ def conn_handshake_negotiate(self, _write, _read, _flush, _hsVersion, _protocol2
self.log.warning("Bad protocol error")
return False

def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLevel, _hsVersion, _protocol1, _protocol2, user):
def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLevel, _hsVersion, _protocol1, _protocol2, user, pgOptions):

#We need database information at the backend in order to
#select security restrictions. So always send the database first
Expand All @@ -170,9 +170,9 @@ def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLev
return False

if self._hsVersion == CP_VERSION_6 or self._hsVersion == CP_VERSION_4:
return self.conn_send_handshake_version4(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user)
return self.conn_send_handshake_version4(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user, pgOptions)
elif self._hsVersion == CP_VERSION_5 or self._hsVersion == CP_VERSION_3 or self._hsVersion == CP_VERSION_2:
return self.conn_send_handshake_version2(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user)
return self.conn_send_handshake_version2(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user, pgOptions)

return True

Expand Down Expand Up @@ -309,7 +309,7 @@ def conn_secure_session(self, securityLevel):
self.log.warning("Error: connection failed")
return False

def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user):
def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user, pgOptions):

if isinstance(user, str):
user = user.encode('utf8')
Expand All @@ -336,6 +336,13 @@ def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _proto

if information == HSV2_REMOTE_PID:
val = bytearray( core.h_pack(information) + core.i_pack(getpid()))
information = HSV2_OPTIONS
continue

if information == HSV2_OPTIONS:
if pgOptions is not None:
val = bytearray( core.h_pack(information))
val.extend( pgOptions.encode('utf8') + core.NULL_BYTE)
information = HSV2_CLIENT_TYPE
continue

Expand Down Expand Up @@ -364,7 +371,7 @@ def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _proto
self.log.warning("ERROR_CONN_FAIL")
return False

def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user):
def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user, pgOptions):

if isinstance(user, str):
user = user.encode('utf8')
Expand Down Expand Up @@ -419,6 +426,13 @@ def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _proto

if information == HSV2_REMOTE_PID:
val = bytearray( core.h_pack(information) + core.i_pack(getpid()))
information = HSV2_OPTIONS
continue

if information == HSV2_OPTIONS:
if pgOptions is not None:
val = bytearray( core.h_pack(information))
val.extend( pgOptions.encode('utf8') + core.NULL_BYTE)
information = HSV2_CLIENT_TYPE
continue

Expand Down

0 comments on commit 74f5089

Please sign in to comment.