diff --git a/tests/test_base.py b/tests/test_base.py index 5b712ff2..39743da3 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,5 +1,6 @@ """Tests for the base module.""" import os +from pathlib import Path from sqlalchemy import Column, Integer, create_engine, select from sqlalchemy.orm import declarative_base @@ -26,19 +27,22 @@ class BaseTable(Base): # type: ignore class VocabTests(RequiresDBTestCase): """Module test case.""" + test_dir = Path("tests/examples") + start_dir = os.getcwd() + def setUp(self) -> None: """Pre-test setup.""" - run_psql("providers.dump") + run_psql(Path("tests/examples/providers.dump")) self.engine = create_engine( "postgresql://postgres:password@localhost:5432/providers" ) metadata.create_all(self.engine) - os.chdir("tests/examples") + os.chdir(self.test_dir) def tearDown(self) -> None: - os.chdir("../..") + os.chdir(self.start_dir) def test_load(self) -> None: """Test the load method.""" diff --git a/tests/test_functional.py b/tests/test_functional.py index 9d882472..23037ad3 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -19,6 +19,9 @@ class FunctionalTestCase(RequiresDBTestCase): concept_file_path = Path("concept.csv") + test_dir = Path("tests/workspace") + start_dir = os.getcwd() + env = os.environ.copy() env = { **env, @@ -37,9 +40,9 @@ def setUp(self) -> None: """Pre-test setup.""" # Create a blank destination database - run_psql("dst.dump") + run_psql(Path("tests/examples/dst.dump")) - os.chdir("tests/workspace") + os.chdir(self.test_dir) for file_path in ( self.orm_file_path, @@ -51,7 +54,7 @@ def setUp(self) -> None: file_path.unlink(missing_ok=True) def tearDown(self) -> None: - os.chdir("../../") + os.chdir(self.start_dir) def test_workflow_minimal_args(self) -> None: """Test the recommended CLI workflow runs without errors.""" diff --git a/tests/test_make.py b/tests/test_make.py index 561baeca..ad5ab958 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,6 +1,7 @@ """Tests for the main module.""" import os from io import StringIO +from pathlib import Path from subprocess import CalledProcessError from unittest import TestCase from unittest.mock import MagicMock, call, patch @@ -16,14 +17,16 @@ class TestMake(TestCase): """Tests that don't require a database.""" + test_dir = Path("tests/examples") + start_dir = os.getcwd() + def setUp(self) -> None: """Pre-test setup.""" - - os.chdir("tests/examples") + os.chdir(self.test_dir) def tearDown(self) -> None: """Post-test cleanup.""" - os.chdir("../..") + os.chdir(self.start_dir) @patch("sqlsynthgen.make.get_settings") @patch("sqlsynthgen.make.create_engine") diff --git a/tests/test_providers.py b/tests/test_providers.py index 140f67a8..b00aec03 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -1,5 +1,6 @@ """Tests for the providers module.""" import datetime as dt +from pathlib import Path from unittest import TestCase from sqlalchemy import Column, Integer, Text, create_engine, insert @@ -40,7 +41,7 @@ class ColumnValueProviderTestCase(RequiresDBTestCase): def setUp(self) -> None: """Pre-test setup.""" - run_psql("providers.dump") + run_psql(Path("tests/examples/providers.dump")) self.engine = create_engine( "postgresql://postgres:password@localhost:5432/providers", diff --git a/tests/test_utils.py b/tests/test_utils.py index 2aef72d3..209410ab 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -31,13 +31,16 @@ class MyTable(Base): # type: ignore class TestImport(TestCase): """Tests for the import_file function.""" + test_dir = Path("tests/examples") + start_dir = os.getcwd() + def setUp(self) -> None: """Pre-test setup.""" - - os.chdir("tests/examples") + os.chdir(self.test_dir) def tearDown(self) -> None: - os.chdir("../../") + """Post-test cleanup.""" + os.chdir(self.start_dir) def test_import_file(self) -> None: """Test that we can import an example module.""" @@ -53,10 +56,13 @@ class TestDownload(RequiresDBTestCase): mytable_file_path = Path("mytable.csv") + test_dir = Path("tests/workspace") + start_dir = os.getcwd() + def setUp(self) -> None: """Pre-test setup.""" - run_psql("providers.dump") + run_psql(Path("tests/examples/providers.dump")) self.engine = create_engine( "postgresql://postgres:password@localhost:5432/providers", @@ -64,12 +70,12 @@ def setUp(self) -> None: ) metadata.create_all(self.engine) - os.chdir("tests/workspace") + os.chdir(self.test_dir) self.mytable_file_path.unlink(missing_ok=True) def tearDown(self) -> None: """Post-test cleanup.""" - os.chdir("../..") + os.chdir(self.start_dir) def test_download_table(self) -> None: """Test the download_table function.""" diff --git a/tests/utils.py b/tests/utils.py index c0613698..9fb6df48 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -30,7 +30,7 @@ def get_test_settings() -> settings.Settings: ) -def run_psql(dump_file_name: str) -> None: +def run_psql(dump_file: Path) -> None: """Run psql and pass dump_file_name as the --file option.""" # If you need to update a .dump file, use @@ -41,12 +41,7 @@ def run_psql(dump_file_name: str) -> None: # Clear and re-create the test database completed_process = run( - [ - "psql", - "--host=localhost", - "--username=postgres", - "--file=" + str(Path(f"tests/examples/{dump_file_name}")), - ], + ["psql", "--host=localhost", "--username=postgres", f"--file={dump_file}"], capture_output=True, env=env, check=True, @@ -57,7 +52,7 @@ def run_psql(dump_file_name: str) -> None: @skipUnless(os.environ.get("REQUIRES_DB") == "1", "Set 'REQUIRES_DB=1' to enable.") class RequiresDBTestCase(TestCase): - """A test case that only runs if REQUIRES_DB has been set to true.""" + """A test case that only runs if REQUIRES_DB has been set to 1.""" def setUp(self) -> None: pass