diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c677d227..1d8a2833 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,8 +27,8 @@ repos: language: system types: ['python'] exclude: (?x)( - star| - tests/examples + tests/examples| + tests/workspace ) - id: isort name: isort @@ -36,8 +36,8 @@ repos: language: system types: ['python'] exclude: (?x)( - star| - tests/examples + tests/examples| + tests/workspace ) - id: pylint name: Pylint @@ -50,15 +50,15 @@ repos: language: system types: ['python'] exclude: (?x)( - tests/| - docs/ + docs/| + tests/ ) - id: mypy name: mypy entry: poetry run mypy --follow-imports=silent language: system exclude: (?x)( - star| - tests/examples + tests/examples| + tests/workspace ) types: ['python'] diff --git a/.pylintrc b/.pylintrc index 26dd2101..47b971bd 100644 --- a/.pylintrc +++ b/.pylintrc @@ -24,8 +24,8 @@ ignore=CVS # Add files or directories matching the regex patterns to the ignore-list. The # regex matches against paths. -ignore-paths=sqlsynthgen/star.py, - tests/examples +ignore-paths=tests/examples, + tests/workspace # Files or directories matching the regex patterns are skipped. The regex # matches against base names, not paths. diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index c780e7b3..7aebf3fe 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -1,51 +1,27 @@ """Entrypoint for the SQLSynthGen package.""" import sys -from importlib import import_module from pathlib import Path -from subprocess import CalledProcessError, run from sys import stderr -from types import ModuleType -from typing import Any, Optional +from typing import Final, Optional import typer -import yaml from sqlsynthgen.create import create_db_data, create_db_tables, create_db_vocab -from sqlsynthgen.make import make_generators_from_tables +from sqlsynthgen.make import make_generators_from_tables, make_tables_file from sqlsynthgen.settings import get_settings +from sqlsynthgen.utils import import_file, read_yaml_file -app = typer.Typer() - - -def import_file(file_path: str) -> ModuleType: - """Import a file. - - This utility function returns - the file at file_path as a module - - Args: - file_path (str): Path to file to be imported - - Returns: - ModuleType - """ - file_path_path = Path(file_path) - module_path = ".".join(file_path_path.parts[:-1] + (file_path_path.stem,)) - return import_module(module_path) +ORM_FILENAME: Final[str] = "orm.py" +SSG_FILENAME: Final[str] = "ssg.py" - -def read_yaml_file(path: str) -> Any: - """Read a yaml file in to dictionary, given a path.""" - with open(path, "r", encoding="utf8") as f: - config = yaml.safe_load(f) - return config +app = typer.Typer() @app.command() def create_data( - orm_file: str = typer.Argument(...), - ssg_file: str = typer.Argument(...), - num_passes: int = typer.Argument(...), + orm_file: str = typer.Option(ORM_FILENAME), + ssg_file: str = typer.Option(SSG_FILENAME), + num_passes: int = typer.Option(1), ) -> None: """Populate schema with synthetic data. @@ -63,15 +39,14 @@ def create_data( Final input is the number of rows required. Example: - $ sqlsynthgen create-data example_orm.py expected_ssg.py 100 + $ sqlsynthgen create-data Args: - orm_file (str): Path to object relational model. - ssg_file (str): Path to sqlsyngen output. + orm_file (str): Name of Python ORM file. + Must be in the current working directory. + ssg_file (str): Name of generators file. + Must be in the current working directory. num_passes (int): Number of passes to make. - - Returns: - None """ orm_module = import_file(orm_file) ssg_module = import_file(ssg_file) @@ -81,28 +56,33 @@ def create_data( @app.command() -def create_vocab(ssg_file: str = typer.Argument(...)) -> None: - """Create tables using the SQLAlchemy file.""" +def create_vocab(ssg_file: str = typer.Option(SSG_FILENAME)) -> None: + """Create tables using the SQLAlchemy file. + + Example: + $ sqlsynthgen create-vocab + + Args: + ssg_file (str): Name of generators file. + Must be in the current working directory. + """ ssg_module = import_file(ssg_file) create_db_vocab(ssg_module.sorted_vocab) @app.command() -def create_tables(orm_file: str = typer.Argument(...)) -> None: +def create_tables(orm_file: str = typer.Option(ORM_FILENAME)) -> None: """Create schema from Python classes. This CLI command creates Postgresql schema using object relational model declared as Python tables. (eg.) Example: - $ sqlsynthgen create-tables example_orm.py + $ sqlsynthgen create-tables Args: - orm_file (str): Path to Python tables file. - - Returns: - None - + orm_file (str): Name of Python ORM file. + Must be in the current working directory. """ orm_module = import_file(orm_file) create_db_tables(orm_module.Base.metadata) @@ -110,7 +90,8 @@ def create_tables(orm_file: str = typer.Argument(...)) -> None: @app.command() def make_generators( - orm_file: str = typer.Argument(...), + orm_file: str = typer.Option(ORM_FILENAME), + ssg_file: str = typer.Option(SSG_FILENAME), config_file: Optional[str] = typer.Argument(None), ) -> None: """Make a SQLSynthGen file of generator classes. @@ -119,55 +100,51 @@ def make_generators( returns a set of synthetic data generators for each attribute Example: - $ sqlsynthgen make-generators example_orm.py + $ sqlsynthgen make-generators Args: - orm_file (str): Path to Python tables file. + orm_file (str): Name of Python ORM file. + Must be in the current working directory. + ssg_file (str): Path to write the generators file to. config_file (str): Path to configuration file. """ + ssg_file_path = Path(ssg_file) + if ssg_file_path.exists(): + print(f"{ssg_file} should not already exist. Exiting...", file=stderr) + sys.exit(1) + orm_module = import_file(orm_file) generator_config = read_yaml_file(config_file) if config_file is not None else {} result = make_generators_from_tables(orm_module, generator_config) - print(result) + + ssg_file_path.write_text(result, encoding="utf-8") @app.command() -def make_tables() -> None: +def make_tables( + orm_file: str = typer.Option(ORM_FILENAME), +) -> None: """Make a SQLAlchemy file of Table classes. This CLI command deploys sqlacodegen to discover a - schema structure, and generates a object relational model declared + schema structure, and generates an object relational model declared as Python classes. Example: $ sqlsynthgen make_tables - """ - settings = get_settings() - - command = ["sqlacodegen"] - - if settings.src_schema: - command.append(f"--schema={settings.src_schema}") - - command.append(str(get_settings().src_postgres_dsn)) - try: - completed_process = run( - command, capture_output=True, encoding="utf-8", check=True - ) - except CalledProcessError as e: - print(e.stderr, file=stderr) - sys.exit(e.returncode) + Args: + orm_file (str): Path to write the Python ORM file. + """ + orm_file_path = Path(orm_file) + if orm_file_path.exists(): + print(f"{orm_file} should not already exist. Exiting...", file=stderr) + sys.exit(1) - # sqlacodegen falls back on Tables() for tables without PKs, - # but we don't explicitly support Tables and behaviour is unpredictable. - if " = Table(" in completed_process.stdout: - print( - "WARNING: Table without PK detected. sqlsynthgen may not be able to continue.", - file=stderr, - ) + settings = get_settings() - print(completed_process.stdout) + content = make_tables_file(str(settings.src_postgres_dsn), settings.src_schema) + orm_file_path.write_text(content, encoding="utf-8") if __name__ == "__main__": diff --git a/sqlsynthgen/make.py b/sqlsynthgen/make.py index 0f650852..d4335be5 100644 --- a/sqlsynthgen/make.py +++ b/sqlsynthgen/make.py @@ -1,16 +1,18 @@ """Functions to make a module of generator classes.""" -import csv import inspect -from pathlib import Path +import sys +from subprocess import CalledProcessError, run +from sys import stderr from types import ModuleType from typing import Any, Final, Optional from mimesis.providers.base import BaseProvider -from sqlalchemy import create_engine, select +from sqlalchemy import create_engine from sqlalchemy.sql import sqltypes from sqlsynthgen import providers from sqlsynthgen.settings import get_settings +from sqlsynthgen.utils import download_table HEADER_TEXT: str = "\n".join( ( @@ -132,20 +134,6 @@ def _add_generator_for_table( return content, new_class_name -def _download_table(table: Any, engine: Any) -> None: - """Download a table and store it as a .csv file.""" - stmt = select([table]) - with engine.connect() as conn: - result = list(conn.execute(stmt)) - with Path(table.fullname + ".csv").open( - "w", newline="", encoding="utf-8" - ) as csvfile: - writer = csv.writer(csvfile, delimiter=",") - writer.writerow([x.name for x in table.columns]) - for row in result: - writer.writerow(row) - - def make_generators_from_tables( tables_module: ModuleType, generator_config: dict ) -> str: @@ -162,7 +150,7 @@ def make_generators_from_tables( new_content += f"\nimport {tables_module.__name__}" generator_module_name = generator_config.get("custom_generators_module", None) if generator_module_name is not None: - new_content += f"\nfrom . import {generator_module_name}" + new_content += f"\nimport {generator_module_name}" sorted_generators = "[\n" sorted_vocab = "[\n" @@ -185,7 +173,7 @@ def make_generators_from_tables( ) sorted_vocab += f"{INDENTATION}{class_name.lower()}_vocab,\n" - _download_table(table, engine) + download_table(table, engine) else: new_content, new_generator_name = _add_generator_for_table( @@ -200,3 +188,34 @@ def make_generators_from_tables( new_content += "\n\n" + "sorted_vocab = " + sorted_vocab + "\n" return new_content + + +def make_tables_file(db_dsn: str, schema_name: Optional[str]) -> str: + """Write a file with the SQLAlchemy ORM classes. + + Exists with an error if sqlacodegen is unsuccessful. + """ + command = ["sqlacodegen"] + + if schema_name: + command.append(f"--schema={schema_name}") + + command.append(db_dsn) + + try: + completed_process = run( + command, capture_output=True, encoding="utf-8", check=True + ) + except CalledProcessError as e: + print(e.stderr, file=stderr) + sys.exit(e.returncode) + + # sqlacodegen falls back on Tables() for tables without PKs, + # but we don't explicitly support Tables and behaviour is unpredictable. + if " = Table(" in completed_process.stdout: + print( + "WARNING: Table without PK detected. sqlsynthgen may not be able to continue.", + file=stderr, + ) + + return completed_process.stdout diff --git a/sqlsynthgen/utils.py b/sqlsynthgen/utils.py new file mode 100644 index 00000000..7752172e --- /dev/null +++ b/sqlsynthgen/utils.py @@ -0,0 +1,61 @@ +"""Utility functions.""" +import csv +import os +import sys +from importlib import import_module +from pathlib import Path +from sys import stderr +from types import ModuleType +from typing import Any + +import yaml +from sqlalchemy import select + + +def read_yaml_file(path: str) -> Any: + """Read a yaml file in to dictionary, given a path.""" + with open(path, "r", encoding="utf8") as f: + config = yaml.safe_load(f) + return config + + +def import_file(file_name: str) -> ModuleType: + """Import a file. + + This utility function returns file_name imported as a module. + + Args: + file_name (str): The name of a file in the current working directory. + + Returns: + ModuleType + """ + module_name = file_name[:-3] + + sys.path.append(os.getcwd()) + + try: + module = import_module(module_name) + finally: + sys.path.pop() + + return module + + +def download_table(table: Any, engine: Any) -> None: + """Download a Table and store it as a .csv file.""" + csv_file_name = table.fullname + ".csv" + csv_file_path = Path(csv_file_name) + if csv_file_path.exists(): + print(f"{str(csv_file_name)} already exists. Exiting...", file=stderr) + sys.exit(1) + + stmt = select([table]) + with engine.connect() as conn: + result = list(conn.execute(stmt)) + + with csv_file_path.open("w", newline="", encoding="utf-8") as csvfile: + writer = csv.writer(csvfile, delimiter=",") + writer.writerow([x.name for x in table.columns]) + for row in result: + writer.writerow(row) diff --git a/tests/examples/expected_ssg.py b/tests/examples/expected_ssg.py index 796c5524..ac0b02ce 100644 --- a/tests/examples/expected_ssg.py +++ b/tests/examples/expected_ssg.py @@ -15,7 +15,7 @@ generic.add_provider(TimespanProvider) import tests.examples.example_orm -from . import custom_generators +import custom_generators concept_vocab = FileUploader(tests.examples.example_orm.Concept.__table__) diff --git a/tests/examples/functional_conf.yaml b/tests/examples/functional_conf.yaml new file mode 100644 index 00000000..26ec28a2 --- /dev/null +++ b/tests/examples/functional_conf.yaml @@ -0,0 +1,30 @@ +custom_generators_module: custom_generators +tables: + person: + num_rows_per_pass: 2 + custom_generators: + - name: generic.person.full_name + args: null + columns_assigned: name + - name: generic.datetime.datetime + args: + start: 2022 + end: 2022 + columns_assigned: stored_from + + hospital_visit: + num_rows_per_pass: 3 + custom_generators: + - name: custom_generators.timespan_generator + args: + generic: generic + earliest_start_year: 2021 + last_start_year: 2022 + min_dt_days: 1 + max_dt_days: 30 + columns_assigned: + - visit_start + - visit_end + - visit_duration_seconds + concept: + vocabulary_table: true diff --git a/tests/examples/import_test.py b/tests/examples/import_test.py new file mode 100644 index 00000000..68829e12 --- /dev/null +++ b/tests/examples/import_test.py @@ -0,0 +1,2 @@ +"""For testing the import_file function.""" +x = 10 diff --git a/tests/examples/mytable.csv b/tests/examples/mytable.csv deleted file mode 100644 index 9c8175ee..00000000 --- a/tests/examples/mytable.csv +++ /dev/null @@ -1,2 +0,0 @@ -id -1 diff --git a/tests/examples/src.dump b/tests/examples/src.dump index 1229b105..5be992fb 100644 --- a/tests/examples/src.dump +++ b/tests/examples/src.dump @@ -77,6 +77,7 @@ CREATE TABLE public.hospital_visit ( hospital_visit_id bigint NOT NULL PRIMARY KEY, person_id integer NOT NULL references public.person(person_id), visit_start date NOT NULL, + visit_end date NOT NULL, visit_duration_seconds real NOT NULL, visit_image bytea NOT NULL, visit_type_concept_id integer NOT NULL references public.concept(concept_id) @@ -95,20 +96,12 @@ CREATE TABLE public.no_pk_test ( ALTER TABLE public.no_pk_test OWNER TO postgres; -- --- Data for Name: hospital_visit; Type: TABLE DATA; Schema: public; Owner: postgres +-- Data for Name: concept; Type: TABLE DATA; Schema: public; Owner: postgres -- -COPY public.hospital_visit (hospital_visit_id, person_id, visit_start, visit_duration_seconds, visit_image) FROM stdin; -\. +insert into public.concept values (1, 'some concept name'); --- --- Data for Name: person; Type: TABLE DATA; Schema: public; Owner: postgres --- - -COPY public.person (person_id, name, research_opt_out, stored_from) FROM stdin; -\. - -- -- PostgreSQL database dump complete -- diff --git a/tests/test_create.py b/tests/test_create.py index 75dde7ea..369e98a6 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,4 +1,4 @@ -"""Tests for the main module.""" +"""Tests for the create module.""" from unittest import TestCase from unittest.mock import MagicMock, call, patch @@ -14,54 +14,58 @@ class MyTestCase(TestCase): """Module test case.""" - def test_create_db_data(self) -> None: + @patch("sqlsynthgen.create.create_engine") + @patch("sqlsynthgen.create.get_settings") + @patch("sqlsynthgen.create.populate") + def test_create_db_data( + self, + mock_populate: MagicMock, + mock_get_settings: MagicMock, + mock_create_engine: MagicMock, + ) -> None: """Test the generate function.""" - with patch("sqlsynthgen.create.populate") as mock_populate, patch( - "sqlsynthgen.create.get_settings" - ) as mock_get_settings, patch( - "sqlsynthgen.create.create_engine" - ) as mock_create_engine: - mock_get_settings.return_value = get_test_settings() + mock_get_settings.return_value = get_test_settings() - create_db_data([], [], 0) + create_db_data([], [], 0) - mock_populate.assert_called_once() - mock_create_engine.assert_called() + mock_populate.assert_called_once() + mock_create_engine.assert_called() - def test_create_db_tables(self) -> None: + @patch("sqlsynthgen.create.get_settings") + @patch("sqlsynthgen.create.create_engine") + def test_create_db_tables( + self, mock_create_engine: MagicMock, mock_get_settings: MagicMock + ) -> None: """Test the create_tables function.""" mock_meta = MagicMock() - with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch( - "sqlsynthgen.create.get_settings" - ) as mock_get_settings: + create_db_tables(mock_meta) + mock_get_settings.assert_called_once() + mock_create_engine.assert_called_once_with( + mock_get_settings.return_value.dst_postgres_dsn + ) - create_db_tables(mock_meta) - mock_get_settings.assert_called_once() - mock_create_engine.assert_called_once_with( - mock_get_settings.return_value.dst_postgres_dsn - ) - - def test_populate(self) -> None: + @patch("sqlsynthgen.create.insert") + def test_populate(self, mock_insert: MagicMock) -> None: """Test the populate function.""" - with patch("sqlsynthgen.create.insert") as mock_insert: - mock_src_conn = MagicMock() - mock_dst_conn = MagicMock() - mock_gen = MagicMock() - mock_gen.num_rows_per_pass = 2 - tables = [None] - generators = [mock_gen] - populate(mock_src_conn, mock_dst_conn, tables, generators, 1) + mock_src_conn = MagicMock() + mock_dst_conn = MagicMock() + mock_gen = MagicMock() + mock_gen.num_rows_per_pass = 2 + tables = [None] + generators = [mock_gen] + populate(mock_src_conn, mock_dst_conn, tables, generators, 1) - mock_gen.assert_has_calls([call(mock_src_conn, mock_dst_conn)] * 2) - mock_insert.return_value.values.assert_has_calls( - [call(mock_gen.return_value.__dict__)] * 2 - ) - mock_dst_conn.execute.assert_has_calls( - [call(mock_insert.return_value.values.return_value)] * 2 - ) + mock_gen.assert_has_calls([call(mock_src_conn, mock_dst_conn)] * 2) + mock_insert.return_value.values.assert_has_calls( + [call(mock_gen.return_value.__dict__)] * 2 + ) + mock_dst_conn.execute.assert_has_calls( + [call(mock_insert.return_value.values.return_value)] * 2 + ) - def test_populate_diff_length(self) -> None: + @patch("sqlsynthgen.create.insert") + def test_populate_diff_length(self, mock_insert: MagicMock) -> None: """Test when generators and tables differ in length.""" mock_dst_conn = MagicMock() mock_gen_two = MagicMock() @@ -69,23 +73,23 @@ def test_populate_diff_length(self) -> None: tables = [1, 2, 3] generators = [mock_gen_two, mock_gen_three] - with patch("sqlsynthgen.create.insert") as mock_insert: - populate(2, mock_dst_conn, tables, generators, 1) - self.assertListEqual([call(2), call(3)], mock_insert.call_args_list) + populate(2, mock_dst_conn, tables, generators, 1) + self.assertListEqual([call(2), call(3)], mock_insert.call_args_list) mock_gen_two.assert_called_once() mock_gen_three.assert_called_once() - def test_create_db_vocab(self) -> None: + @patch("sqlsynthgen.create.create_engine") + @patch("sqlsynthgen.create.get_settings") + def test_create_db_vocab( + self, mock_get_settings: MagicMock, mock_create_engine: MagicMock + ) -> None: """Test the create_db_vocab function.""" - with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch( - "sqlsynthgen.create.get_settings" - ) as mock_get_settings: - vocab_list = [MagicMock()] - create_db_vocab(vocab_list) - vocab_list[0].load.assert_called_once_with( - mock_create_engine.return_value.connect.return_value.__enter__.return_value - ) - mock_create_engine.assert_called_once_with( - mock_get_settings.return_value.dst_postgres_dsn - ) + vocab_list = [MagicMock()] + create_db_vocab(vocab_list) + vocab_list[0].load.assert_called_once_with( + mock_create_engine.return_value.connect.return_value.__enter__.return_value + ) + mock_create_engine.assert_called_once_with( + mock_get_settings.return_value.dst_postgres_dsn + ) diff --git a/tests/test_functional.py b/tests/test_functional.py index 9bbfd6d4..9d882472 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -5,56 +5,162 @@ from tests.utils import RequiresDBTestCase, run_psql +# pylint: disable=subprocess-run-check + class FunctionalTestCase(RequiresDBTestCase): """End-to-end tests.""" - orm_file_path = Path("tests/tmp/orm.py") - ssg_file_path = Path("tests/tmp/ssg.py") + orm_file_path = Path("orm.py") + ssg_file_path = Path("ssg.py") + + alt_orm_file_path = Path("my_orm.py") + alt_ssg_file_path = Path("my_ssg.py") + + concept_file_path = Path("concept.csv") + + env = os.environ.copy() + env = { + **env, + "src_host_name": "localhost", + "src_user_name": "postgres", + "src_password": "password", + "src_db_name": "src", + "src_schema": "", + "dst_host_name": "localhost", + "dst_user_name": "postgres", + "dst_password": "password", + "dst_db_name": "dst", + } def setUp(self) -> None: """Pre-test setup.""" - self.orm_file_path.unlink(missing_ok=True) - self.ssg_file_path.unlink(missing_ok=True) + # Create a blank destination database run_psql("dst.dump") - def test_workflow(self) -> None: + os.chdir("tests/workspace") + + for file_path in ( + self.orm_file_path, + self.ssg_file_path, + self.alt_orm_file_path, + self.alt_ssg_file_path, + self.concept_file_path, + ): + file_path.unlink(missing_ok=True) + + def tearDown(self) -> None: + os.chdir("../../") + + def test_workflow_minimal_args(self) -> None: """Test the recommended CLI workflow runs without errors.""" - env = os.environ.copy() - env = { - **env, - "src_host_name": "localhost", - "src_user_name": "postgres", - "src_password": "password", - "src_db_name": "src", - "src_schema": "", - "dst_host_name": "localhost", - "dst_user_name": "postgres", - "dst_password": "password", - "dst_db_name": "dst", - } - - with open(self.orm_file_path, "wb") as file: - run(["sqlsynthgen", "make-tables"], stdout=file, env=env, check=True) - - with open(self.ssg_file_path, "wb") as file: - run( - ["sqlsynthgen", "make-generators", self.orm_file_path], - stdout=file, - env=env, - check=True, - ) - - run(["sqlsynthgen", "create-tables", self.orm_file_path], env=env, check=True) - run( - ["sqlsynthgen", "create-vocab", self.ssg_file_path], - env=env, - check=True, - ) - run( - ["sqlsynthgen", "create-data", self.orm_file_path, self.ssg_file_path, "1"], - env=env, - check=True, + completed_process = run( + ["sqlsynthgen", "make-tables"], + capture_output=True, + env=self.env, + ) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + ["sqlsynthgen", "make-generators"], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + ["sqlsynthgen", "create-tables"], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + ["sqlsynthgen", "create-vocab"], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + ["sqlsynthgen", "create-data"], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + def test_workflow_maximal_args(self) -> None: + """Test the CLI workflow runs with optional arguments.""" + + completed_process = run( + [ + "sqlsynthgen", + "make-tables", + f"--orm-file={self.alt_orm_file_path}", + ], + capture_output=True, + env=self.env, + ) + self.assertEqual( + "WARNING: Table without PK detected. sqlsynthgen may not be able to continue.\n", + completed_process.stderr.decode("utf-8"), + ) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + [ + "sqlsynthgen", + "make-generators", + f"--orm-file={self.alt_orm_file_path}", + f"--ssg-file={self.alt_ssg_file_path}", + "../examples/functional_conf.yaml", + ], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + [ + "sqlsynthgen", + "create-tables", + f"--orm-file={self.alt_orm_file_path}", + ], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + [ + "sqlsynthgen", + "create-vocab", + f"--ssg-file={self.alt_ssg_file_path}", + ], + capture_output=True, + env=self.env, + ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) + + completed_process = run( + [ + "sqlsynthgen", + "create-data", + f"--orm-file={self.alt_orm_file_path}", + f"--ssg-file={self.alt_ssg_file_path}", + "--num-passes=2", + ], + capture_output=True, + env=self.env, ) + self.assertEqual("", completed_process.stderr.decode("utf-8")) + self.assertEqual(0, completed_process.returncode) diff --git a/tests/test_main.py b/tests/test_main.py index 98289abb..15a68dda 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,17 +1,15 @@ """Tests for the main module.""" -from subprocess import CalledProcessError +from io import StringIO from unittest import TestCase -from unittest.mock import call, patch +from unittest.mock import MagicMock, call, patch -import yaml from click.testing import Result from typer.testing import CliRunner from sqlsynthgen.main import app -from tests.examples import example_orm, expected_ssg from tests.utils import get_test_settings -runner = CliRunner() +runner = CliRunner(mix_stderr=False) class TestCLI(TestCase): @@ -24,173 +22,159 @@ def assertSuccess(self, result: Result) -> None: print(result.stdout) self.assertEqual(0, result.exit_code) - def test_make_tables(self) -> None: - """Test the make-tables sub-command.""" - - with patch("sqlsynthgen.main.run") as mock_run, patch( - "sqlsynthgen.main.get_settings" - ) as mock_get_settings: - mock_get_settings.return_value = get_test_settings() - mock_run.return_value.returncode = 0 - - result = runner.invoke( - app, - [ - "make-tables", - ], - catch_exceptions=False, - ) + @patch("sqlsynthgen.main.import_file") + @patch("sqlsynthgen.main.create_db_vocab") + def test_create_vocab(self, mock_create: MagicMock, mock_import: MagicMock) -> None: + """Test the create-vocab sub-command.""" + result = runner.invoke( + app, + [ + "create-vocab", + ], + catch_exceptions=False, + ) + mock_create.assert_called_once_with(mock_import.return_value.sorted_vocab) self.assertSuccess(result) - mock_run.assert_has_calls( + @patch("sqlsynthgen.main.import_file") + @patch("sqlsynthgen.main.Path") + @patch("sqlsynthgen.main.make_generators_from_tables") + def test_make_generators( + self, mock_make: MagicMock, mock_path: MagicMock, mock_import: MagicMock + ) -> None: + """Test the make-generators sub-command.""" + mock_path.return_value.exists.return_value = False + mock_make.return_value = "some text" + + result = runner.invoke( + app, [ - call( - [ - "sqlacodegen", - get_test_settings().src_postgres_dsn, - ], - capture_output=True, - encoding="utf-8", - check=True, - ), - ] + "make-generators", + ], + catch_exceptions=False, ) - self.assertNotEqual("", result.stdout) - def test_make_tables_with_schema(self) -> None: - """Test the make-tables sub-command handles the schema setting.""" - - with patch("sqlsynthgen.main.run") as mock_run, patch( - "sqlsynthgen.main.get_settings" - ) as mock_get_settings: - mock_get_settings.return_value = get_test_settings() - mock_get_settings.return_value.src_schema = "sschema" + mock_make.assert_called_once_with(mock_import.return_value, {}) + mock_path.return_value.write_text.assert_called_once_with( + "some text", encoding="utf-8" + ) + self.assertSuccess(result) - result = runner.invoke( - app, - [ - "make-tables", - ], - catch_exceptions=False, - ) + @patch("sqlsynthgen.main.Path") + @patch("sqlsynthgen.main.stderr", new_callable=StringIO) + def test_make_generators_errors_if_file_exists( + self, mock_stderr: MagicMock, mock_path: MagicMock + ) -> None: + """Test the make-tables sub-command doesn't overwrite.""" - self.assertSuccess(result) + mock_path.return_value.exists.return_value = True - mock_run.assert_has_calls( + result = runner.invoke( + app, [ - call( - [ - "sqlacodegen", - "--schema=sschema", - get_test_settings().src_postgres_dsn, - ], - capture_output=True, - encoding="utf-8", - check=True, - ), - ] + "make-generators", + ], + catch_exceptions=False, ) - self.assertNotEqual("", result.stdout) - - def test_make_tables_handles_errors(self) -> None: - """Test the make-tables sub-command handles sqlacodegen errors.""" - - with patch("sqlsynthgen.main.run") as mock_run, patch( - "sqlsynthgen.main.get_settings" - ) as mock_get_settings, patch("sqlsynthgen.main.stderr") as mock_stderr: - mock_run.side_effect = CalledProcessError( - returncode=99, cmd="some-cmd", stderr="some-error-output" - ) - mock_get_settings.return_value = get_test_settings() - - result = runner.invoke( - app, - [ - "make-tables", - ], - catch_exceptions=False, - ) - - self.assertEqual(99, result.exit_code) - mock_stderr.assert_has_calls( - [call.write("some-error-output"), call.write("\n")] + self.assertEqual( + "ssg.py should not already exist. Exiting...\n", mock_stderr.getvalue() ) + self.assertEqual(1, result.exit_code) + + @patch("sqlsynthgen.main.create_db_tables") + @patch("sqlsynthgen.main.import_file") + def test_create_tables( + self, mock_import: MagicMock, mock_create: MagicMock + ) -> None: + """Test the create-tables sub-command.""" - def test_make_tables_warns_no_pk(self) -> None: - """Test the make-tables sub-command warns about Tables().""" - with patch("sqlsynthgen.main.run") as mock_run, patch( - "sqlsynthgen.main.get_settings" - ) as mock_get_settings, patch("sqlsynthgen.main.stderr") as mock_stderr: - mock_get_settings.return_value = get_test_settings() - mock_run.return_value.stdout = "t_nopk_table = Table(" - - result = runner.invoke( - app, - [ - "make-tables", - ], - catch_exceptions=False, - ) - - self.assertEqual(0, result.exit_code) - mock_stderr.assert_has_calls( + result = runner.invoke( + app, [ - call.write( - "WARNING: Table without PK detected. sqlsynthgen may not be able to continue." - ), - call.write("\n"), - ] + "create-tables", + ], + catch_exceptions=False, ) - def test_make_generators(self) -> None: - """Test the make-generators sub-command.""" - with patch("sqlsynthgen.main.make_generators_from_tables") as mock_make: - conf_path = "tests/examples/generator_conf.yaml" - with open(conf_path, "r", encoding="utf8") as f: - config = yaml.safe_load(f) - result = runner.invoke( - app, - [ - "make-generators", - "tests/examples/example_orm.py", - conf_path, - ], - catch_exceptions=False, - ) - + mock_create.assert_called_once_with(mock_import.return_value.Base.metadata) self.assertSuccess(result) - mock_make.assert_called_once_with(example_orm, config) - def test_create_tables(self) -> None: - """Test the create-tables sub-command.""" + @patch("sqlsynthgen.main.import_file") + @patch("sqlsynthgen.main.create_db_data") + def test_create_data(self, mock_create: MagicMock, mock_import: MagicMock) -> None: + """Test the create-data sub-command.""" - with patch("sqlsynthgen.main.create_db_tables") as mock_create: - result = runner.invoke( - app, - ["create-tables", "tests/examples/example_orm.py"], - catch_exceptions=False, - ) + result = runner.invoke( + app, + [ + "create-data", + ], + catch_exceptions=False, + ) + self.assertListEqual( + [ + call("orm.py"), + call("ssg.py"), + ], + mock_import.call_args_list, + ) + mock_create.assert_called_once_with( + mock_import.return_value.Base.metadata.sorted_tables, + mock_import.return_value.sorted_generators, + 1, + ) self.assertSuccess(result) - mock_create.assert_called_once_with(example_orm.metadata) - def test_create_data(self) -> None: - """Test the create-data sub-command.""" + @patch("sqlsynthgen.main.Path") + @patch("sqlsynthgen.main.make_tables_file") + @patch("sqlsynthgen.main.get_settings") + def test_make_tables( + self, + mock_get_settings: MagicMock, + mock_make_tables_file: MagicMock, + mock_path: MagicMock, + ) -> None: + """Test the make-tables sub-command.""" - with patch("sqlsynthgen.main.create_db_data") as mock_create_db_data: - result = runner.invoke( - app, - [ - "create-data", - "tests/examples/example_orm.py", - "tests/examples/expected_ssg.py", - "10", - ], - catch_exceptions=False, - ) + mock_path.return_value.exists.return_value = False + mock_get_settings.return_value = get_test_settings() + mock_make_tables_file.return_value = "some text" + result = runner.invoke( + app, + [ + "make-tables", + ], + catch_exceptions=False, + ) + + mock_make_tables_file.assert_called_once_with( + "postgresql://suser:spassword@shost:5432/sdbname", None + ) + mock_path.return_value.write_text.assert_called_once_with( + "some text", encoding="utf-8" + ) self.assertSuccess(result) - mock_create_db_data.assert_called_once_with( - example_orm.metadata.sorted_tables, expected_ssg.sorted_generators, 10 + + @patch("sqlsynthgen.main.stderr", new_callable=StringIO) + @patch("sqlsynthgen.main.Path") + def test_make_tables_errors_if_file_exists( + self, mock_path: MagicMock, mock_stderr: MagicMock + ) -> None: + """Test the make-tables sub-command doesn't overwrite.""" + + mock_path.return_value.exists.return_value = True + + result = runner.invoke( + app, + [ + "make-tables", + ], + catch_exceptions=False, + ) + self.assertEqual( + "orm.py should not already exist. Exiting...\n", mock_stderr.getvalue() ) + self.assertEqual(1, result.exit_code) diff --git a/tests/test_make.py b/tests/test_make.py index 95407fbc..561baeca 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -1,50 +1,36 @@ """Tests for the main module.""" import os -from pathlib import Path -from unittest.mock import patch +from io import StringIO +from subprocess import CalledProcessError +from unittest import TestCase +from unittest.mock import MagicMock, call, patch import yaml -from sqlalchemy import Column, Integer, create_engine, insert -from sqlalchemy.orm import declarative_base from sqlsynthgen import make +from sqlsynthgen.make import make_tables_file from tests.examples import example_orm -from tests.utils import RequiresDBTestCase, run_psql +from tests.utils import SysExit -# pylint: disable=invalid-name -Base = declarative_base() -# pylint: enable=invalid-name -metadata = Base.metadata - -class MakeTable(Base): # type: ignore - """A SQLAlchemy table.""" - - __tablename__ = "maketable" - id = Column( - Integer, - primary_key=True, - ) - - -class MyTestCase(RequiresDBTestCase): - """Module test case.""" +class TestMake(TestCase): + """Tests that don't require a database.""" def setUp(self) -> None: """Pre-test setup.""" - run_psql("providers.dump") - - self.engine = create_engine( - "postgresql://postgres:password@localhost:5432/providers", - ) - metadata.create_all(self.engine) os.chdir("tests/examples") def tearDown(self) -> None: + """Post-test cleanup.""" os.chdir("../..") - def test_make_generators_from_tables(self) -> None: + @patch("sqlsynthgen.make.get_settings") + @patch("sqlsynthgen.make.create_engine") + @patch("sqlsynthgen.make.download_table") + def test_make_generators_from_tables( + self, mock_download: MagicMock, mock_create: MagicMock, _: MagicMock + ) -> None: """Check that we can make a generators file from a tables module.""" self.maxDiff = None # pylint: disable=invalid-name with open("expected_ssg.py", encoding="utf-8") as expected_output: @@ -53,27 +39,85 @@ def test_make_generators_from_tables(self) -> None: with open(conf_path, "r", encoding="utf8") as f: config = yaml.safe_load(f) - with patch("sqlsynthgen.make._download_table",) as mock_download, patch( - "sqlsynthgen.make.create_engine" - ) as mock_create_engine, patch("sqlsynthgen.make.get_settings"): actual = make.make_generators_from_tables(example_orm, config) mock_download.assert_called_once() - mock_create_engine.assert_called_once() + mock_create.assert_called_once() self.assertEqual(expected, actual) - def test__download_table(self) -> None: - """Test the _download_table function.""" - # pylint: disable=protected-access - with self.engine.connect() as conn: - conn.execute(insert(MakeTable).values({"id": 1})) + @patch("sqlsynthgen.make.run") + def test_make_tables_file(self, mock_run: MagicMock) -> None: + """Test the make_tables_file function.""" + + mock_run.return_value.stdout = "some output" + + make_tables_file("my:postgres/db", None) + + self.assertEqual( + call( + [ + "sqlacodegen", + "my:postgres/db", + ], + capture_output=True, + encoding="utf-8", + check=True, + ), + mock_run.call_args_list[0], + ) - make._download_table(MakeTable.__table__, self.engine) + @patch("sqlsynthgen.make.run") + def test_make_tables_file_with_schema(self, mock_run: MagicMock) -> None: + """Check that the function handles the schema setting.""" + + make_tables_file("my:postgres/db", "my_schema") + + self.assertEqual( + call( + [ + "sqlacodegen", + "--schema=my_schema", + "my:postgres/db", + ], + capture_output=True, + encoding="utf-8", + check=True, + ), + mock_run.call_args_list[0], + ) - with Path("expected.csv").open(encoding="utf-8") as csvfile: - expected = csvfile.read() + @patch("sys.exit") + @patch("sqlsynthgen.make.stderr", new_callable=StringIO) + @patch("sqlsynthgen.make.run") + def test_make_tables_handles_errors( + self, mock_run: MagicMock, mock_stderr: MagicMock, mock_exit: MagicMock + ) -> None: + """Test the make-tables sub-command handles sqlacodegen errors.""" - with Path("maketable.csv").open(encoding="utf-8") as csvfile: - actual = csvfile.read() + mock_run.side_effect = CalledProcessError( + returncode=99, cmd="some-cmd", stderr="some-error-output" + ) + mock_exit.side_effect = SysExit - self.assertEqual(expected, actual) + try: + make_tables_file("my:postgres/db", None) + except SysExit: + pass + + mock_exit.assert_called_once_with(99) + self.assertEqual("some-error-output\n", mock_stderr.getvalue()) + + @patch("sqlsynthgen.make.stderr", new_callable=StringIO) + @patch("sqlsynthgen.make.run") + def test_make_tables_warns_no_pk( + self, mock_run: MagicMock, mock_stderr: MagicMock + ) -> None: + """Test the make-tables sub-command warns about Tables().""" + + mock_run.return_value.stdout = "t_nopk_table = Table(" + make_tables_file("my:postgres/db", None) + + self.assertEqual( + "WARNING: Table without PK detected. sqlsynthgen may not be able to continue.\n", + mock_stderr.getvalue(), + ) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..2aef72d3 --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,111 @@ +"""Tests for the utils module.""" +import os +import sys +from io import StringIO +from pathlib import Path +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from sqlalchemy import Column, Integer, create_engine, insert +from sqlalchemy.orm import declarative_base + +from sqlsynthgen.utils import download_table, import_file +from tests.utils import RequiresDBTestCase, SysExit, run_psql + +# pylint: disable=invalid-name +Base = declarative_base() +# pylint: enable=invalid-name +metadata = Base.metadata + + +class MyTable(Base): # type: ignore + """A SQLAlchemy model.""" + + __tablename__ = "mytable" + id = Column( + Integer, + primary_key=True, + ) + + +class TestImport(TestCase): + """Tests for the import_file function.""" + + def setUp(self) -> None: + """Pre-test setup.""" + + os.chdir("tests/examples") + + def tearDown(self) -> None: + os.chdir("../../") + + def test_import_file(self) -> None: + """Test that we can import an example module.""" + old_path = sys.path.copy() + module = import_file("import_test.py") + self.assertEqual(10, module.x) + + self.assertEqual(old_path, sys.path) + + +class TestDownload(RequiresDBTestCase): + """Tests for the download_table function.""" + + mytable_file_path = Path("mytable.csv") + + def setUp(self) -> None: + """Pre-test setup.""" + + run_psql("providers.dump") + + self.engine = create_engine( + "postgresql://postgres:password@localhost:5432/providers", + connect_args={"connect_timeout": 10}, + ) + metadata.create_all(self.engine) + + os.chdir("tests/workspace") + self.mytable_file_path.unlink(missing_ok=True) + + def tearDown(self) -> None: + """Post-test cleanup.""" + os.chdir("../..") + + def test_download_table(self) -> None: + """Test the download_table function.""" + # pylint: disable=protected-access + + with self.engine.connect() as conn: + conn.execute(insert(MyTable).values({"id": 1})) + + download_table(MyTable.__table__, self.engine) + + with Path("../examples/expected.csv").open(encoding="utf-8") as csvfile: + expected = csvfile.read() + + with self.mytable_file_path.open(encoding="utf-8") as csvfile: + actual = csvfile.read() + + self.assertEqual(expected, actual) + + @patch("sys.exit") + @patch("sqlsynthgen.utils.stderr", new_callable=StringIO) + @patch("sqlsynthgen.utils.Path") + def test_download_table_does_not_overwrite( + self, mock_path: MagicMock, mock_stderr: MagicMock, mock_exit: MagicMock + ) -> None: + """Test the download_table function.""" + # pylint: disable=protected-access + + mock_exit.side_effect = SysExit + mock_path.return_value.exists.return_value = True + + try: + download_table(MyTable.__table__, None) + except SysExit: + pass + + self.assertEqual( + "mytable.csv already exists. Exiting...\n", mock_stderr.getvalue() + ) + mock_exit.assert_called_once_with(1) diff --git a/tests/utils.py b/tests/utils.py index e7e9cefb..c0613698 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -8,6 +8,10 @@ from sqlsynthgen import settings +class SysExit(Exception): + """To force the function to exit as sys.exit() would.""" + + @lru_cache(1) def get_test_settings() -> settings.Settings: """Get a Settings object that ignores .env files and environment variables.""" @@ -27,7 +31,7 @@ def get_test_settings() -> settings.Settings: def run_psql(dump_file_name: str) -> None: - """Run psql and""" + """Run psql and pass dump_file_name as the --file option.""" # If you need to update a .dump file, use # pg_dump -d DBNAME -h localhost -U postgres -C -c > tests/examples/FILENAME.dump diff --git a/tests/tmp/.gitignore b/tests/workspace/.gitignore similarity index 100% rename from tests/tmp/.gitignore rename to tests/workspace/.gitignore diff --git a/tests/workspace/README.md b/tests/workspace/README.md new file mode 100644 index 00000000..8165a69a --- /dev/null +++ b/tests/workspace/README.md @@ -0,0 +1,3 @@ +# Test Workspace + +A workspace for the functional tests to run in. diff --git a/tests/examples/custom_generators.py b/tests/workspace/custom_generators.py similarity index 100% rename from tests/examples/custom_generators.py rename to tests/workspace/custom_generators.py