Skip to content

Commit

Permalink
Increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
anti-social committed Feb 13, 2015
1 parent 327dec8 commit fae9859
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 32 deletions.
4 changes: 2 additions & 2 deletions elasticmagic/agg.py
Expand Up @@ -37,8 +37,8 @@ def clone(self):
return self.__class__(aggs=self._aggs, **self.params)

@_with_clone
def aggregations(self, *args, **aggs):
if args == (None,):
def aggregations(self, aggs):
if aggs is None:
self._aggs = Params()
else:
self._aggs = Params(dict(self._aggs), **aggs)
Expand Down
5 changes: 1 addition & 4 deletions elasticmagic/attribute.py
Expand Up @@ -73,10 +73,7 @@ def __get__(self, obj, type=None):
if obj is None:
return self

dict_ = obj.__dict__
if self._attr in obj.__dict__:
return dict_[self._attr]
dict_[self._attr] = None
obj.__dict__[self._attr] = None
return None

def _collect_doc_classes(self):
Expand Down
1 change: 1 addition & 0 deletions elasticmagic/cluster.py
@@ -1,5 +1,6 @@
from .util import clean_params
from .index import Index
from .search import SearchQuery
from .result import Result, BulkResult
from .document import DynamicDocument
from .expression import Params
Expand Down
24 changes: 14 additions & 10 deletions elasticmagic/result.py
@@ -1,28 +1,32 @@
import collections

from .agg import BucketAgg
from .document import Document
from .document import DynamicDocument


class Result(object):
def __init__(self, raw_result, aggregations=None,
doc_cls=None, instance_mapper=None):
self.raw = raw_result
self._query_aggs = aggregations or {}
if not doc_cls:
self._doc_classes = (Document,)
elif isinstance(doc_cls, tuple):
self._doc_classes = doc_cls

if doc_cls is None:
doc_classes = ()
elif not isinstance(doc_cls, collections.Iterable):
doc_classes = (doc_cls,)
else:
self._doc_classes = (doc_cls,)
self._doc_cls_map = {doc_cls.__doc_type__: doc_cls for doc_cls in self._doc_classes}
doc_classes = doc_cls
self._doc_cls_map = {doc_cls.__doc_type__: doc_cls for doc_cls in doc_classes}

if isinstance(instance_mapper, dict):
self._instance_mappers = instance_mapper
else:
self._instance_mappers = {doc_cls: instance_mapper for doc_cls in self._doc_classes}
self._instance_mappers = {doc_cls: instance_mapper for doc_cls in doc_classes}

self.total = raw_result['hits']['total']
self.hits = []
for hit in raw_result['hits']['hits']:
doc_cls = self._doc_cls_map[hit['_type']]
doc_cls = self._doc_cls_map.get(hit['_type'], DynamicDocument)
self.hits.append(doc_cls(_hit=hit, _result=self))

self.aggregations = {}
Expand All @@ -40,7 +44,7 @@ def get_aggregation(self, name):

def _populate_instances(self, doc_cls):
docs = [doc for doc in self.hits if isinstance(doc, doc_cls)]
instances = self._instance_mappers[doc_cls]([doc._id for doc in docs])
instances = self._instance_mappers.get(doc_cls)([doc._id for doc in docs])
for doc in docs:
doc.__dict__['instance'] = instances.get(doc._id)

Expand Down
21 changes: 13 additions & 8 deletions elasticmagic/search.py
@@ -1,3 +1,5 @@
import warnings
import collections
from itertools import chain

from .util import _with_clone, cached_property
Expand Down Expand Up @@ -195,21 +197,24 @@ def with_search_type(self, search_type):

def _get_doc_cls(self):
if self._doc_cls:
doc_classes = [self._doc_cls]
doc_cls = self._doc_cls
else:
doc_classes = self._collect_doc_classes()
doc_cls = self._collect_doc_classes()

if len(doc_classes) != 1:
raise ValueError('Cannot determine document class')
if not doc_cls:
warnings.warn('Cannot determine document class')
return None

return next(iter(doc_classes))
return doc_cls

def _get_doc_type(self, doc_cls=None):
doc_cls = doc_cls or self._get_doc_cls()
if isinstance(doc_cls, tuple):
if isinstance(doc_cls, collections.Iterable):
return ','.join(d.__doc_type__ for d in doc_cls)
else:
return self._doc_type or doc_cls.__doc_type__
elif self._doc_type:
return self._doc_type
elif doc_cls:
return doc_cls.__doc_type__

def get_query(self, wrap_function_score=True):
if wrap_function_score and self._boost_functions:
Expand Down
13 changes: 9 additions & 4 deletions tests/test_agg.py
Expand Up @@ -18,10 +18,15 @@ def test_aggs(self):
"avg": {"field": "price"}
}
)
a = a.build_agg_result({
res = a.build_agg_result({
'value': 75.3
})
self.assertAlmostEqual(a.value, 75.3)
self.assertAlmostEqual(res.value, 75.3)

aa = a.clone()
self.assertIsNot(a, aa)
self.assertEqual(a.__visit_name__, aa.__visit_name__)
self.assertEqual(a.params, aa.params)

a = agg.Stats(f.grade)
self.assert_expression(
Expand Down Expand Up @@ -416,8 +421,8 @@ class ProductDocument(Document):
self.assertEqual(a.get_aggregation('min_price').value, 350)

# complex aggregation with sub aggregations
a = agg.Global(
aggs={
a = agg.Global()
a = a.aggs({
'selling_type': agg.Terms(
f.selling_type,
aggs={
Expand Down
52 changes: 50 additions & 2 deletions tests/test_cluster.py
Expand Up @@ -6,6 +6,53 @@


class ClusterTest(BaseTestCase):
def test_search_query(self):
self.client.search = MagicMock(
return_value={
'hits': {
'hits': [
{
'_id': '381',
'_type': 'product',
'_index': 'test1',
'_score': 4.675524,
'_source': {
'name': 'LG',
},
},
{
'_id': '921',
'_type': 'opinion',
'_index': 'test2',
'_score': 3.654321,
'_source': {
'rank': 1.2,
},
}
],
'max_score': 4.675524,
'total': 6234
},
'timed_out': False,
'took': 57
}
)
sq = self.cluster.search_query()
result = sq.result
self.client.search.assert_called_with(body={})

self.assertEqual(len(result.hits), 2)
self.assertEqual(result.hits[0]._id, '381')
self.assertEqual(result.hits[0]._type, 'product')
self.assertEqual(result.hits[0]._index, 'test1')
self.assertAlmostEqual(result.hits[0]._score, 4.675524)
self.assertEqual(result.hits[0].name, 'LG')
self.assertEqual(result.hits[1]._id, '921')
self.assertEqual(result.hits[1]._type, 'opinion')
self.assertEqual(result.hits[1]._index, 'test2')
self.assertAlmostEqual(result.hits[1]._score, 3.654321)
self.assertAlmostEqual(result.hits[1].rank, 1.2)

def test_multi_index_search(self):
es_log_index = self.cluster[('log_2014-11-19', 'log_2014-11-20', 'log_2014-11-20')]
self.assertIs(
Expand Down Expand Up @@ -83,7 +130,7 @@ def test_multi_search(self):
}
)
ProductDoc = self.index.product
sq1 = SearchQuery(doc_cls=ProductDoc, search_type='count')
sq1 = SearchQuery(doc_cls=ProductDoc, search_type='count', routing=123)
sq2 = (
SearchQuery(index=self.cluster['us'], doc_cls=ProductDoc)
.filter(ProductDoc.status == 0)
Expand All @@ -92,7 +139,7 @@ def test_multi_search(self):
results = self.cluster.multi_search([sq1, sq2])
self.client.msearch.assert_called_with(
body=[
{'doc_type': 'product', 'search_type': 'count'},
{'doc_type': 'product', 'search_type': 'count', 'routing': 123},
{},
{'index': 'us', 'doc_type': 'product'},
{'query': {'filtered': {'filter': {'term': {'status': 0}}}}, 'size': 1}
Expand Down Expand Up @@ -329,6 +376,7 @@ def test_bulk(self):
self.assertEqual(result.took, 2)
self.assertEqual(result.errors, True)
self.assertEqual(len(result.items), 5)
self.assertIs(next(iter(result)), result.items[0])
self.assertEqual(result.items[0].name, 'index')
self.assertEqual(result.items[0]._id, '1')
self.assertEqual(result.items[0]._type, 'car')
Expand Down
18 changes: 16 additions & 2 deletions tests/test_document.py
Expand Up @@ -101,21 +101,29 @@ class TestDocument(Document):
self.assertIsInstance(TestDocument._id, AttributedField)
self.assertIsInstance(TestDocument._id.get_field().get_type(), String)
self.assertEqual(TestDocument._id.get_field().get_name(), '_id')
self.assertEqual(TestDocument._id.get_attr(), '_id')
self.assertIs(TestDocument._id.get_parent(), TestDocument)
self.assert_expression(TestDocument._id, '_id')
self.assertIsInstance(TestDocument._score, AttributedField)
self.assertIsInstance(TestDocument._score.get_field().get_type(), Float)
self.assertEqual(TestDocument._score.get_field().get_name(), '_score')
self.assertEqual(TestDocument._score.get_attr(), '_score')
self.assertIs(TestDocument._score.get_parent(), TestDocument)
self.assert_expression(TestDocument._score, '_score')
self.assertIsInstance(TestDocument.name, AttributedField)
self.assertIsInstance(TestDocument.name.get_field().get_type(), String)
self.assertEqual(TestDocument.name.get_field().get_name(), 'test_name')
self.assertEqual(TestDocument.name.get_attr(), 'name')
self.assertIs(TestDocument.name.get_parent(), TestDocument)
self.assert_expression(TestDocument.name, 'test_name')
self.assertEqual(list(TestDocument.name.fields), [TestDocument.name.raw])
self.assertEqual(TestDocument.name._collect_doc_classes(), [TestDocument])
self.assertIsInstance(TestDocument.name.raw, AttributedField)
self.assertIsInstance(TestDocument.name.raw.get_field().get_type(), String)
self.assert_expression(TestDocument.name.raw, 'test_name.raw')
self.assertEqual(TestDocument.name.raw.get_field().get_name(), 'test_name.raw')
self.assertEqual(TestDocument.name.raw.get_attr(), 'raw')
self.assertIsInstance(TestDocument.name.raw.get_parent(), AttributedField)
self.assertEqual(TestDocument.name.raw._collect_doc_classes(), [TestDocument])
self.assertIsInstance(TestDocument.status, AttributedField)
self.assertIsInstance(TestDocument.status.get_field().get_type(), Integer)
Expand All @@ -127,10 +135,13 @@ class TestDocument(Document):
self.assertIsInstance(TestDocument.group.name, AttributedField)
self.assertEqual(list(TestDocument.group.name.fields), [TestDocument.group.name.raw])
self.assertEqual(TestDocument.group.name.get_field().get_name(), 'group.test_name')
self.assertEqual(TestDocument.group.name.raw.get_attr(), 'raw')
self.assertIsInstance(TestDocument.group.name.get_field().get_type(), String)
self.assertIs(TestDocument.group.name.get_parent(), TestDocument)
self.assertEqual(TestDocument.group.name._collect_doc_classes(), [TestDocument])
self.assertEqual(TestDocument.group.name.raw.get_field().get_name(), 'group.test_name.raw')
self.assertIsInstance(TestDocument.group.name.raw.get_field().get_type(), String)
self.assertIsInstance(TestDocument.group.name.raw.get_parent(), AttributedField)
self.assertEqual(TestDocument.group.name.raw._collect_doc_classes(), [TestDocument])
self.assertIsInstance(TestDocument.tags, AttributedField)
self.assertIsInstance(TestDocument.tags.get_field().get_type(), List)
Expand Down Expand Up @@ -280,7 +291,10 @@ class TestDocument(Document):
price=101.5,
tags=[TagDocument(id=1, name='Test tag'),
TagDocument(id=2, name='Just tag')],
i_attr_3=45)
i_attr_1=None,
i_attr_2='',
i_attr_3=[],
i_attr_4=45)
self.assertEqual(
doc.to_source(),
{
Expand All @@ -294,7 +308,7 @@ class TestDocument(Document):
{'id': 1, 'name': 'Test tag'},
{'id': 2, 'name': 'Just tag'},
],
'i_attr_3': 45
'i_attr_4': 45
}
)

Expand Down

0 comments on commit fae9859

Please sign in to comment.