From 9292dc65da1755d192f93e8dfdce2039ae24e442 Mon Sep 17 00:00:00 2001 From: ramnes Date: Wed, 21 Dec 2022 20:20:33 +0100 Subject: [PATCH] Implement Thingy.save(refresh=True) to ensure data consistency MongoDB doesn't always store exactly the data we give it. For example, Python's datetime objects have microsecond precision, while dates stored in MongoDB have millisecond precision. Let's add a refresh option to Thingy.save() that runs an extra find query to return the data as it is stored. --- mongo_thingy/__init__.py | 22 ++++++++++++++++------ setup.cfg | 1 + tests/__init__.py | 37 +++++++++++++++++++++++++++++++++++++ tests/conftest.py | 4 ++++ 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/mongo_thingy/__init__.py b/mongo_thingy/__init__.py index e7c2fe4..21ffca3 100644 --- a/mongo_thingy/__init__.py +++ b/mongo_thingy/__init__.py @@ -225,13 +225,18 @@ def find_one_and_replace(cls, filter, replacement, *args, **kwargs): if result is not None: return cls(result) - def save(self, force_insert=False): + def save(self, force_insert=False, refresh=False): data = self.__dict__ + collection = self.get_collection() + if self.id is not None and not force_insert: filter = {"_id": self.id} - self.get_collection().replace_one(filter, data, upsert=True) + collection.replace_one(filter, data, upsert=True) else: - self.get_collection().insert_one(data) + collection.insert_one(data) + + if refresh: + self.__dict__ = collection.find_one(self.id) return self @@ -262,13 +267,18 @@ async def find_one_and_replace(cls, filter, replacement, *args, **kwargs): if result is not None: return cls(result) - async def save(self, force_insert=False): + async def save(self, force_insert=False, refresh=False): data = self.__dict__ + collection = self.get_collection() + if self.id is not None and not force_insert: filter = {"_id": self.id} - await self.get_collection().replace_one(filter, data, upsert=True) + await collection.replace_one(filter, data, upsert=True) else: - await self.get_collection().insert_one(data) + await collection.insert_one(data) + + if refresh: + self.__dict__ = await collection.find_one(self.id) return self diff --git a/setup.cfg b/setup.cfg index 7abb3d9..84aa1d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,6 +4,7 @@ python_files = *.py addopts = -vv --showlocals --cov mongo_thingy --cov-report term-missing --cov-fail-under 100 markers = all_backends: mark a test as testable against all backends + only_backends: mark a test as testable only against these backends ignore_backends: mark a test as not testable against these backends asyncio_mode = auto diff --git a/tests/__init__.py b/tests/__init__.py index aaef2e4..b456de1 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,4 +1,5 @@ import asyncio +from datetime import datetime, timezone import pytest from bson import ObjectId @@ -508,6 +509,24 @@ def test_thingy_save(TestThingy, collection): assert thingy._id == "bar" +@pytest.mark.only_backends("pymongo") +def test_thingy_save_refresh(TestThingy): + created_at = datetime.now(timezone.utc) + + thingy = TestThingy(created_at=created_at).save() + assert thingy.created_at == created_at + + thingy = thingy.save(refresh=True) + assert thingy.created_at != created_at + + approx = created_at.replace(microsecond=0, tzinfo=None) + saved_approx = thingy.created_at.replace(microsecond=0, tzinfo=None) + assert approx == saved_approx + + assert TestThingy.find_one().created_at != created_at + assert TestThingy.find_one().created_at == thingy.created_at + + async def test_async_thingy_save(TestThingy, collection): thingy = TestThingy(bar="baz") assert await TestThingy.count_documents() == 0 @@ -521,6 +540,24 @@ async def test_async_thingy_save(TestThingy, collection): assert thingy._id == "bar" +@pytest.mark.only_backends("motor_asyncio", "motor_tornado") +async def test_async_thingy_save_refresh(TestThingy): + created_at = datetime.now(timezone.utc) + + thingy = await TestThingy(created_at=created_at).save() + assert thingy.created_at == created_at + + thingy = await thingy.save(refresh=True) + assert thingy.created_at != created_at + + approx = created_at.replace(microsecond=0, tzinfo=None) + saved_approx = thingy.created_at.replace(microsecond=0, tzinfo=None) + assert approx == saved_approx + + assert (await TestThingy.find_one()).created_at != created_at + assert (await TestThingy.find_one()).created_at == thingy.created_at + + def test_thingy_save_force_insert(TestThingy, collection): thingy = TestThingy().save(force_insert=True) diff --git a/tests/conftest.py b/tests/conftest.py index c588a1e..549697c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -59,6 +59,10 @@ def pytest_generate_tests(metafunc): if option: _backends = [b for b in _backends if b == option] + marker = metafunc.definition.get_closest_marker("only_backends") + if marker: + _backends = [b for b in _backends if b in marker.args] + marker = metafunc.definition.get_closest_marker("ignore_backends") if marker: _backends = [b for b in _backends if b not in marker.args]