diff --git a/nzpy/__init__.py b/nzpy/__init__.py index 683fb94..cf7de12 100644 --- a/nzpy/__init__.py +++ b/nzpy/__init__.py @@ -43,11 +43,11 @@ def connect( user, host='localhost', unix_sock=None, port=5432, database=None, password=None, ssl=None, securityLevel= 0, timeout=None, application_name=None, - max_prepared_statements=1000, datestyle = 'ISO', logLevel = 0, tcp_keepalive=True, char_varchar_encoding='latin', logOptions=LogOptions.Inherit): + max_prepared_statements=1000, datestyle = 'ISO', logLevel = 0, tcp_keepalive=True, char_varchar_encoding='latin', logOptions=LogOptions.Inherit, pgOptions=None): return Connection( user, host, unix_sock, port, database, password, ssl, securityLevel, timeout, - application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive,char_varchar_encoding, logOptions) + application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive,char_varchar_encoding, logOptions, pgOptions) apilevel = "2.0" diff --git a/nzpy/core.py b/nzpy/core.py index d725153..f85ce0c 100644 --- a/nzpy/core.py +++ b/nzpy/core.py @@ -1168,7 +1168,7 @@ def _getError(self, error): def __init__( self, user, host, unix_sock, port, database, password, ssl, - securityLevel, timeout, application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive, char_varchar_encoding, logOptions=LogOptions.Inherit): + securityLevel, timeout, application_name, max_prepared_statements, datestyle, logLevel, tcp_keepalive, char_varchar_encoding, logOptions=LogOptions.Inherit, pgOptions=None): self._char_varchar_encoding = char_varchar_encoding self._client_encoding = "utf8" self._commands_with_count = ( @@ -1528,7 +1528,7 @@ def conn_send_query(): COPY_OUT_RESPONSE: self.handle_COPY_OUT_RESPONSE} hs = handshake.Handshake(self._usock, self._sock, ssl, self.log) - response = hs.startup(database, securityLevel, user, password) + response = hs.startup(database, securityLevel, user, password, pgOptions) if response is not False: self._flush = response.flush diff --git a/nzpy/handshake.py b/nzpy/handshake.py index b28d903..51243e7 100644 --- a/nzpy/handshake.py +++ b/nzpy/handshake.py @@ -77,7 +77,7 @@ def __init__(self, _usock, _sock, ssl, log): self.guardium_clientHostName = gethostname() self.guardium_applName = path.basename(argv[0]) - def startup(self, database, securityLevel, user, password): + def startup(self, database, securityLevel, user, password, pgOptions): #Negotiate the handshake version (connection protocol) if not self.conn_handshake_negotiate(self._sock.write, self._sock.read, self._sock.flush, self._hsVersion, self._protocol2): @@ -85,7 +85,7 @@ def startup(self, database, securityLevel, user, password): return False self.log.info("Sending handshake information to server") - if not self.conn_send_handshake_info(self._sock.write, self._sock.read, self._sock.flush, database, securityLevel, self._hsVersion, self._protocol1, self._protocol2, user): + if not self.conn_send_handshake_info(self._sock.write, self._sock.read, self._sock.flush, database, securityLevel, self._hsVersion, self._protocol1, self._protocol2, user, pgOptions): self.log.warning("Error in conn_send_handshake_info") return False @@ -154,7 +154,7 @@ def conn_handshake_negotiate(self, _write, _read, _flush, _hsVersion, _protocol2 self.log.warning("Bad protocol error") return False - def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLevel, _hsVersion, _protocol1, _protocol2, user): + def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLevel, _hsVersion, _protocol1, _protocol2, user, pgOptions): #We need database information at the backend in order to #select security restrictions. So always send the database first @@ -170,9 +170,9 @@ def conn_send_handshake_info(self, _write, _read, _flush, _database, securityLev return False if self._hsVersion == CP_VERSION_6 or self._hsVersion == CP_VERSION_4: - return self.conn_send_handshake_version4(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user) + return self.conn_send_handshake_version4(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user, pgOptions) elif self._hsVersion == CP_VERSION_5 or self._hsVersion == CP_VERSION_3 or self._hsVersion == CP_VERSION_2: - return self.conn_send_handshake_version2(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user) + return self.conn_send_handshake_version2(self._sock.write, self._sock.read, self._sock.flush, self._protocol1, self._protocol2, self._hsVersion, user, pgOptions) return True @@ -309,7 +309,7 @@ def conn_secure_session(self, securityLevel): self.log.warning("Error: connection failed") return False - def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user): + def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user, pgOptions): if isinstance(user, str): user = user.encode('utf8') @@ -336,6 +336,13 @@ def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _proto if information == HSV2_REMOTE_PID: val = bytearray( core.h_pack(information) + core.i_pack(getpid())) + information = HSV2_OPTIONS + continue + + if information == HSV2_OPTIONS: + if pgOptions is not None: + val = bytearray( core.h_pack(information)) + val.extend( pgOptions.encode('utf8') + core.NULL_BYTE) information = HSV2_CLIENT_TYPE continue @@ -364,7 +371,7 @@ def conn_send_handshake_version2(self, _write, _read, _flush, _protocol1, _proto self.log.warning("ERROR_CONN_FAIL") return False - def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user): + def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _protocol2, _hsVersion, user, pgOptions): if isinstance(user, str): user = user.encode('utf8') @@ -419,6 +426,13 @@ def conn_send_handshake_version4(self, _write, _read, _flush, _protocol1, _proto if information == HSV2_REMOTE_PID: val = bytearray( core.h_pack(information) + core.i_pack(getpid())) + information = HSV2_OPTIONS + continue + + if information == HSV2_OPTIONS: + if pgOptions is not None: + val = bytearray( core.h_pack(information)) + val.extend( pgOptions.encode('utf8') + core.NULL_BYTE) information = HSV2_CLIENT_TYPE continue