From fd1d79027b4c17c3a8297712afcba94167ffdd5f Mon Sep 17 00:00:00 2001 From: Yuce Tekol Date: Thu, 21 Feb 2019 18:41:22 +0300 Subject: [PATCH] Added use_manual_address client option --- CHANGELOG.md | 1 + integration_tests/test_client_it.py | 64 ++++++++++++++++++++++++++++- pilosa/client.py | 62 ++++++++++++++++++---------- 3 files changed, 104 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d69e9e1..adf5996 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,7 @@ * **master** * Added `field.rows` and `index.group_by` calls. * Deprecated `field.range` method. Use `field.row` with `from_` and/or `to` keywords instead. + * Added `use_manual_address` client option. * **v1.2.0** (2018-12-21) * **Compatible with Pilosa 1.2** diff --git a/integration_tests/test_client_it.py b/integration_tests/test_client_it.py index bed9319..b791e51 100644 --- a/integration_tests/test_client_it.py +++ b/integration_tests/test_client_it.py @@ -32,9 +32,9 @@ # import threading import unittest +from datetime import datetime from wsgiref.simple_server import make_server from wsgiref.util import setup_testing_defaults -from datetime import datetime try: from io import StringIO @@ -321,6 +321,41 @@ def test_csv_import(self): for result in response.results: self.assertEqual([], result.row.columns) + def test_csv_import_manual_address(self): + client = self.get_client_manual_address() + text = u""" + 10, 7 + 10, 5 + 2, 3 + 7, 1 + """ + reader = csv_column_reader(StringIO(text)) + field = self.index.field("importfield") + client.ensure_field(field) + client.import_field(field, reader) + bq = self.index.batch_query( + field.row(2), + field.row(7), + field.row(10), + ) + response = client.query(bq) + target = [3, 1, 5] + self.assertEqual(3, len(response.results)) + self.assertEqual(target, [result.row.columns[0] for result in response.results]) + + # test clear import + reader = csv_column_reader(StringIO(text)) + client.import_field(field, reader, clear=True) + bq = self.index.batch_query( + field.row(2), + field.row(7), + field.row(10), + ) + response = client.query(bq) + self.assertEqual(3, len(response.results)) + for result in response.results: + self.assertEqual([], result.row.columns) + def test_csv_roaring_import(self): client = self.get_client() text = u""" @@ -419,6 +454,28 @@ def test_csv_import_row_keys(self): self.assertEqual(3, len(response.results)) self.assertEqual(target, [result.row.columns[0] for result in response.results]) + def test_csv_import_row_keys_manual_address(self): + client = self.get_client_manual_address() + text = u""" + ten, 7 + ten, 5 + two, 3 + seven, 1 + """ + reader = csv_column_reader(StringIO(text), formatfunc=csv_row_key_column_id) + field = self.index.field("importfield-keys", keys=True) + client.ensure_field(field) + client.import_field(field, reader) + bq = self.index.batch_query( + field.row("two"), + field.row("seven"), + field.row("ten"), + ) + response = client.query(bq) + target = [3, 1, 5] + self.assertEqual(3, len(response.results)) + self.assertEqual(target, [result.row.columns[0] for result in response.results]) + def test_csv_import_time_field(self): text = u""" 1,10,683793200 @@ -686,6 +743,11 @@ def get_client(cls): server_address = cls.get_server_address() return Client(server_address, tls_skip_verify=True) + @classmethod + def get_client_manual_address(cls): + server_address = cls.get_server_address() + return Client(server_address, tls_skip_verify=True, use_manual_address=True) + @classmethod def get_server_address(cls): import os diff --git a/pilosa/client.py b/pilosa/client.py index d2a5b64..42580b9 100644 --- a/pilosa/client.py +++ b/pilosa/client.py @@ -94,7 +94,7 @@ class Client(object): def __init__(self, cluster_or_uri=None, connect_timeout=30000, socket_timeout=300000, pool_size_per_route=10, pool_size_total=100, retry_count=3, - tls_skip_verify=False, tls_ca_certificate_path=""): + tls_skip_verify=False, tls_ca_certificate_path="", use_manual_address=False): """Creates a Client. :param object cluster_or_uri: A ``pilosa.Cluster`` or ``pilosa.URI` instance @@ -107,20 +107,11 @@ def __init__(self, cluster_or_uri=None, connect_timeout=30000, socket_timeout=30 :param int retry_count: Number of connection trials :param bool tls_skip_verify: Do not verify the TLS certificate of the server (Not recommended for production) :param str tls_ca_certificate_path: Server's TLS certificate (Useful when using self-signed certificates) + :param bool use_manual_address: Forces the client to use only the manual server address * See `Pilosa Python Client/Server Interaction `_. """ - if cluster_or_uri is None: - self.cluster = Cluster(URI()) - elif isinstance(cluster_or_uri, Cluster): - self.cluster = cluster_or_uri.copy() - elif isinstance(cluster_or_uri, URI): - self.cluster = Cluster(cluster_or_uri) - elif isinstance(cluster_or_uri, str): - self.cluster = Cluster(URI.address(cluster_or_uri)) - else: - raise PilosaError("Invalid cluster_or_uri: %s" % cluster_or_uri) - + self.use_manual_address = use_manual_address self.connect_timeout = connect_timeout / 1000.0 self.socket_timeout = socket_timeout / 1000.0 self.pool_size_per_route = pool_size_per_route @@ -134,6 +125,26 @@ def __init__(self, cluster_or_uri=None, connect_timeout=30000, socket_timeout=30 self.__coordinator_lock = threading.RLock() self.__coordinator_uri = None + if cluster_or_uri is None: + self.cluster = Cluster(URI()) + elif isinstance(cluster_or_uri, Cluster): + self.cluster = cluster_or_uri.copy() + elif isinstance(cluster_or_uri, URI): + if use_manual_address: + self.__coordinator_uri = cluster_or_uri + self.__current_host = cluster_or_uri + else: + self.cluster = Cluster(cluster_or_uri) + elif isinstance(cluster_or_uri, str): + uri = URI.address(cluster_or_uri) + if use_manual_address: + self.__coordinator_uri = uri + self.__current_host = uri + else: + self.cluster = Cluster(uri) + else: + raise PilosaError("Invalid cluster_or_uri: %s" % cluster_or_uri) + def query(self, query, column_attrs=False, exclude_columns=False, exclude_attrs=False, shards=None): """Runs the given query against the server with the given options. @@ -320,10 +331,13 @@ def _import_data(self, field, shard, data, fast_import, clear): # sort by row_id then by column_id if not field.index.keys: data.sort(key=lambda col: (col.row_id, col.column_id)) - if field.index.keys or field.keys: - nodes = [self._fetch_coordinator_node()] + if self.use_manual_address: + nodes = [_Node.from_uri(self.__current_host)] else: - nodes = self._fetch_fragment_nodes(field.index.name, shard) + if field.index.keys or field.keys: + nodes = [self._fetch_coordinator_node()] + else: + nodes = self._fetch_fragment_nodes(field.index.name, shard) # copy client params client_params = {} for k,v in self.__dict__.items(): @@ -405,13 +419,14 @@ def __http_request(self, method, path, data=None, headers=None, use_coordinator= response = self.__client.request(method, uri, body=data, headers=headers) break except urllib3.exceptions.MaxRetryError as e: - if use_coordinator: - self.__coordinator_uri = None - self.logger.warning("Removed coordinator %s due to %s", self.__coordinator_uri, str(e)) - else: - self.cluster.remove_host(self.__current_host) - self.logger.warning("Removed %s from the cluster due to %s", self.__current_host, str(e)) - self.__current_host = None + if not self.use_manual_address: + if use_coordinator: + self.__coordinator_uri = None + self.logger.warning("Removed coordinator %s due to %s", self.__coordinator_uri, str(e)) + else: + self.cluster.remove_host(self.__current_host) + self.logger.warning("Removed %s from the cluster due to %s", self.__current_host, str(e)) + self.__current_host = None else: raise PilosaError("Tried %s hosts, still failing" % _MAX_HOSTS) @@ -775,6 +790,9 @@ def __init__(self, scheme, host, port): self.host = host self.port = port + @classmethod + def from_uri(cls, uri): + return cls(uri.scheme, uri.host, uri.port) @property def url(self):