Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mongoengine/connection.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pymongo import MongoClient, ReadPreference, uri_parser
import six

from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.pymongo_support import IS_PYMONGO_3

__all__ = ['MongoEngineConnectionError', 'connect', 'register_connection',
'DEFAULT_CONNECTION_NAME']
'DEFAULT_CONNECTION_NAME', 'get_db']


DEFAULT_CONNECTION_NAME = 'default'
Expand Down
3 changes: 2 additions & 1 deletion mongoengine/context_managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from mongoengine.common import _import_class
from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
from mongoengine.pymongo_support import count_documents

__all__ = ('switch_db', 'switch_collection', 'no_dereference',
'no_sub_classes', 'query_counter', 'set_write_concern')
Expand Down Expand Up @@ -237,7 +238,7 @@ def _get_count(self):
and substracting the queries issued by this context. In fact everytime this is called, 1 query is
issued so we need to balance that
"""
count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter
count = count_documents(self.db.system.profile, self._ignored_query) - self._ctx_query_counter
self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
return count

Expand Down
4 changes: 2 additions & 2 deletions mongoengine/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
switch_db)
from mongoengine.errors import (InvalidDocumentError, InvalidQueryError,
SaveConditionError)
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.pymongo_support import IS_PYMONGO_3, list_collection_names
from mongoengine.queryset import (NotUniqueError, OperationError,
QuerySet, transform)

Expand Down Expand Up @@ -228,7 +228,7 @@ def _get_capped_collection(cls):

# If the collection already exists and has different options
# (i.e. isn't capped or has different max/size), raise an error.
if collection_name in db.collection_names():
if collection_name in list_collection_names(db, include_system_collections=True):
collection = db[collection_name]
options = collection.options()
if (
Expand Down
33 changes: 33 additions & 0 deletions mongoengine/pymongo_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""
Helper functions, constants, and types to aid with PyMongo v2.7 - v3.x support.
"""
import pymongo

_PYMONGO_37 = (3, 7)

PYMONGO_VERSION = tuple(pymongo.version_tuple[:2])

IS_PYMONGO_3 = PYMONGO_VERSION[0] >= 3
IS_PYMONGO_GTE_37 = PYMONGO_VERSION >= _PYMONGO_37


def count_documents(collection, filter):
"""Pymongo>3.7 deprecates count in favour of count_documents"""
if IS_PYMONGO_GTE_37:
return collection.count_documents(filter)
else:
count = collection.find(filter).count()
return count


def list_collection_names(db, include_system_collections=False):
"""Pymongo>3.7 deprecates collection_names in favour of list_collection_names"""
if IS_PYMONGO_GTE_37:
collections = db.list_collection_names()
else:
collections = db.collection_names()

if not include_system_collections:
collections = [c for c in collections if not c.startswith('system.')]

return collections
7 changes: 1 addition & 6 deletions mongoengine/python_support.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
"""
Helper functions, constants, and types to aid with Python v2.7 - v3.x and
PyMongo v2.7 - v3.x support.
Helper functions, constants, and types to aid with Python v2.7 - v3.x support
"""
import pymongo
import six


IS_PYMONGO_3 = pymongo.version_tuple[0] >= 3

# six.BytesIO resolves to StringIO.StringIO in Py2 and io.BytesIO in Py3.
StringIO = six.BytesIO

Expand Down
2 changes: 1 addition & 1 deletion mongoengine/queryset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from mongoengine.context_managers import set_write_concern, switch_db
from mongoengine.errors import (InvalidQueryError, LookUpError,
NotUniqueError, OperationError)
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.pymongo_support import IS_PYMONGO_3
from mongoengine.queryset import transform
from mongoengine.queryset.field_list import QueryFieldList
from mongoengine.queryset.visitor import Q, QNode
Expand Down
2 changes: 1 addition & 1 deletion mongoengine/queryset/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mongoengine.common import _import_class
from mongoengine.connection import get_connection
from mongoengine.errors import InvalidQueryError
from mongoengine.python_support import IS_PYMONGO_3
from mongoengine.pymongo_support import IS_PYMONGO_3

__all__ = ('query', 'update')

Expand Down
13 changes: 6 additions & 7 deletions tests/document/class_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest

from mongoengine import *
from mongoengine.pymongo_support import list_collection_names

from mongoengine.queryset import NULLIFY, PULL
from mongoengine.connection import get_db
Expand All @@ -27,9 +28,7 @@ class Person(Document):
self.Person = Person

def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
for collection in list_collection_names(self.db):
self.db.drop_collection(collection)

def test_definition(self):
Expand Down Expand Up @@ -66,10 +65,10 @@ def test_drop_collection(self):
"""
collection_name = 'person'
self.Person(name='Test').save()
self.assertIn(collection_name, self.db.collection_names())
self.assertIn(collection_name, list_collection_names(self.db))

self.Person.drop_collection()
self.assertNotIn(collection_name, self.db.collection_names())
self.assertNotIn(collection_name, list_collection_names(self.db))

def test_register_delete_rule(self):
"""Ensure that register delete rule adds a delete rule to the document
Expand Down Expand Up @@ -340,7 +339,7 @@ class Person(Document):
meta = {'collection': collection_name}

Person(name="Test User").save()
self.assertIn(collection_name, self.db.collection_names())
self.assertIn(collection_name, list_collection_names(self.db))

user_obj = self.db[collection_name].find_one()
self.assertEqual(user_obj['name'], "Test User")
Expand All @@ -349,7 +348,7 @@ class Person(Document):
self.assertEqual(user_obj.name, "Test User")

Person.drop_collection()
self.assertNotIn(collection_name, self.db.collection_names())
self.assertNotIn(collection_name, list_collection_names(self.db))

def test_collection_name_and_primary(self):
"""Ensure that a collection with a specified name may be used.
Expand Down
14 changes: 5 additions & 9 deletions tests/document/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@

from bson import SON
from mongoengine import *
from mongoengine.connection import get_db
from mongoengine.pymongo_support import list_collection_names
from tests.utils import MongoDBTestCase

__all__ = ("DeltaTest",)


class DeltaTest(unittest.TestCase):
class DeltaTest(MongoDBTestCase):

def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
super(DeltaTest, self).setUp()

class Person(Document):
name = StringField()
Expand All @@ -25,9 +23,7 @@ class Person(Document):
self.Person = Person

def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
for collection in list_collection_names(self.db):
self.db.drop_collection(collection)

def test_delta(self):
Expand Down
15 changes: 5 additions & 10 deletions tests/document/inheritance.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,18 @@

from mongoengine import (BooleanField, Document, EmbeddedDocument,
EmbeddedDocumentField, GenericReferenceField,
IntField, ReferenceField, StringField, connect)
from mongoengine.connection import get_db
IntField, ReferenceField, StringField)
from mongoengine.pymongo_support import list_collection_names
from tests.utils import MongoDBTestCase
from tests.fixtures import Base

__all__ = ('InheritanceTest', )


class InheritanceTest(unittest.TestCase):

def setUp(self):
connect(db='mongoenginetest')
self.db = get_db()
class InheritanceTest(MongoDBTestCase):

def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
for collection in list_collection_names(self.db):
self.db.drop_collection(collection)

def test_constructor_cls(self):
Expand Down
10 changes: 4 additions & 6 deletions tests/document/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pymongo.errors import DuplicateKeyError
from six import iteritems

from mongoengine.pymongo_support import list_collection_names
from tests import fixtures
from tests.fixtures import (PickleEmbedded, PickleTest, PickleSignalsTest,
PickleDynamicEmbedded, PickleDynamicTest)
Expand Down Expand Up @@ -55,9 +56,7 @@ class Person(Document):
self.Job = Job

def tearDown(self):
for collection in self.db.collection_names():
if 'system.' in collection:
continue
for collection in list_collection_names(self.db):
self.db.drop_collection(collection)

def assertDbEqual(self, docs):
Expand Down Expand Up @@ -572,7 +571,7 @@ class Post(Document):

Post.drop_collection()

Post._get_collection().insert({
Post._get_collection().insert_one({
"title": "Items eclipse",
"items": ["more lorem", "even more ipsum"]
})
Expand Down Expand Up @@ -3217,8 +3216,7 @@ class Person(Document):
coll = Person._get_collection()
for person in Person.objects.as_pymongo():
if 'height' not in person:
person['height'] = 189
coll.save(person)
coll.update_one({'_id': person['_id']}, {'$set': {'height': 189}})

self.assertEquals(Person.objects(height=189).count(), 1)

Expand Down
44 changes: 26 additions & 18 deletions tests/fields/file_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
TEST_IMAGE2_PATH = os.path.join(os.path.dirname(__file__), 'mongodb_leaf.png')


def get_file(path):
"""Use a BytesIO instead of a file to allow
to have a one-liner and avoid that the file remains opened"""
bytes_io = StringIO()
with open(path, 'rb') as f:
bytes_io.write(f.read())
bytes_io.seek(0)
return bytes_io


class FileTest(MongoDBTestCase):

def tearDown(self):
Expand Down Expand Up @@ -247,8 +257,8 @@ class Animal(Document):
Animal.drop_collection()
marmot = Animal(genus='Marmota', family='Sciuridae')

marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk
marmot.photo.put(marmot_photo, content_type='image/jpeg', foo='bar')
marmot_photo_content = get_file(TEST_IMAGE_PATH) # Retrieve a photo from disk
marmot.photo.put(marmot_photo_content, content_type='image/jpeg', foo='bar')
marmot.photo.close()
marmot.save()

Expand All @@ -261,11 +271,11 @@ class TestFile(Document):
the_file = FileField()
TestFile.drop_collection()

test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save()
test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save()
self.assertEqual(test_file.the_file.get().length, 8313)

test_file = TestFile.objects.first()
test_file.the_file = open(TEST_IMAGE2_PATH, 'rb')
test_file.the_file = get_file(TEST_IMAGE2_PATH)
test_file.save()
self.assertEqual(test_file.the_file.get().length, 4971)

Expand Down Expand Up @@ -379,7 +389,7 @@ class TestImage(Document):
self.assertEqual("%s" % e, "Invalid image: cannot identify image file %s" % f)

t = TestImage()
t.image.put(open(TEST_IMAGE_PATH, 'rb'))
t.image.put(get_file(TEST_IMAGE_PATH))
t.save()

t = TestImage.objects.first()
Expand All @@ -400,11 +410,11 @@ class TestFile(Document):
the_file = ImageField()
TestFile.drop_collection()

test_file = TestFile(the_file=open(TEST_IMAGE_PATH, 'rb')).save()
test_file = TestFile(the_file=get_file(TEST_IMAGE_PATH)).save()
self.assertEqual(test_file.the_file.size, (371, 76))

test_file = TestFile.objects.first()
test_file.the_file = open(TEST_IMAGE2_PATH, 'rb')
test_file.the_file = get_file(TEST_IMAGE2_PATH)
test_file.save()
self.assertEqual(test_file.the_file.size, (45, 101))

Expand All @@ -418,7 +428,7 @@ class TestImage(Document):
TestImage.drop_collection()

t = TestImage()
t.image.put(open(TEST_IMAGE_PATH, 'rb'))
t.image.put(get_file(TEST_IMAGE_PATH))
t.save()

t = TestImage.objects.first()
Expand All @@ -441,7 +451,7 @@ class TestImage(Document):
TestImage.drop_collection()

t = TestImage()
t.image.put(open(TEST_IMAGE_PATH, 'rb'))
t.image.put(get_file(TEST_IMAGE_PATH))
t.save()

t = TestImage.objects.first()
Expand All @@ -464,7 +474,7 @@ class TestImage(Document):
TestImage.drop_collection()

t = TestImage()
t.image.put(open(TEST_IMAGE_PATH, 'rb'))
t.image.put(get_file(TEST_IMAGE_PATH))
t.save()

t = TestImage.objects.first()
Expand Down Expand Up @@ -542,8 +552,8 @@ class TestImage(Document):
TestImage.drop_collection()

t = TestImage()
t.image1.put(open(TEST_IMAGE_PATH, 'rb'))
t.image2.put(open(TEST_IMAGE2_PATH, 'rb'))
t.image1.put(get_file(TEST_IMAGE_PATH))
t.image2.put(get_file(TEST_IMAGE2_PATH))
t.save()

test = TestImage.objects.first()
Expand All @@ -563,12 +573,10 @@ class Animal(Document):
Animal.drop_collection()
marmot = Animal(genus='Marmota', family='Sciuridae')

marmot_photo = open(TEST_IMAGE_PATH, 'rb') # Retrieve a photo from disk

photos_field = marmot._fields['photos'].field
new_proxy = photos_field.get_proxy_obj('photos', marmot)
new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar')
marmot_photo.close()
with open(TEST_IMAGE_PATH, 'rb') as marmot_photo: # Retrieve a photo from disk
photos_field = marmot._fields['photos'].field
new_proxy = photos_field.get_proxy_obj('photos', marmot)
new_proxy.put(marmot_photo, content_type='image/jpeg', foo='bar')

marmot.photos.append(new_proxy)
marmot.save()
Expand Down
Loading