Skip to content

Commit

Permalink
Implement Thingy.save(refresh=True) to ensure data consistency
Browse files Browse the repository at this point in the history
MongoDB doesn't always store exactly the data we give it. For example, Python's
datetime objects have microsecond precision, while dates stored in MongoDB have
millisecond precision. Let's add a refresh option to Thingy.save() that runs an
extra find query to return the data as it is stored.
  • Loading branch information
ramnes committed Dec 21, 2022
1 parent 5825e22 commit 9292dc6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 6 deletions.
22 changes: 16 additions & 6 deletions mongo_thingy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,18 @@ def find_one_and_replace(cls, filter, replacement, *args, **kwargs):
if result is not None:
return cls(result)

def save(self, force_insert=False):
def save(self, force_insert=False, refresh=False):
data = self.__dict__
collection = self.get_collection()

if self.id is not None and not force_insert:
filter = {"_id": self.id}
self.get_collection().replace_one(filter, data, upsert=True)
collection.replace_one(filter, data, upsert=True)
else:
self.get_collection().insert_one(data)
collection.insert_one(data)

if refresh:
self.__dict__ = collection.find_one(self.id)
return self


Expand Down Expand Up @@ -262,13 +267,18 @@ async def find_one_and_replace(cls, filter, replacement, *args, **kwargs):
if result is not None:
return cls(result)

async def save(self, force_insert=False):
async def save(self, force_insert=False, refresh=False):
data = self.__dict__
collection = self.get_collection()

if self.id is not None and not force_insert:
filter = {"_id": self.id}
await self.get_collection().replace_one(filter, data, upsert=True)
await collection.replace_one(filter, data, upsert=True)
else:
await self.get_collection().insert_one(data)
await collection.insert_one(data)

if refresh:
self.__dict__ = await collection.find_one(self.id)
return self


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ python_files = *.py
addopts = -vv --showlocals --cov mongo_thingy --cov-report term-missing --cov-fail-under 100
markers =
all_backends: mark a test as testable against all backends
only_backends: mark a test as testable only against these backends
ignore_backends: mark a test as not testable against these backends
asyncio_mode = auto

Expand Down
37 changes: 37 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from datetime import datetime, timezone

import pytest
from bson import ObjectId
Expand Down Expand Up @@ -508,6 +509,24 @@ def test_thingy_save(TestThingy, collection):
assert thingy._id == "bar"


@pytest.mark.only_backends("pymongo")
def test_thingy_save_refresh(TestThingy):
created_at = datetime.now(timezone.utc)

thingy = TestThingy(created_at=created_at).save()
assert thingy.created_at == created_at

thingy = thingy.save(refresh=True)
assert thingy.created_at != created_at

approx = created_at.replace(microsecond=0, tzinfo=None)
saved_approx = thingy.created_at.replace(microsecond=0, tzinfo=None)
assert approx == saved_approx

assert TestThingy.find_one().created_at != created_at
assert TestThingy.find_one().created_at == thingy.created_at


async def test_async_thingy_save(TestThingy, collection):
thingy = TestThingy(bar="baz")
assert await TestThingy.count_documents() == 0
Expand All @@ -521,6 +540,24 @@ async def test_async_thingy_save(TestThingy, collection):
assert thingy._id == "bar"


@pytest.mark.only_backends("motor_asyncio", "motor_tornado")
async def test_async_thingy_save_refresh(TestThingy):
created_at = datetime.now(timezone.utc)

thingy = await TestThingy(created_at=created_at).save()
assert thingy.created_at == created_at

thingy = await thingy.save(refresh=True)
assert thingy.created_at != created_at

approx = created_at.replace(microsecond=0, tzinfo=None)
saved_approx = thingy.created_at.replace(microsecond=0, tzinfo=None)
assert approx == saved_approx

assert (await TestThingy.find_one()).created_at != created_at
assert (await TestThingy.find_one()).created_at == thingy.created_at


def test_thingy_save_force_insert(TestThingy, collection):
thingy = TestThingy().save(force_insert=True)

Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def pytest_generate_tests(metafunc):
if option:
_backends = [b for b in _backends if b == option]

marker = metafunc.definition.get_closest_marker("only_backends")
if marker:
_backends = [b for b in _backends if b in marker.args]

marker = metafunc.definition.get_closest_marker("ignore_backends")
if marker:
_backends = [b for b in _backends if b not in marker.args]
Expand Down

0 comments on commit 9292dc6

Please sign in to comment.