diff --git a/.gitignore b/.gitignore index c57cbc3a..c8b0d574 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # Custom .vscode .idea +.DS_Store # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/.pylintrc b/.pylintrc index 40746f3e..89698404 100644 --- a/.pylintrc +++ b/.pylintrc @@ -387,7 +387,9 @@ good-names=i, k, ex, Run, - _ + _, + e, + f # Good variable names regexes, separated by a comma. If names match any regex, # they will always be accepted good-names-rgxs= diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index ca8cb37e..50b3aa81 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -13,7 +13,7 @@ def create_db_tables(metadata: Any) -> Any: metadata.create_all(engine) -def generate(sorted_tables: list, sorted_generators: list) -> Any: +def create_db_data(sorted_tables: list, sorted_generators: list) -> None: """Connect to a database and populate it with data.""" settings = get_settings() engine = create_engine(settings.dst_postgres_dsn) diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 1e5e9f82..42c84bb9 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -1,26 +1,51 @@ -"""Entrypoint for the sqlsynthgen package.""" -from subprocess import run +"""Entrypoint for the SQLSynthGen package.""" +import sys +from importlib import import_module +from pathlib import Path +from subprocess import CalledProcessError, run +from sys import stderr +from types import ModuleType import typer +from sqlsynthgen.create import create_db_data, create_db_tables +from sqlsynthgen.make import make_generators_from_tables from sqlsynthgen.settings import get_settings app = typer.Typer() +def import_file(file_path: str) -> ModuleType: + """Import a file given a relative path.""" + file_path_path = Path(file_path) + module_path = ".".join(file_path_path.parts[:-1] + (file_path_path.stem,)) + return import_module(module_path) + + @app.command() -def create_data() -> None: +def create_data( + orm_file: str = typer.Argument(...), + ssg_file: str = typer.Argument(...), +) -> None: """Fill tables with synthetic data.""" + orm_module = import_file(orm_file) + ssg_module = import_file(ssg_file) + create_db_data(orm_module.metadata.sorted_tables, ssg_module.sorted_generators) @app.command() -def create_tables() -> None: +def create_tables(orm_file: str = typer.Argument(...)) -> None: """Create tables using the SQLAlchemy file.""" + orm_module = import_file(orm_file) + create_db_tables(orm_module.metadata) @app.command() -def make_generators() -> None: - """Make a SQLSynthGun file of generator classes.""" +def make_generators(orm_file: str = typer.Argument(...)) -> None: + """Make a SQLSynthGen file of generator classes.""" + orm_module = import_file(orm_file) + result = make_generators_from_tables(orm_module) + print(result) @app.command() @@ -35,7 +60,14 @@ def make_tables() -> None: command.append(str(get_settings().src_postgres_dsn)) - completed_process = run(command, capture_output=True, encoding="utf-8", check=True) + try: + completed_process = run( + command, capture_output=True, encoding="utf-8", check=True + ) + except CalledProcessError as e: + print(e.stderr, file=stderr) + sys.exit(e.returncode) + print(completed_process.stdout) diff --git a/sqlsynthgen/create_generators.py b/sqlsynthgen/make.py similarity index 89% rename from sqlsynthgen/create_generators.py rename to sqlsynthgen/make.py index 5bdfff74..66a86530 100644 --- a/sqlsynthgen/create_generators.py +++ b/sqlsynthgen/make.py @@ -1,5 +1,5 @@ -"""Functions to create a module of generator classes.""" -import importlib +"""Functions to make a module of generator classes.""" +from types import ModuleType from typing import Final from sqlalchemy.sql import sqltypes @@ -21,12 +21,11 @@ INDENTATION: Final[str] = " " * 4 -def create_generators_from_tables(tables_module_name: str) -> str: +def make_generators_from_tables(tables_module: ModuleType) -> str: """Creates sqlsynthgen generator classes from a sqlacodegen-generated file. Args: - tables_module_name: The name of a sqlacodegen-generated module - as you would provide to importlib.import_module. + tables_module: A sqlacodegen-generated module. Returns: A string that is a valid Python module, once written to file. @@ -47,7 +46,6 @@ def create_generators_from_tables(tables_module_name: str) -> str: sqltypes.LargeBinary: "generic.binary_provider.bytes()", } - tables_module = importlib.import_module(tables_module_name) for table in tables_module.metadata.sorted_tables: new_class_name = table.name + "Generator" sorted_generators += INDENTATION + new_class_name + ",\n" diff --git a/sqlsynthgen/settings.py b/sqlsynthgen/settings.py index 6a482ef2..8e8918d3 100644 --- a/sqlsynthgen/settings.py +++ b/sqlsynthgen/settings.py @@ -16,7 +16,7 @@ class Settings(BaseSettings): src_port: int = 5432 src_user_name: str # e.g. "postgres" or "myuser@mydb" src_password: str - src_db_name: str = "" # leave empty to get the user's default db + src_db_name: str src_ssl_required: bool = False # whether the db requires SSL src_schema: Optional[str] @@ -25,7 +25,7 @@ class Settings(BaseSettings): dst_port: int = 5432 dst_user_name: str # e.g. "postgres" or "myuser@mydb" dst_password: str - dst_db_name: str = "" # leave empty to get the user's default db + dst_db_name: str dst_ssl_required: bool = False # whether the db requires SSL # These are calculated so do not provide them explicitly diff --git a/tests/test_create.py b/tests/test_create.py index 852f2afe..788d7a96 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -2,14 +2,14 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from sqlsynthgen.create import create_db_tables, generate +from sqlsynthgen.create import create_db_data, create_db_tables from tests.utils import get_test_settings class MyTestCase(TestCase): """Module test case.""" - def test_generate(self) -> None: + def test_create_db_data(self) -> None: """Test the generate function.""" with patch("sqlsynthgen.create.populate") as mock_populate, patch( "sqlsynthgen.create.get_settings" @@ -18,12 +18,12 @@ def test_generate(self) -> None: ) as mock_create_engine: mock_get_settings.return_value = get_test_settings() - generate([], []) + create_db_data([], []) mock_populate.assert_called_once() mock_create_engine.assert_called_once() - def test_create_tables(self) -> None: + def test_create_db_tables(self) -> None: """Test the create_tables function.""" mock_meta = MagicMock() diff --git a/tests/test_main.py b/tests/test_main.py index 62f09469..eadc391d 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,4 +1,5 @@ """Tests for the main module.""" +from subprocess import CalledProcessError from unittest import TestCase from unittest.mock import call, patch @@ -6,6 +7,7 @@ from typer.testing import CliRunner from sqlsynthgen.main import app +from tests.examples import example_tables, expected_output from tests.utils import get_test_settings runner = CliRunner() @@ -90,38 +92,70 @@ def test_make_tables_with_schema(self) -> None: ) self.assertNotEqual("", result.stdout) + def test_make_tables_handles_errors(self) -> None: + """Test the make-tables sub-command handles sqlacodegen errors.""" + + with patch("sqlsynthgen.main.run") as mock_run, patch( + "sqlsynthgen.main.get_settings" + ) as mock_get_settings, patch("sqlsynthgen.main.stderr") as mock_stderr: + mock_run.side_effect = CalledProcessError( + returncode=99, cmd="some-cmd", stderr="some-error-output" + ) + mock_get_settings.return_value = get_test_settings() + + result = runner.invoke( + app, + [ + "make-tables", + ], + catch_exceptions=False, + ) + + self.assertEqual(99, result.exit_code) + mock_stderr.assert_has_calls( + [call.write("some-error-output"), call.write("\n")] + ) + def test_make_generators(self) -> None: """Test the make-generators sub-command.""" - result = runner.invoke( - app, - [ - "make-generators", - ], - catch_exceptions=False, - ) + with patch("sqlsynthgen.main.make_generators_from_tables") as mock_make: + result = runner.invoke( + app, + ["make-generators", "tests/examples/example_tables.py"], + catch_exceptions=False, + ) self.assertSuccess(result) + mock_make.assert_called_once_with(example_tables) def test_create_tables(self) -> None: """Test the create-tables sub-command.""" - result = runner.invoke( - app, - [ - "create-tables", - ], - catch_exceptions=False, - ) + + with patch("sqlsynthgen.main.create_db_tables") as mock_create: + result = runner.invoke( + app, + ["create-tables", "tests/examples/example_tables.py"], + catch_exceptions=False, + ) self.assertSuccess(result) + mock_create.assert_called_once_with(example_tables.metadata) def test_create_data(self) -> None: """Test the create-data sub-command.""" - result = runner.invoke( - app, - [ - "create-data", - ], - catch_exceptions=False, - ) + + with patch("sqlsynthgen.main.create_db_data") as mock_create_db_data: + result = runner.invoke( + app, + [ + "create-data", + "tests/examples/example_tables.py", + "tests/examples/expected_output.py", + ], + catch_exceptions=False, + ) self.assertSuccess(result) + mock_create_db_data.assert_called_once_with( + example_tables.metadata.sorted_tables, expected_output.sorted_generators + ) diff --git a/tests/test_create_generators.py b/tests/test_make.py similarity index 53% rename from tests/test_create_generators.py rename to tests/test_make.py index 7456a08d..08f5cca7 100644 --- a/tests/test_create_generators.py +++ b/tests/test_make.py @@ -1,20 +1,20 @@ """Tests for the main module.""" from unittest import TestCase -from sqlsynthgen import create_generators +from sqlsynthgen import make +from tests.examples import example_tables class MyTestCase(TestCase): """Module test case.""" - def test_generators_from_tables(self) -> None: - """Check that we can create a generators file from a tables file.""" + def test_make_generators_from_tables(self) -> None: + """Check that we can make a generators file from a tables module.""" + with open( "tests/examples/expected_output.py", encoding="utf-8" ) as expected_output: expected = expected_output.read() - actual = create_generators.create_generators_from_tables( - "tests.examples.example_tables" - ) + actual = make.make_generators_from_tables(example_tables) self.assertEqual(expected, actual) diff --git a/tests/test_settings.py b/tests/test_settings.py index 8fc70796..43608a15 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -13,20 +13,24 @@ def test_default_settings(self) -> None: src_host_name="shost", src_user_name="suser", src_password="spassword", + src_db_name="sdbname", dst_host_name="dhost", dst_user_name="duser", dst_password="dpassword", + dst_db_name="ddbname", # To stop any local .env files influencing the test _env_file=None, ) self.assertEqual( - "postgresql://suser:spassword@shost:5432/", str(settings.src_postgres_dsn) + "postgresql://suser:spassword@shost:5432/sdbname", + str(settings.src_postgres_dsn), ) self.assertIsNone(settings.src_schema) self.assertEqual( - "postgresql://duser:dpassword@dhost:5432/", str(settings.dst_postgres_dsn) + "postgresql://duser:dpassword@dhost:5432/ddbname", + str(settings.dst_postgres_dsn), ) def test_maximal_settings(self) -> None: diff --git a/tests/utils.py b/tests/utils.py index c40fd523..8635d78b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,9 +12,11 @@ def get_test_settings() -> settings.Settings: src_host_name="shost", src_user_name="suser", src_password="spassword", + src_db_name="sdbname", dst_host_name="dhost", dst_user_name="duser", dst_password="dpassword", + dst_db_name="ddbname", # To stop any local .env files influencing the test _env_file=None, )