Skip to content

Commit

Permalink
Initial commit, add way to pass pgOptions when establishing connection (
Browse files Browse the repository at this point in the history
#44)

Co-authored-by: Michael DeRoy <mkderoy@ibm.com>
  • Loading branch information
mderoy and Michael DeRoy committed Mar 31, 2021
1 parent 1a148fe commit e62d781
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions nzpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@
def connect(
user, host='localhost', unix_sock=None, port=5432, database=None,
password=None, ssl=None, securityLevel= 0, timeout=None, application_name=None,
max_prepared_statements=1000, datestyle = 'ISO', logLevel = 0, tcp_keepalive=True, char_varchar_encoding='latin', logOptions=LogOptions.Inherit):
max_prepared_statements=1000, datestyle = 'ISO', logLevel = 0, tcp_keepalive=True, char_varchar_encoding='latin', logOptions=LogOptions.Inherit, pgOptions=None):

return Connection(
user, host, unix_sock, port, database, password, ssl, securityLevel, timeout,
application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive,char_varchar_encoding, logOptions)
application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive,char_varchar_encoding, logOptions, pgOptions)


apilevel = "2.0"
Expand Down
4 changes: 2 additions & 2 deletions nzpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def _getError(self, error):

def __init__(
self, user, host, unix_sock, port, database, password, ssl,
securityLevel, timeout, application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive, char_varchar_encoding, logOptions=LogOptions.Inherit):
securityLevel, timeout, application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive, char_varchar_encoding, logOptions=LogOptions.Inherit, pgOptions=None):
self._char_varchar_encoding = char_varchar_encoding
self._client_encoding = "utf8"
self._commands_with_count = (
Expand Down Expand Up @@ -1528,7 +1528,7 @@ def conn_send_query():
COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE}

hs = handshake.Handshake(self._usock, self._sock, ssl, self.log)
response = hs.startup(database, securityLevel, user, password)
response = hs.startup(database, securityLevel, user, password, pgOptions)

if response is not False:
self._flush = response.flush
Expand Down
28 changes: 21 additions & 7 deletions nzpy/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@ def __init__(self, _usock, _sock, ssl, log):
self.guardium_clientHostName = gethostname()
self.guardium_applName = path.basename(argv[0])

def startup(self, database, securityLevel, user, password):
def startup(self, database, securityLevel, user, password, pgOptions):

#Negotiate the handshake version (connection protocol)
if not self.conn_handshake_negotiate(self._sock.write, self._sock.read, self._sock.flush, self._hsVersion, self._protocol2):
self.log.info("Handshake negotiation unsuccessful")
return False

self.log.info("Sending handshake information to server")
if not self.conn_send_handshake_info(self._sock.write, self._sock.read, self._sock.flush, database, securityLevel, self._hsVersion, self._protocol1, self._protocol2, user):
if not self.conn_send_handshake_info(self._sock.write, self._sock.read, self._sock.flush, database, securityLevel, self._hsVersion, self._protocol1, self._protocol2, user, pgOptions):
self.log.warning("Error in conn_send_handshake_info")
return False

Expand Down Expand Up @@ -154,7 +154,7 @@ def conn_handshake_negotiate(self, _write, _read, _flush, _hsVersion, _protocol2
self.log.warning("Bad protocol error")
return False

def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLevel, _hsVersion, _protocol1, _protocol2, user):
def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLevel, _hsVersion, _protocol1, _protocol2, user, pgOptions):

#We need database information at the backend in order to
#select security restrictions. So always send the database first
Expand All @@ -170,9 +170,9 @@ def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLev
return False

if self._hsVersion == CP_VERSION_6 or self._hsVersion == CP_VERSION_4:
return self.conn_send_handshake_version4(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user)
return self.conn_send_handshake_version4(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user, pgOptions)
elif self._hsVersion == CP_VERSION_5 or self._hsVersion == CP_VERSION_3 or self._hsVersion == CP_VERSION_2:
return self.conn_send_handshake_version2(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user)
return self.conn_send_handshake_version2(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user, pgOptions)

return True

Expand Down Expand Up @@ -309,7 +309,7 @@ def conn_secure_session(self, securityLevel):
self.log.warning("Error: connection failed")
return False

def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user):
def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user, pgOptions):

if isinstance(user, str):
user = user.encode('utf8')
Expand All @@ -336,6 +336,13 @@ def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _proto

if information == HSV2_REMOTE_PID:
val = bytearray( core.h_pack(information) + core.i_pack(getpid()))
information = HSV2_OPTIONS
continue

if information == HSV2_OPTIONS:
if pgOptions is not None:
val = bytearray( core.h_pack(information))
val.extend( pgOptions.encode('utf8') + core.NULL_BYTE)
information = HSV2_CLIENT_TYPE
continue

Expand Down Expand Up @@ -364,7 +371,7 @@ def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _proto
self.log.warning("ERROR_CONN_FAIL")
return False

def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user):
def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user, pgOptions):

if isinstance(user, str):
user = user.encode('utf8')
Expand Down Expand Up @@ -419,6 +426,13 @@ def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _proto

if information == HSV2_REMOTE_PID:
val = bytearray( core.h_pack(information) + core.i_pack(getpid()))
information = HSV2_OPTIONS
continue

if information == HSV2_OPTIONS:
if pgOptions is not None:
val = bytearray( core.h_pack(information))
val.extend( pgOptions.encode('utf8') + core.NULL_BYTE)
information = HSV2_CLIENT_TYPE
continue

Expand Down

0 comments on commit e62d781

Please sign in to comment.