Permalink
Browse files

Updated django-elasticsearch

  • Loading branch information...
1 parent bb6fbf4 commit 30031ae58fa52a3894eb5fff9657af1e0a9b0791 Alberto Paro committed Feb 3, 2011
View
28 LICENSE
@@ -0,0 +1,28 @@
+Copyright (c) 2009, Ask Solem
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+ * Redistributions of source code must retain the above copyright notice,
+ this list of conditions and the following disclaimer.
+ * Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+Neither the name of Ask Solem nor the names of its contributors may be used
+to endorse or promote products derived from this software without specific
+prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS
+BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+POSSIBILITY OF SUCH DAMAGE.
+
@@ -57,13 +57,19 @@ def _ensure_is_connected(self):
except ValueError:
raise ImproperlyConfigured("PORT must be an integer")
-
- self._connection = ES("%s:%s"%(self.settings_dict['HOST'], port),
- decoder = Decoder,
- encoder=Encoder)
-
self.db_name = self.settings_dict['NAME']
- self._db_connection = self._connection
+ self._connection = ES("%s:%s" % (self.settings_dict['HOST'], port),
+ decoder=Decoder,
+ encoder=Encoder,
+ autorefresh=True,
+ default_indexes=[self.db_name])
+
+ self._db_connection = self._connection
+ #auto index creation: check if to remove
+ try:
+ self._connection.create_index(self.db_name)
+ except:
+ pass
# We're done!
self._is_connected = True
@@ -15,19 +15,19 @@
from django.db.models.sql.where import WhereNode
from django.db.models.fields import NOT_PROVIDED
from django.utils.tree import Node
-from pyes import MatchAllQuery, FilteredQuery, BoolQuery, StringQuery, ObjectId, WildcardQuery, RegexTermQuery, RangeQuery, ESRange
+from pyes import MatchAllQuery, FilteredQuery, BoolQuery, StringQuery, \
+ WildcardQuery, RegexTermQuery, RangeQuery, ESRange, \
+ TermQuery, ConstantScoreQuery, TermFilter, TermsFilter, NotFilter, RegexTermFilter
from djangotoolbox.db.basecompiler import NonrelQuery, NonrelCompiler, \
NonrelInsertCompiler, NonrelUpdateCompiler, NonrelDeleteCompiler
-from brainaetic.es.query import WildcardQuery
from django.db.models.fields import AutoField
-
+import logging
TYPE_MAPPING_FROM_DB = {
'unicode': lambda val: unicode(val),
'int': lambda val: int(val),
'float': lambda val: float(val),
'bool': lambda val: bool(val),
- 'objectid': lambda val: unicode(val),
}
TYPE_MAPPING_TO_DB = {
@@ -42,23 +42,23 @@
OPERATORS_MAP = {
'exact': lambda val: val,
- 'iexact': lambda val: val,
- 'startswith': lambda val: '%s*'%val,
- 'istartswith': lambda val: '%s*'%val.lower(),
- 'endswith': lambda val: '*%s'%val,
- 'iendswith': lambda val: '*%s'%val.lower(),
- 'contains': lambda val: '*%s*'%val,
- 'icontains': lambda val: '*%s*'%val.lower(),
+ 'iexact': lambda val: val, #tofix
+ 'startswith': lambda val: r'^%s' % re.escape(val),
+ 'istartswith': lambda val: r'^%s' % re.escape(val),
+ 'endswith': lambda val: r'%s$' % re.escape(val),
+ 'iendswith': lambda val: r'%s$' % re.escape(val),
+ 'contains': lambda val: r'%s' % re.escape(val),
+ 'icontains': lambda val: r'%s' % re.escape(val),
'regex': lambda val: val,
- 'iregex': lambda val: val.lower(),
+ 'iregex': lambda val: re.compile(val, re.IGNORECASE),
'gt': lambda val: {"_from" : val, "include_lower" : False},
'gte': lambda val: {"_from" : val, "include_lower" : True},
'lt': lambda val: {"_to" : val, "include_upper": False},
'lte': lambda val: {"_to" : val, "include_upper": True},
'range': lambda val: {"_from" : val[0], "_to" : val[1], "include_lower" : True, "include_upper": True},
'year': lambda val: {"_from" : val[0], "_to" : val[1], "include_lower" : True, "include_upper": False},
'isnull': lambda val: None if val else {'$ne': None},
- 'in': lambda val: {'$in': val},
+ 'in': lambda val: val,
}
NEGATED_OPERATORS_MAP = {
@@ -87,7 +87,7 @@ def _get_mapping(db_type, value, mapping):
# TODO - what if the data is represented as list on the python side?
if isinstance(value, list):
return map(_func, value)
-
+
return _func(value)
def python2db(db_type, value):
@@ -114,9 +114,8 @@ class DBQuery(NonrelQuery):
def __init__(self, compiler, fields):
super(DBQuery, self).__init__(compiler, fields)
self._connection = self.connection.db_connection
- self.indexname = self.query.get_meta().db_table
self._ordering = []
- self.db_query = BoolQuery()
+ self.db_query = ConstantScoreQuery()
# This is needed for debugging
def __repr__(self):
@@ -125,8 +124,9 @@ def __repr__(self):
@safe_call
def fetch(self, low_mark, high_mark):
results = self._get_results()
- hits = results['hits']['hits']
-
+ #print "results:", results
+ hits = results['hits']['hits']
+
if low_mark > 0:
hits = hits[low_mark:]
if high_mark is not None:
@@ -142,8 +142,8 @@ def count(self, limit=None):
query = self.db_query
if self.db_query.is_empty():
query = MatchAllQuery()
-
- res = self._connection.count(query, indexes=[self.connection.db_name], doc_types=[self.indexname])
+
+ res = self._connection.count(query, doc_types=self.query.model._meta.db_table)
return res["count"]
@safe_call
@@ -162,6 +162,8 @@ def order_by(self, ordering):
# This function is used by the default add_filters() implementation
@safe_call
def add_filter(self, column, lookup_type, negated, db_type, value):
+ if column == self.query.get_meta().pk.column:
+ column = '_id'
# Emulated/converted lookups
if negated and lookup_type in NEGATED_OPERATORS_MAP:
@@ -172,35 +174,40 @@ def add_filter(self, column, lookup_type, negated, db_type, value):
value = op(self.convert_value_for_db(db_type, value))
queryf = self._get_query_type(column, lookup_type, db_type, value)
-
+
if negated:
- self.db_query.add_must_not(queryf)
+ self.db_query.add(NotFilter(queryf))
else:
- self.db_query.add_must(queryf)
+ self.db_query.add(queryf)
def _get_query_type(self, column, lookup_type, db_type, value):
if db_type == "unicode":
if (lookup_type == "exact" or lookup_type == "iexact"):
- q = StringQuery('"%s"'%value, default_field=column)
- q.text = '"%s"'%value
+ q = TermQuery(column, value)
return q
if (lookup_type == "startswith" or lookup_type == "istartswith"):
- return WildcardQuery(column, value)
+ return RegexTermFilter(column, value)
if (lookup_type == "endswith" or lookup_type == "iendswith"):
- return WildcardQuery(column, value)
+ return RegexTermFilter(column, value)
if (lookup_type == "contains" or lookup_type == "icontains"):
- return WildcardQuery(column, value)
+ return RegexTermFilter(column, value)
if (lookup_type == "regex" or lookup_type == "iregex"):
- return RegexTermQuery(column, value)
+ return RegexTermFilter(column, value)
if db_type == "datetime" or db_type == "date":
if (lookup_type == "exact" or lookup_type == "iexact"):
- return TermQuery(column, value)
-
+ return TermFilter(column, value)
+
+ #TermFilter, TermsFilter
if lookup_type in ["gt", "gte", "lt", "lte", "range", "year"]:
value['field'] = column
return RangeQuery(ESRange(**value))
-
+ if lookup_type == "in":
+# terms = [TermQuery(column, val) for val in value]
+# if len(terms) == 1:
+# return terms[0]
+# return BoolQuery(should=terms)
+ return TermsFilter(field=column, values=value)
raise NotImplemented
def _get_results(self):
@@ -213,7 +220,8 @@ def _get_results(self):
query = MatchAllQuery()
if self._ordering:
query.sort = self._ordering
- return self._connection.search(query, indexes=[self.connection.db_name], doc_types=[self.indexname])
+ #print "query", self.query.tables, query
+ return self._connection.search(query, indexes=[self.connection.db_name], doc_types=self.query.model._meta.db_table)
class SQLCompiler(NonrelCompiler):
"""
@@ -248,7 +256,7 @@ def convert_value_for_db(self, db_type, value):
def insert_params(self):
conn = self.connection
-
+
params = {
'safe': conn.safe_inserts,
}
@@ -258,6 +266,37 @@ def insert_params(self):
return params
+ def _get_ordering(self):
+ if not self.query.default_ordering:
+ ordering = self.query.order_by
+ else:
+ ordering = self.query.order_by or self.query.get_meta().ordering
+ result = []
+ for order in ordering:
+ if LOOKUP_SEP in order:
+ #raise DatabaseError("Ordering can't span tables on non-relational backends (%s)" % order)
+ print "Ordering can't span tables on non-relational backends (%s):skipping" % order
+ continue
+ if order == '?':
+ raise DatabaseError("Randomized ordering isn't supported by the backend")
+
+ order = order.lstrip('+')
+
+ descending = order.startswith('-')
+ name = order.lstrip('-')
+ if name == 'pk':
+ name = self.query.get_meta().pk.name
+ order = '-' + name if descending else name
+
+ if self.query.standard_ordering:
+ result.append(order)
+ else:
+ if descending:
+ result.append(name)
+ else:
+ result.append('-' + name)
+ return result
+
class SQLInsertCompiler(NonrelInsertCompiler, SQLCompiler):
@safe_call
@@ -266,14 +305,11 @@ def insert(self, data, return_id=False):
pk = None
if pk_column in data:
pk = data[pk_column]
- else:
- pk = unicode(ObjectId())
- data[pk_column] = pk
db_table = self.query.get_meta().db_table
- res = self.connection.db_connection.index(data, self.connection.db_name, db_table, pk)
- #TODO: remove or timeout the refresh
- self.connection.db_connection.refresh([self.connection.db_name])
-
+ logging.debug("Insert data %s: %s" % (db_table, data))
+ #print("Insert data %s: %s" % (db_table, data))
+ res = self.connection.db_connection.index(data, self.connection.db_name, db_table, id=pk)
+ #print "Insert result", res
return res['_id']
# TODO: Define a common nonrel API for updates and add it to the nonrel
@@ -291,10 +327,8 @@ def execute_sql(self, return_id=False):
pk_name = pk_field.attname
db_table = self.query.get_meta().db_table
- res = self.connection.db_connection.index(data, self.connection.db_name, db_table, pk)
+ res = self.connection.db_connection.index(data, self.connection.db_name, db_table, id=pk)
- #TODO: remove or timeout the refresh
- self.connection.db_connection.refresh([self.connection.db_name])
return res['_id']
class SQLDeleteCompiler(NonrelDeleteCompiler, SQLCompiler):
@@ -303,9 +337,7 @@ def execute_sql(self, return_id=False):
self.query - the data that should be inserted
"""
db_table = self.query.get_meta().db_table
- if len(self.query.where.children)==1 and isinstance(self.query.where.children[0][0].field, AutoField) and self.query.where.children[0][1]=="in":
+ if len(self.query.where.children) == 1 and isinstance(self.query.where.children[0][0].field, AutoField) and self.query.where.children[0][1] == "in":
for pk in self.query.where.children[0][3]:
- res = self.connection.db_connection.delete(self.connection.db_name, db_table, pk)
- #TODO: remove or timeout the refresh
- self.connection.db_connection.refresh([self.connection.db_name])
- return
+ self.connection.db_connection.delete(self.connection.db_name, db_table, pk)
+ return
Oops, something went wrong.

0 comments on commit 30031ae

Please sign in to comment.