diff --git a/elasticmagic/agg.py b/elasticmagic/agg.py index cc02cfe..873739e 100644 --- a/elasticmagic/agg.py +++ b/elasticmagic/agg.py @@ -37,8 +37,8 @@ def clone(self): return self.__class__(aggs=self._aggs, **self.params) @_with_clone - def aggregations(self, *args, **aggs): - if args == (None,): + def aggregations(self, aggs): + if aggs is None: self._aggs = Params() else: self._aggs = Params(dict(self._aggs), **aggs) diff --git a/elasticmagic/attribute.py b/elasticmagic/attribute.py index 3d2f5e4..60f9982 100644 --- a/elasticmagic/attribute.py +++ b/elasticmagic/attribute.py @@ -73,10 +73,7 @@ def __get__(self, obj, type=None): if obj is None: return self - dict_ = obj.__dict__ - if self._attr in obj.__dict__: - return dict_[self._attr] - dict_[self._attr] = None + obj.__dict__[self._attr] = None return None def _collect_doc_classes(self): diff --git a/elasticmagic/cluster.py b/elasticmagic/cluster.py index cf3fc10..ddbe4b0 100644 --- a/elasticmagic/cluster.py +++ b/elasticmagic/cluster.py @@ -1,5 +1,6 @@ from .util import clean_params from .index import Index +from .search import SearchQuery from .result import Result, BulkResult from .document import DynamicDocument from .expression import Params diff --git a/elasticmagic/result.py b/elasticmagic/result.py index c216459..dcbeae5 100644 --- a/elasticmagic/result.py +++ b/elasticmagic/result.py @@ -1,5 +1,7 @@ +import collections + from .agg import BucketAgg -from .document import Document +from .document import DynamicDocument class Result(object): @@ -7,22 +9,24 @@ def __init__(self, raw_result, aggregations=None, doc_cls=None, instance_mapper=None): self.raw = raw_result self._query_aggs = aggregations or {} - if not doc_cls: - self._doc_classes = (Document,) - elif isinstance(doc_cls, tuple): - self._doc_classes = doc_cls + + if doc_cls is None: + doc_classes = () + elif not isinstance(doc_cls, collections.Iterable): + doc_classes = (doc_cls,) else: - self._doc_classes = (doc_cls,) - self._doc_cls_map = {doc_cls.__doc_type__: doc_cls for doc_cls in self._doc_classes} + doc_classes = doc_cls + self._doc_cls_map = {doc_cls.__doc_type__: doc_cls for doc_cls in doc_classes} + if isinstance(instance_mapper, dict): self._instance_mappers = instance_mapper else: - self._instance_mappers = {doc_cls: instance_mapper for doc_cls in self._doc_classes} + self._instance_mappers = {doc_cls: instance_mapper for doc_cls in doc_classes} self.total = raw_result['hits']['total'] self.hits = [] for hit in raw_result['hits']['hits']: - doc_cls = self._doc_cls_map[hit['_type']] + doc_cls = self._doc_cls_map.get(hit['_type'], DynamicDocument) self.hits.append(doc_cls(_hit=hit, _result=self)) self.aggregations = {} @@ -40,7 +44,7 @@ def get_aggregation(self, name): def _populate_instances(self, doc_cls): docs = [doc for doc in self.hits if isinstance(doc, doc_cls)] - instances = self._instance_mappers[doc_cls]([doc._id for doc in docs]) + instances = self._instance_mappers.get(doc_cls)([doc._id for doc in docs]) for doc in docs: doc.__dict__['instance'] = instances.get(doc._id) diff --git a/elasticmagic/search.py b/elasticmagic/search.py index 30c89df..49f6239 100644 --- a/elasticmagic/search.py +++ b/elasticmagic/search.py @@ -1,3 +1,5 @@ +import warnings +import collections from itertools import chain from .util import _with_clone, cached_property @@ -195,21 +197,24 @@ def with_search_type(self, search_type): def _get_doc_cls(self): if self._doc_cls: - doc_classes = [self._doc_cls] + doc_cls = self._doc_cls else: - doc_classes = self._collect_doc_classes() + doc_cls = self._collect_doc_classes() - if len(doc_classes) != 1: - raise ValueError('Cannot determine document class') + if not doc_cls: + warnings.warn('Cannot determine document class') + return None - return next(iter(doc_classes)) + return doc_cls def _get_doc_type(self, doc_cls=None): doc_cls = doc_cls or self._get_doc_cls() - if isinstance(doc_cls, tuple): + if isinstance(doc_cls, collections.Iterable): return ','.join(d.__doc_type__ for d in doc_cls) - else: - return self._doc_type or doc_cls.__doc_type__ + elif self._doc_type: + return self._doc_type + elif doc_cls: + return doc_cls.__doc_type__ def get_query(self, wrap_function_score=True): if wrap_function_score and self._boost_functions: diff --git a/tests/test_agg.py b/tests/test_agg.py index 0e9c3a7..0328b79 100644 --- a/tests/test_agg.py +++ b/tests/test_agg.py @@ -18,10 +18,15 @@ def test_aggs(self): "avg": {"field": "price"} } ) - a = a.build_agg_result({ + res = a.build_agg_result({ 'value': 75.3 }) - self.assertAlmostEqual(a.value, 75.3) + self.assertAlmostEqual(res.value, 75.3) + + aa = a.clone() + self.assertIsNot(a, aa) + self.assertEqual(a.__visit_name__, aa.__visit_name__) + self.assertEqual(a.params, aa.params) a = agg.Stats(f.grade) self.assert_expression( @@ -416,8 +421,8 @@ class ProductDocument(Document): self.assertEqual(a.get_aggregation('min_price').value, 350) # complex aggregation with sub aggregations - a = agg.Global( - aggs={ + a = agg.Global() + a = a.aggs({ 'selling_type': agg.Terms( f.selling_type, aggs={ diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 38d1636..e8ff9ed 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -6,6 +6,53 @@ class ClusterTest(BaseTestCase): + def test_search_query(self): + self.client.search = MagicMock( + return_value={ + 'hits': { + 'hits': [ + { + '_id': '381', + '_type': 'product', + '_index': 'test1', + '_score': 4.675524, + '_source': { + 'name': 'LG', + }, + }, + { + '_id': '921', + '_type': 'opinion', + '_index': 'test2', + '_score': 3.654321, + '_source': { + 'rank': 1.2, + }, + } + ], + 'max_score': 4.675524, + 'total': 6234 + }, + 'timed_out': False, + 'took': 57 + } + ) + sq = self.cluster.search_query() + result = sq.result + self.client.search.assert_called_with(body={}) + + self.assertEqual(len(result.hits), 2) + self.assertEqual(result.hits[0]._id, '381') + self.assertEqual(result.hits[0]._type, 'product') + self.assertEqual(result.hits[0]._index, 'test1') + self.assertAlmostEqual(result.hits[0]._score, 4.675524) + self.assertEqual(result.hits[0].name, 'LG') + self.assertEqual(result.hits[1]._id, '921') + self.assertEqual(result.hits[1]._type, 'opinion') + self.assertEqual(result.hits[1]._index, 'test2') + self.assertAlmostEqual(result.hits[1]._score, 3.654321) + self.assertAlmostEqual(result.hits[1].rank, 1.2) + def test_multi_index_search(self): es_log_index = self.cluster[('log_2014-11-19', 'log_2014-11-20', 'log_2014-11-20')] self.assertIs( @@ -83,7 +130,7 @@ def test_multi_search(self): } ) ProductDoc = self.index.product - sq1 = SearchQuery(doc_cls=ProductDoc, search_type='count') + sq1 = SearchQuery(doc_cls=ProductDoc, search_type='count', routing=123) sq2 = ( SearchQuery(index=self.cluster['us'], doc_cls=ProductDoc) .filter(ProductDoc.status == 0) @@ -92,7 +139,7 @@ def test_multi_search(self): results = self.cluster.multi_search([sq1, sq2]) self.client.msearch.assert_called_with( body=[ - {'doc_type': 'product', 'search_type': 'count'}, + {'doc_type': 'product', 'search_type': 'count', 'routing': 123}, {}, {'index': 'us', 'doc_type': 'product'}, {'query': {'filtered': {'filter': {'term': {'status': 0}}}}, 'size': 1} @@ -329,6 +376,7 @@ def test_bulk(self): self.assertEqual(result.took, 2) self.assertEqual(result.errors, True) self.assertEqual(len(result.items), 5) + self.assertIs(next(iter(result)), result.items[0]) self.assertEqual(result.items[0].name, 'index') self.assertEqual(result.items[0]._id, '1') self.assertEqual(result.items[0]._type, 'car') diff --git a/tests/test_document.py b/tests/test_document.py index ba28809..9e38726 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -101,14 +101,20 @@ class TestDocument(Document): self.assertIsInstance(TestDocument._id, AttributedField) self.assertIsInstance(TestDocument._id.get_field().get_type(), String) self.assertEqual(TestDocument._id.get_field().get_name(), '_id') + self.assertEqual(TestDocument._id.get_attr(), '_id') + self.assertIs(TestDocument._id.get_parent(), TestDocument) self.assert_expression(TestDocument._id, '_id') self.assertIsInstance(TestDocument._score, AttributedField) self.assertIsInstance(TestDocument._score.get_field().get_type(), Float) self.assertEqual(TestDocument._score.get_field().get_name(), '_score') + self.assertEqual(TestDocument._score.get_attr(), '_score') + self.assertIs(TestDocument._score.get_parent(), TestDocument) self.assert_expression(TestDocument._score, '_score') self.assertIsInstance(TestDocument.name, AttributedField) self.assertIsInstance(TestDocument.name.get_field().get_type(), String) self.assertEqual(TestDocument.name.get_field().get_name(), 'test_name') + self.assertEqual(TestDocument.name.get_attr(), 'name') + self.assertIs(TestDocument.name.get_parent(), TestDocument) self.assert_expression(TestDocument.name, 'test_name') self.assertEqual(list(TestDocument.name.fields), [TestDocument.name.raw]) self.assertEqual(TestDocument.name._collect_doc_classes(), [TestDocument]) @@ -116,6 +122,8 @@ class TestDocument(Document): self.assertIsInstance(TestDocument.name.raw.get_field().get_type(), String) self.assert_expression(TestDocument.name.raw, 'test_name.raw') self.assertEqual(TestDocument.name.raw.get_field().get_name(), 'test_name.raw') + self.assertEqual(TestDocument.name.raw.get_attr(), 'raw') + self.assertIsInstance(TestDocument.name.raw.get_parent(), AttributedField) self.assertEqual(TestDocument.name.raw._collect_doc_classes(), [TestDocument]) self.assertIsInstance(TestDocument.status, AttributedField) self.assertIsInstance(TestDocument.status.get_field().get_type(), Integer) @@ -127,10 +135,13 @@ class TestDocument(Document): self.assertIsInstance(TestDocument.group.name, AttributedField) self.assertEqual(list(TestDocument.group.name.fields), [TestDocument.group.name.raw]) self.assertEqual(TestDocument.group.name.get_field().get_name(), 'group.test_name') + self.assertEqual(TestDocument.group.name.raw.get_attr(), 'raw') self.assertIsInstance(TestDocument.group.name.get_field().get_type(), String) + self.assertIs(TestDocument.group.name.get_parent(), TestDocument) self.assertEqual(TestDocument.group.name._collect_doc_classes(), [TestDocument]) self.assertEqual(TestDocument.group.name.raw.get_field().get_name(), 'group.test_name.raw') self.assertIsInstance(TestDocument.group.name.raw.get_field().get_type(), String) + self.assertIsInstance(TestDocument.group.name.raw.get_parent(), AttributedField) self.assertEqual(TestDocument.group.name.raw._collect_doc_classes(), [TestDocument]) self.assertIsInstance(TestDocument.tags, AttributedField) self.assertIsInstance(TestDocument.tags.get_field().get_type(), List) @@ -280,7 +291,10 @@ class TestDocument(Document): price=101.5, tags=[TagDocument(id=1, name='Test tag'), TagDocument(id=2, name='Just tag')], - i_attr_3=45) + i_attr_1=None, + i_attr_2='', + i_attr_3=[], + i_attr_4=45) self.assertEqual( doc.to_source(), { @@ -294,7 +308,7 @@ class TestDocument(Document): {'id': 1, 'name': 'Test tag'}, {'id': 2, 'name': 'Just tag'}, ], - 'i_attr_3': 45 + 'i_attr_4': 45 } )