Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Refactoring grouping

  • Loading branch information...
commit 94e140b4702ee78b658e34e06dbf37fbda187f22 1 parent 80de825
Alexander authored

Showing 2 changed files with 84 additions and 43 deletions. Show diff stats Hide diff stats

  1. +52 25 solar/result.py
  2. +32 18 tests/test_query.py
77 solar/result.py
@@ -4,7 +4,7 @@ class SolrResults(object):
4 4 def __init__(self, query, hits, db_query=None, db_query_filters=[]):
5 5 self.query = query
6 6 self.searcher = self.query.searcher
7   - self.hits = hits
  7 + self.ndocs = self.hits = hits
8 8 self.docs = []
9 9 self._all_docs = []
10 10 self.facet_fields = []
@@ -15,8 +15,10 @@ def __init__(self, query, hits, db_query=None, db_query_filters=[]):
15 15 self._db_query = db_query
16 16 self._db_query_filters = db_query_filters
17 17
  18 + self.groupeds = {}
  19 +
18 20 def __len__(self):
19   - return self.hits
  21 + return self.ndocs
20 22
21 23 def __iter__(self):
22 24 return iter(self.docs)
@@ -28,28 +30,36 @@ def add_docs(self, docs):
28 30 self.docs = [Document(d, results=self) for d in docs]
29 31 self._all_docs = self.docs[:]
30 32
31   - def add_grouped_docs(self, grouped):
  33 + def add_grouped_docs(self, raw_grouped):
32 34 self.docs = []
33 35 self._all_docs = []
34   - for group_field, group_data in grouped.items():
  36 + for key, grouped_data in raw_grouped.items():
  37 + grouped = Grouped(key,
  38 + grouped_data.get('ngroups'),
  39 + grouped_data.get('matches'))
  40 + self.groupeds[key] = grouped
35 41 # grouped format
36   - if 'groups' in group_data:
37   - groups = group_data['groups']
38   - for group in groups:
39   - group_value = group['groupValue']
40   - group_doclist = group['doclist']
41   - group_docs = group_doclist['docs']
42   - group_hits = group_doclist['numFound']
43   - g_docs = [Document(d, results=self) for d in group_docs]
44   - doc = g_docs[0]
45   - self.docs.append(doc)
46   - self._all_docs.extend(g_docs)
47   - doc.grouped_docs = g_docs[1:]
48   - doc.grouped_count = group_hits - 1
  42 + if 'groups' in grouped_data:
  43 + groups = grouped_data['groups']
  44 + for group_data in groups:
  45 + group = Group(group_data['groupValue'],
  46 + group_data['doclist']['numFound'])
  47 + grouped.groups.append(group)
  48 + for doc_data in group_data['doclist']['docs']:
  49 + # TODO: make document cache
  50 + # documents with identical ids should map to one object
  51 + doc = Document(doc_data, results=self)
  52 + group.add_doc(doc)
  53 + self._all_docs.append(doc)
49 54 # simple format
50 55 else:
51   - for doc in group_data['doclist']['docs']:
52   - self.docs.append(Document(doc, results=self))
  56 + for doc_data in grouped_data['doclist']['docs']:
  57 + doc = Document(doc_data, results=self)
  58 + grouped.add_doc(doc)
  59 + self._all_docs.append(doc)
  60 +
  61 + def get_grouped(self, key):
  62 + return self.groupeds.get(key)
53 63
54 64 def add_stats_fields(self, stats_fields):
55 65 if stats_fields:
@@ -78,25 +88,42 @@ def _populate_instances(self):
78 88 ids = []
79 89 for doc in self._all_docs:
80 90 ids.append(self.searcher.get_id(doc.id))
81   - ids.extend([self.searcher.get_id(g_doc.id) for g_doc in doc.grouped_docs])
82 91 instances = self.searcher.get_instances(ids, self._db_query,
83 92 self._db_query_filters)
84 93
85   - for doc in self:
  94 + for doc in self._all_docs:
86 95 doc._instance = instances.get(self.searcher.get_id(doc.id))
87   - for g_doc in doc.grouped_docs:
88   - g_doc._instance = instances.get(self.searcher.get_id(g_doc.id))
89 96
90 97 @property
91 98 def instances(self):
92 99 return [doc.instance for doc in self if doc.instance]
93 100
  101 +class Grouped(object):
  102 + def __init__(self, key, ngroups, ndocs):
  103 + self.key = key
  104 + self.ngroups = ngroups # present if group.ngroups=true else None
  105 + self.ndocs = ndocs
  106 + self.groups = [] # grouped format
  107 + self.docs = [] # simple format
  108 +
  109 + def add_group(self, group):
  110 + self.groups.append(group)
  111 +
  112 + def add_doc(self, doc):
  113 + self.docs.append(doc)
  114 +
  115 +class Group(object):
  116 + def __init__(self, value, ndocs):
  117 + self.value = value
  118 + self.ndocs = ndocs
  119 + self.docs = []
  120 +
  121 + def add_doc(self, doc):
  122 + self.docs.append(doc)
94 123
95 124 class Document(object):
96 125 def __init__(self, doc, results=None):
97 126 self.results = results
98   - self.grouped_docs = []
99   - self.grouped_count = 0
100 127 for key in doc:
101 128 setattr(self, key, doc[key])
102 129
50 tests/test_query.py
@@ -97,8 +97,8 @@ def test_search_grouped_main(self):
97 97 s.solrs_read[0]._send_request.return_value = '''{
98 98 "grouped":{
99 99 "company":{
100   - "matches":0,
101   - "ngroups":0,
  100 + "matches":281,
  101 + "ngroups":109,
102 102 "groups":[{
103 103 "groupValue":"1",
104 104 "doclist":{"numFound":9,"start":0,"docs":[
@@ -182,16 +182,21 @@ def category_mapper(ids):
182 182 self.assertTrue('rows=24' in raw_query)
183 183
184 184 r = q.results
185   - self.assertEqual(len(r.docs), 2)
186   - self.assertEqual(r.docs[0].id, '111')
187   - self.assertEqual(r.docs[0].name, 'Test 1')
188   - self.assertEqual(r.docs[0].grouped_count, 8)
189   - self.assertEqual(r.docs[0].grouped_docs[-1].id, '333')
190   - self.assertEqual(r.docs[0].grouped_docs[-1].name, 'Test 3')
191   - self.assertEqual(r.docs[1].id, '555')
192   - self.assertEqual(r.docs[1].name, 'Test 5')
193   - self.assertEqual(len(r.docs[0].grouped_docs), 2)
  185 + grouped = r.get_grouped('company')
  186 + self.assertEqual(grouped.ngroups, 109)
  187 + self.assertEqual(grouped.ndocs, 281)
  188 + self.assertEqual(grouped.groups[0].ndocs, 9)
  189 + self.assertEqual(grouped.groups[0].docs[0].id, '111')
  190 + self.assertEqual(grouped.groups[0].docs[0].name, 'Test 1')
  191 + self.assertEqual(grouped.groups[0].docs[-1].id, '333')
  192 + self.assertEqual(grouped.groups[0].docs[-1].name, 'Test 3')
  193 + self.assertEqual(grouped.groups[1].ndocs, 1)
  194 + self.assertEqual(grouped.groups[1].docs[0].id, '555')
  195 + self.assertEqual(grouped.groups[1].docs[0].name, 'Test 5')
  196 + self.assertEqual(len(grouped.docs), 0)
  197 +
194 198 self.assertEqual(len(r.facet_fields), 2)
  199 +
195 200 category_facet = r.get_facet_field('category')
196 201 self.assertEqual(len(category_facet.values), 2)
197 202 self.assertEqual(category_facet.values[0].value, '1')
@@ -200,11 +205,13 @@ def category_mapper(ids):
200 205 self.assertEqual(category_facet.values[1].value, '2')
201 206 self.assertEqual(category_facet.values[1].count, 2)
202 207 self.assertEqual(category_facet.values[1].instance, {'id': 2, 'name': '2'})
  208 +
203 209 tag_facet = r.get_facet_field('tag')
204 210 self.assertEqual(len(tag_facet.values), 3)
205 211 self.assertEqual(tag_facet.values[-1].value, '1000')
206 212 self.assertEqual(tag_facet.values[-1].count, 30)
207 213 self.assertEqual(len(r.facet_queries), 1)
  214 +
208 215 price_stats = r.get_stats_field('price')
209 216 self.assertEqual(len(r.stats_fields), 1)
210 217 self.assertEqual(price_stats.min, 3.5)
@@ -251,14 +258,21 @@ def test_search_grouped_simple(self):
251 258 self.assertTrue('group.field=company' in raw_query)
252 259
253 260 r = q.results
254   - self.assertEqual(len(r.docs), 4)
255   - self.assertEqual(r.docs[0].id, '111')
256   - self.assertEqual(r.docs[0].name, 'Test 1')
257   - self.assertEqual(r.docs[2].id, '333')
258   - self.assertEqual(r.docs[2].name, 'Test 3')
259   - self.assertEqual(r.docs[3].id, '555')
260   - self.assertEqual(r.docs[3].name, 'Test 5')
  261 + grouped = r.get_grouped('company')
  262 + self.assertEqual(grouped.ngroups, 216036)
  263 + self.assertEqual(grouped.ndocs, 3657093)
  264 + self.assertEqual(len(grouped.docs), 4)
  265 + self.assertEqual(grouped.docs[0].id, '111')
  266 + self.assertEqual(grouped.docs[0].name, 'Test 1')
  267 + self.assertEqual(grouped.docs[2].id, '333')
  268 + self.assertEqual(grouped.docs[2].name, 'Test 3')
  269 + self.assertEqual(grouped.docs[3].id, '555')
  270 + self.assertEqual(grouped.docs[3].name, 'Test 5')
261 271
  272 + def test_instance_mapper(self):
  273 + # TODO
  274 + pass
  275 +
262 276
263 277 if __name__ == '__main__':
264 278 from unittest import main

0 comments on commit 94e140b

Please sign in to comment.
Something went wrong with that request. Please try again.