Skip to content

Commit

Permalink
Implement Thingy.save(refresh=True) to ensure data consistency
Browse files Browse the repository at this point in the history
MongoDB doesn't always store exactly the data we give it. For example, Python's
datetime objects have microsecond precision, while dates stored in MongoDB have
millisecond precision. Let's add a refresh option to Thingy.save() that runs an
extra find query to return the data as it is stored.
  • Loading branch information
ramnes committed Dec 21, 2022
1 parent 5825e22 commit d62b61b
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 2 deletions.
10 changes: 8 additions & 2 deletions mongo_thingy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,13 +225,16 @@ def find_one_and_replace(cls, filter, replacement, *args, **kwargs):
if result is not None:
return cls(result)

def save(self, force_insert=False):
def save(self, force_insert=False, refresh=False):
data = self.__dict__
if self.id is not None and not force_insert:
filter = {"_id": self.id}
self.get_collection().replace_one(filter, data, upsert=True)
else:
self.get_collection().insert_one(data)

if refresh:
return self.find_one(self.id)
return self


Expand Down Expand Up @@ -262,13 +265,16 @@ async def find_one_and_replace(cls, filter, replacement, *args, **kwargs):
if result is not None:
return cls(result)

async def save(self, force_insert=False):
async def save(self, force_insert=False, refresh=False):
data = self.__dict__
if self.id is not None and not force_insert:
filter = {"_id": self.id}
await self.get_collection().replace_one(filter, data, upsert=True)
else:
await self.get_collection().insert_one(data)

if refresh:
return await self.find_one(self.id)
return self


Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ python_files = *.py
addopts = -vv --showlocals --cov mongo_thingy --cov-report term-missing --cov-fail-under 100
markers =
all_backends: mark a test as testable against all backends
only_backends: mark a test as testable only against these backends
ignore_backends: mark a test as not testable against these backends
asyncio_mode = auto

Expand Down
37 changes: 37 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from datetime import datetime, timezone

import pytest
from bson import ObjectId
Expand Down Expand Up @@ -508,6 +509,24 @@ def test_thingy_save(TestThingy, collection):
assert thingy._id == "bar"


@pytest.mark.only_backends("pymongo")
def test_thingy_save_refresh(TestThingy):
thingy = TestThingy(created_at=datetime.now(timezone.utc))
saved_thingy = thingy.save()
assert thingy.created_at == saved_thingy.created_at

fresh_thingy = thingy.save(refresh=True)
assert thingy.created_at != fresh_thingy.created_at

approx = thingy.created_at.replace(microsecond=0, tzinfo=None)
fresh_approx = fresh_thingy.created_at.replace(microsecond=0, tzinfo=None)
assert approx == fresh_approx

queried_thingy = TestThingy.find_one()
assert queried_thingy.created_at != thingy.created_at
assert queried_thingy.created_at == fresh_thingy.created_at


async def test_async_thingy_save(TestThingy, collection):
thingy = TestThingy(bar="baz")
assert await TestThingy.count_documents() == 0
Expand All @@ -521,6 +540,24 @@ async def test_async_thingy_save(TestThingy, collection):
assert thingy._id == "bar"


@pytest.mark.only_backends("motor_asyncio", "motor_tornado")
async def test_async_thingy_save_refresh(TestThingy):
thingy = TestThingy(created_at=datetime.now(timezone.utc))
saved_thingy = await thingy.save()
assert thingy.created_at == saved_thingy.created_at

fresh_thingy = await thingy.save(refresh=True)
assert thingy.created_at != fresh_thingy.created_at

approx = thingy.created_at.replace(microsecond=0, tzinfo=None)
fresh_approx = fresh_thingy.created_at.replace(microsecond=0, tzinfo=None)
assert approx == fresh_approx

queried_thingy = await TestThingy.find_one()
assert queried_thingy.created_at != thingy.created_at
assert queried_thingy.created_at == fresh_thingy.created_at


def test_thingy_save_force_insert(TestThingy, collection):
thingy = TestThingy().save(force_insert=True)

Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ def pytest_generate_tests(metafunc):
if option:
_backends = [b for b in _backends if b == option]

marker = metafunc.definition.get_closest_marker("only_backends")
if marker:
_backends = [b for b in _backends if b in marker.args]

marker = metafunc.definition.get_closest_marker("ignore_backends")
if marker:
_backends = [b for b in _backends if b not in marker.args]
Expand Down

0 comments on commit d62b61b

Please sign in to comment.