Skip to content
This repository was archived by the owner on Sep 28, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* **master**
* Added `field.rows` and `index.group_by` calls.
* Deprecated `field.range` method. Use `field.row` with `from_` and/or `to` keywords instead.
* Added `use_manual_address` client option.

* **v1.2.0** (2018-12-21)
* **Compatible with Pilosa 1.2**
Expand Down
64 changes: 63 additions & 1 deletion integration_tests/test_client_it.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
#
import threading
import unittest
from datetime import datetime
from wsgiref.simple_server import make_server
from wsgiref.util import setup_testing_defaults
from datetime import datetime

try:
from io import StringIO
Expand Down Expand Up @@ -321,6 +321,41 @@ def test_csv_import(self):
for result in response.results:
self.assertEqual([], result.row.columns)

def test_csv_import_manual_address(self):
client = self.get_client_manual_address()
text = u"""
10, 7
10, 5
2, 3
7, 1
"""
reader = csv_column_reader(StringIO(text))
field = self.index.field("importfield")
client.ensure_field(field)
client.import_field(field, reader)
bq = self.index.batch_query(
field.row(2),
field.row(7),
field.row(10),
)
response = client.query(bq)
target = [3, 1, 5]
self.assertEqual(3, len(response.results))
self.assertEqual(target, [result.row.columns[0] for result in response.results])

# test clear import
reader = csv_column_reader(StringIO(text))
client.import_field(field, reader, clear=True)
bq = self.index.batch_query(
field.row(2),
field.row(7),
field.row(10),
)
response = client.query(bq)
self.assertEqual(3, len(response.results))
for result in response.results:
self.assertEqual([], result.row.columns)

def test_csv_roaring_import(self):
client = self.get_client()
text = u"""
Expand Down Expand Up @@ -419,6 +454,28 @@ def test_csv_import_row_keys(self):
self.assertEqual(3, len(response.results))
self.assertEqual(target, [result.row.columns[0] for result in response.results])

def test_csv_import_row_keys_manual_address(self):
client = self.get_client_manual_address()
text = u"""
ten, 7
ten, 5
two, 3
seven, 1
"""
reader = csv_column_reader(StringIO(text), formatfunc=csv_row_key_column_id)
field = self.index.field("importfield-keys", keys=True)
client.ensure_field(field)
client.import_field(field, reader)
bq = self.index.batch_query(
field.row("two"),
field.row("seven"),
field.row("ten"),
)
response = client.query(bq)
target = [3, 1, 5]
self.assertEqual(3, len(response.results))
self.assertEqual(target, [result.row.columns[0] for result in response.results])

def test_csv_import_time_field(self):
text = u"""
1,10,683793200
Expand Down Expand Up @@ -686,6 +743,11 @@ def get_client(cls):
server_address = cls.get_server_address()
return Client(server_address, tls_skip_verify=True)

@classmethod
def get_client_manual_address(cls):
server_address = cls.get_server_address()
return Client(server_address, tls_skip_verify=True, use_manual_address=True)

@classmethod
def get_server_address(cls):
import os
Expand Down
62 changes: 40 additions & 22 deletions pilosa/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Client(object):

def __init__(self, cluster_or_uri=None, connect_timeout=30000, socket_timeout=300000,
pool_size_per_route=10, pool_size_total=100, retry_count=3,
tls_skip_verify=False, tls_ca_certificate_path=""):
tls_skip_verify=False, tls_ca_certificate_path="", use_manual_address=False):
"""Creates a Client.

:param object cluster_or_uri: A ``pilosa.Cluster`` or ``pilosa.URI` instance
Expand All @@ -107,20 +107,11 @@ def __init__(self, cluster_or_uri=None, connect_timeout=30000, socket_timeout=30
:param int retry_count: Number of connection trials
:param bool tls_skip_verify: Do not verify the TLS certificate of the server (Not recommended for production)
:param str tls_ca_certificate_path: Server's TLS certificate (Useful when using self-signed certificates)
:param bool use_manual_address: Forces the client to use only the manual server address

* See `Pilosa Python Client/Server Interaction <https://github.com/pilosa/python-pilosa/blob/master/docs/server-interaction.md>`_.
"""
if cluster_or_uri is None:
self.cluster = Cluster(URI())
elif isinstance(cluster_or_uri, Cluster):
self.cluster = cluster_or_uri.copy()
elif isinstance(cluster_or_uri, URI):
self.cluster = Cluster(cluster_or_uri)
elif isinstance(cluster_or_uri, str):
self.cluster = Cluster(URI.address(cluster_or_uri))
else:
raise PilosaError("Invalid cluster_or_uri: %s" % cluster_or_uri)

self.use_manual_address = use_manual_address
self.connect_timeout = connect_timeout / 1000.0
self.socket_timeout = socket_timeout / 1000.0
self.pool_size_per_route = pool_size_per_route
Expand All @@ -134,6 +125,26 @@ def __init__(self, cluster_or_uri=None, connect_timeout=30000, socket_timeout=30
self.__coordinator_lock = threading.RLock()
self.__coordinator_uri = None

if cluster_or_uri is None:
self.cluster = Cluster(URI())
elif isinstance(cluster_or_uri, Cluster):
self.cluster = cluster_or_uri.copy()
elif isinstance(cluster_or_uri, URI):
if use_manual_address:
self.__coordinator_uri = cluster_or_uri
self.__current_host = cluster_or_uri
else:
self.cluster = Cluster(cluster_or_uri)
elif isinstance(cluster_or_uri, str):
uri = URI.address(cluster_or_uri)
if use_manual_address:
self.__coordinator_uri = uri
self.__current_host = uri
else:
self.cluster = Cluster(uri)
else:
raise PilosaError("Invalid cluster_or_uri: %s" % cluster_or_uri)

def query(self, query, column_attrs=False, exclude_columns=False, exclude_attrs=False, shards=None):
"""Runs the given query against the server with the given options.

Expand Down Expand Up @@ -320,10 +331,13 @@ def _import_data(self, field, shard, data, fast_import, clear):
# sort by row_id then by column_id
if not field.index.keys:
data.sort(key=lambda col: (col.row_id, col.column_id))
if field.index.keys or field.keys:
nodes = [self._fetch_coordinator_node()]
if self.use_manual_address:
nodes = [_Node.from_uri(self.__current_host)]
else:
nodes = self._fetch_fragment_nodes(field.index.name, shard)
if field.index.keys or field.keys:
nodes = [self._fetch_coordinator_node()]
else:
nodes = self._fetch_fragment_nodes(field.index.name, shard)
# copy client params
client_params = {}
for k,v in self.__dict__.items():
Expand Down Expand Up @@ -405,13 +419,14 @@ def __http_request(self, method, path, data=None, headers=None, use_coordinator=
response = self.__client.request(method, uri, body=data, headers=headers)
break
except urllib3.exceptions.MaxRetryError as e:
if use_coordinator:
self.__coordinator_uri = None
self.logger.warning("Removed coordinator %s due to %s", self.__coordinator_uri, str(e))
else:
self.cluster.remove_host(self.__current_host)
self.logger.warning("Removed %s from the cluster due to %s", self.__current_host, str(e))
self.__current_host = None
if not self.use_manual_address:
if use_coordinator:
self.__coordinator_uri = None
self.logger.warning("Removed coordinator %s due to %s", self.__coordinator_uri, str(e))
else:
self.cluster.remove_host(self.__current_host)
self.logger.warning("Removed %s from the cluster due to %s", self.__current_host, str(e))
self.__current_host = None
else:
raise PilosaError("Tried %s hosts, still failing" % _MAX_HOSTS)

Expand Down Expand Up @@ -775,6 +790,9 @@ def __init__(self, scheme, host, port):
self.host = host
self.port = port

@classmethod
def from_uri(cls, uri):
return cls(uri.scheme, uri.host, uri.port)

@property
def url(self):
Expand Down