Skip to content

Commit

Permalink
Refactor API-server database unit tests.
Browse files Browse the repository at this point in the history
Distinguish between database initialization and database migration unit tests.

Prepares for #3130.
  • Loading branch information
fniessink committed May 21, 2024
1 parent e554c82 commit d235328
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 75 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Unit tests for the database initialization."""

import pathlib
from unittest.mock import Mock, mock_open, patch

from initialization.database import init_database

from tests.base import DataModelTestCase


class DatabaseInitTest(DataModelTestCase):
"""Unit tests for database initialization."""

def setUp(self):
"""Extend to set up the Mongo client and database contents."""
super().setUp()
self.mongo_client = Mock()
self.database.reports.find.return_value = []
self.database.reports.distinct.return_value = []
self.database.datamodels.find_one.return_value = None
self.database.reports_overviews.find_one.return_value = None
self.database.reports_overviews.find.return_value = []
self.database.reports.count_documents.return_value = 0
self.database.sessions.find_one.return_value = {"user": "jodoe"}
self.database.measurements.count_documents.return_value = 0
self.database.measurements.index_information.return_value = {}
self.mongo_client().quality_time_db = self.database

def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None:
"""Initialize the database."""
with (
patch.object(pathlib.Path, "glob", Mock(return_value=[])) as glob_mock,
patch.object(
pathlib.Path,
"open",
mock_open(read_data=data_model_json),
),
patch("pymongo.MongoClient", self.mongo_client),
):
init_database()
if assert_glob_called:
glob_mock.assert_called()
else:
glob_mock.assert_not_called()

def test_init_empty_database(self):
"""Test the initialization of an empty database."""
self.init_database('{"change": "yes"}')
self.database.datamodels.insert_one.assert_called_once()
self.database.reports_overviews.insert_one.assert_called_once()

def test_init_initialized_database(self):
"""Test the initialization of an initialized database."""
self.database.datamodels.find_one.return_value = self.DATA_MODEL
self.database.reports_overviews.find_one.return_value = {"_id": "id"}
self.database.reports.count_documents.return_value = 10
self.database.measurements.count_documents.return_value = 20
self.init_database("{}")
self.database.datamodels.insert_one.assert_not_called()
self.database.reports_overviews.insert_one.assert_not_called()

def test_skip_loading_example_reports(self):
"""Test that loading example reports can be skipped."""
with patch("src.initialization.database.os.environ.get", Mock(return_value="False")):
self.init_database('{"change": "yes"}', False)
self.database.datamodels.insert_one.assert_called_once()
self.database.reports_overviews.insert_one.assert_called_once()
Original file line number Diff line number Diff line change
@@ -1,83 +1,12 @@
"""Unit tests for the database initialization."""
"""Unit tests for database migrations."""

import pathlib
from unittest.mock import Mock, mock_open, patch

from initialization.database import init_database, perform_migrations
from initialization.database import perform_migrations

from tests.base import DataModelTestCase
from tests.fixtures import REPORT_ID, SUBJECT_ID, METRIC_ID, METRIC_ID2, METRIC_ID3, SOURCE_ID, SOURCE_ID2


class DatabaseInitializationTestCase(DataModelTestCase):
"""Base class for database unittests."""

def setUp(self):
"""Extend to set up the database fixture."""
super().setUp()
self.database = Mock()


class DatabaseInitTest(DatabaseInitializationTestCase):
"""Unit tests for database initialization."""

def setUp(self):
"""Extend to set up the Mongo client and database contents."""
super().setUp()
self.mongo_client = Mock()
self.database.reports.find.return_value = []
self.database.reports.distinct.return_value = []
self.database.datamodels.find_one.return_value = None
self.database.reports_overviews.find_one.return_value = None
self.database.reports_overviews.find.return_value = []
self.database.reports.count_documents.return_value = 0
self.database.sessions.find_one.return_value = {"user": "jodoe"}
self.database.measurements.count_documents.return_value = 0
self.database.measurements.index_information.return_value = {}
self.mongo_client().quality_time_db = self.database

def init_database(self, data_model_json: str, assert_glob_called: bool = True) -> None:
"""Initialize the database."""
with (
patch.object(pathlib.Path, "glob", Mock(return_value=[])) as glob_mock,
patch.object(
pathlib.Path,
"open",
mock_open(read_data=data_model_json),
),
patch("pymongo.MongoClient", self.mongo_client),
):
init_database()
if assert_glob_called:
glob_mock.assert_called()
else:
glob_mock.assert_not_called()

def test_init_empty_database(self):
"""Test the initialization of an empty database."""
self.init_database('{"change": "yes"}')
self.database.datamodels.insert_one.assert_called_once()
self.database.reports_overviews.insert_one.assert_called_once()

def test_init_initialized_database(self):
"""Test the initialization of an initialized database."""
self.database.datamodels.find_one.return_value = self.DATA_MODEL
self.database.reports_overviews.find_one.return_value = {"_id": "id"}
self.database.reports.count_documents.return_value = 10
self.database.measurements.count_documents.return_value = 20
self.init_database("{}")
self.database.datamodels.insert_one.assert_not_called()
self.database.reports_overviews.insert_one.assert_not_called()

def test_skip_loading_example_reports(self):
"""Test that loading example reports can be skipped."""
with patch("src.initialization.database.os.environ.get", Mock(return_value="False")):
self.init_database('{"change": "yes"}', False)
self.database.datamodels.insert_one.assert_called_once()
self.database.reports_overviews.insert_one.assert_called_once()


class DatabaseMigrationsChangeAccessibilityViolationsTest(DatabaseInitializationTestCase):
class DatabaseMigrationsChangeAccessibilityViolationsTest(DataModelTestCase):
"""Unit tests for the accessibility violations database migration."""

def test_change_accessibility_violations_to_violations_without_reports(self):
Expand Down Expand Up @@ -195,7 +124,7 @@ def test_change_accessibility_violations_to_violations_when_report_has_accessibi
)


class DatabaseMigrationsBranchParameterTest(DatabaseInitializationTestCase):
class DatabaseMigrationsBranchParameterTest(DataModelTestCase):
"""Unit tests for the branch parameter database migration."""

def test_migration_without_reports(self):
Expand Down

0 comments on commit d235328

Please sign in to comment.