diff --git a/poetry.lock b/poetry.lock index 308ea3a1..d0d28d93 100644 --- a/poetry.lock +++ b/poetry.lock @@ -40,7 +40,7 @@ uvloop = ["uvloop (>=0.15.2)"] name = "click" version = "8.1.3" description = "Composable command line interface toolkit" -category = "dev" +category = "main" optional = false python-versions = ">=3.7" @@ -51,7 +51,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} name = "colorama" version = "0.4.6" description = "Cross-platform colored terminal text." -category = "dev" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" @@ -324,6 +324,23 @@ category = "dev" optional = false python-versions = ">=3.6" +[[package]] +name = "typer" +version = "0.7.0" +description = "Typer, build great CLIs. Easy to code. Based on Python type hints." +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +click = ">=7.1.1,<9.0.0" + +[package.extras] +all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] +dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"] +doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"] +test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.910)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<13.0.0)", "shellingham (>=1.3.0,<2.0.0)"] + [[package]] name = "typing-extensions" version = "4.4.0" @@ -343,7 +360,7 @@ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,>=2.7" [metadata] lock-version = "1.1" python-versions = "^3.9" -content-hash = "4f9a87e5fd49a9594c3d6dba5ac06bfb953a95c91014d8662e2d85af8982dad3" +content-hash = "37881e345d4aa2d0100a7c161c1daa7e1a476125d3b6a36a3f278813881d3a46" [metadata.files] astroid = [ @@ -697,6 +714,10 @@ tomlkit = [ {file = "tomlkit-0.11.6-py3-none-any.whl", hash = "sha256:07de26b0d8cfc18f871aec595fda24d95b08fef89d147caa861939f37230bf4b"}, {file = "tomlkit-0.11.6.tar.gz", hash = "sha256:71b952e5721688937fb02cf9d354dbcf0785066149d2855e44531ebdd2b65d73"}, ] +typer = [ + {file = "typer-0.7.0-py3-none-any.whl", hash = "sha256:b5e704f4e48ec263de1c0b3a2387cd405a13767d2f907f44c1a08cbad96f606d"}, + {file = "typer-0.7.0.tar.gz", hash = "sha256:ff797846578a9f2a201b53442aedeb543319466870fbe1c701eab66dd7681165"}, +] typing-extensions = [ {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, diff --git a/pyproject.toml b/pyproject.toml index 7c5b6d2c..4c7bf1c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ pydantic = {extras = ["dotenv"], version = "^1.10.2"} psycopg2-binary = "^2.9.5" sqlalchemy-utils = "^0.38.3" mimesis = "^6.1.1" +typer = "^0.7.0" [tool.poetry.group.dev.dependencies] @@ -26,7 +27,7 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -sqlsynthgen = "sqlsynthgen.main:main" +sqlsynthgen = "sqlsynthgen.main:app" [tool.isort] profile = "black" diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py new file mode 100644 index 00000000..ca8cb37e --- /dev/null +++ b/sqlsynthgen/create.py @@ -0,0 +1,30 @@ +"""Functions and classes to create and populate the target database.""" +from typing import Any + +from sqlalchemy import create_engine, insert + +from sqlsynthgen.settings import get_settings + + +def create_db_tables(metadata: Any) -> Any: + """Create tables described by the sqlalchemy metadata object.""" + settings = get_settings() + engine = create_engine(settings.dst_postgres_dsn) + metadata.create_all(engine) + + +def generate(sorted_tables: list, sorted_generators: list) -> Any: + """Connect to a database and populate it with data.""" + settings = get_settings() + engine = create_engine(settings.dst_postgres_dsn) + + with engine.connect() as conn: + populate(conn, sorted_tables, sorted_generators) + + +def populate(conn: Any, tables: list, generators: list) -> None: + """Populate a database schema with dummy data.""" + + for table, generator in zip(tables, generators): + stmt = insert(table).values(generator(conn).__dict__) + conn.execute(stmt) diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 391ce867..1e5e9f82 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -1,39 +1,43 @@ """Entrypoint for the sqlsynthgen package.""" -from typing import Any +from subprocess import run -from sqlalchemy import create_engine, insert +import typer from sqlsynthgen.settings import get_settings +app = typer.Typer() -def create_tables(metadata: Any) -> Any: - """Create tables described by the sqlalchemy metadata object.""" - settings = get_settings() - engine = create_engine(settings.postgres_dsn) - metadata.create_all(engine) + +@app.command() +def create_data() -> None: + """Fill tables with synthetic data.""" + + +@app.command() +def create_tables() -> None: + """Create tables using the SQLAlchemy file.""" -def main() -> None: - """Not implemented yet.""" - raise NotImplementedError +@app.command() +def make_generators() -> None: + """Make a SQLSynthGun file of generator classes.""" -def generate(sorted_tables: list, sorted_generators: list) -> Any: - """Connect to a database and populate it with data.""" +@app.command() +def make_tables() -> None: + """Make a SQLAlchemy file of Table classes.""" settings = get_settings() - engine = create_engine(settings.postgres_dsn) - with engine.connect() as conn: - populate(conn, sorted_tables, sorted_generators) + command = ["sqlacodegen"] + if settings.src_schema: + command.append(f"--schema={settings.src_schema}") -def populate(conn: Any, tables: list, generators: list) -> None: - """Populate a database schema with dummy data.""" + command.append(str(get_settings().src_postgres_dsn)) - for table, generator in zip(tables, generators): - stmt = insert(table).values(generator(conn).__dict__) - conn.execute(stmt) + completed_process = run(command, capture_output=True, encoding="utf-8", check=True) + print(completed_process.stdout) if __name__ == "__main__": - main() + app() diff --git a/sqlsynthgen/settings.py b/sqlsynthgen/settings.py index fae3c5e4..6a482ef2 100644 --- a/sqlsynthgen/settings.py +++ b/sqlsynthgen/settings.py @@ -10,35 +10,55 @@ class Settings(BaseSettings): """A Pydantic settings class with optional and mandatory settings.""" - # Connection parameters for a PostgreSQL database. See also, + # Connection parameters for the source PostgreSQL database. See also # https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-PARAMKEYWORDS - db_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0" - db_port: int = 5432 - db_user_name: str # e.g. "postgres" or "myuser@mydb" - db_password: str - db_name: str = "" # leave empty to get the user's default db - ssl_required: bool = False # whether the db requires SSL - - # postgres_dsn is calculated so do not provide it explicitly - postgres_dsn: Optional[PostgresDsn] - - @validator("postgres_dsn", pre=True) - def validate_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str: + src_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0" + src_port: int = 5432 + src_user_name: str # e.g. "postgres" or "myuser@mydb" + src_password: str + src_db_name: str = "" # leave empty to get the user's default db + src_ssl_required: bool = False # whether the db requires SSL + src_schema: Optional[str] + + # Connection parameters for the destination PostgreSQL database. + dst_host_name: str # e.g. "mydb.mydomain.com" or "0.0.0.0" + dst_port: int = 5432 + dst_user_name: str # e.g. "postgres" or "myuser@mydb" + dst_password: str + dst_db_name: str = "" # leave empty to get the user's default db + dst_ssl_required: bool = False # whether the db requires SSL + + # These are calculated so do not provide them explicitly + src_postgres_dsn: Optional[PostgresDsn] + dst_postgres_dsn: Optional[PostgresDsn] + + @validator("src_postgres_dsn", pre=True) + def validate_src_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str: + """Create and validate the source database DSN.""" + return cls.check_postgres_dsn(_, values, "src") + + @validator("dst_postgres_dsn", pre=True) + def validate_dst_postgres_dsn(cls, _: Optional[PostgresDsn], values: Any) -> str: + """Create and validate the destination database DSN.""" + return cls.check_postgres_dsn(_, values, "dst") + + @staticmethod + def check_postgres_dsn(_: Optional[PostgresDsn], values: Any, prefix: str) -> str: """Build a DSN string from the host, db name, port, username and password.""" # We want to build the Data Source Name ourselves so none should be provided if _: raise ValueError("postgres_dsn should not be provided") - user = values["db_user_name"] - password = values["db_password"] - host = values["db_host_name"] - port = values["db_port"] - db_name = values["db_name"] + user = values[f"{prefix}_user_name"] + password = values[f"{prefix}_password"] + host = values[f"{prefix}_host_name"] + port = values[f"{prefix}_port"] + db_name = values[f"{prefix}_db_name"] dsn = f"postgresql://{user}:{password}@{host}:{port}/{db_name}" - if values["ssl_required"]: + if values[f"{prefix}_ssl_required"]: return dsn + "?sslmode=require" return dsn diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_create.py b/tests/test_create.py new file mode 100644 index 00000000..852f2afe --- /dev/null +++ b/tests/test_create.py @@ -0,0 +1,38 @@ +"""Tests for the main module.""" +from unittest import TestCase +from unittest.mock import MagicMock, patch + +from sqlsynthgen.create import create_db_tables, generate +from tests.utils import get_test_settings + + +class MyTestCase(TestCase): + """Module test case.""" + + def test_generate(self) -> None: + """Test the generate function.""" + with patch("sqlsynthgen.create.populate") as mock_populate, patch( + "sqlsynthgen.create.get_settings" + ) as mock_get_settings, patch( + "sqlsynthgen.create.create_engine" + ) as mock_create_engine: + mock_get_settings.return_value = get_test_settings() + + generate([], []) + + mock_populate.assert_called_once() + mock_create_engine.assert_called_once() + + def test_create_tables(self) -> None: + """Test the create_tables function.""" + mock_meta = MagicMock() + + with patch("sqlsynthgen.create.create_engine") as mock_create_engine, patch( + "sqlsynthgen.create.get_settings" + ) as mock_get_settings: + + create_db_tables(mock_meta) + mock_get_settings.assert_called_once() + mock_create_engine.assert_called_once_with( + mock_get_settings.return_value.dst_postgres_dsn + ) diff --git a/tests/test_main.py b/tests/test_main.py index 464e55b5..62f09469 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,60 +1,127 @@ """Tests for the main module.""" -from functools import lru_cache from unittest import TestCase -from unittest.mock import MagicMock, patch +from unittest.mock import call, patch -from sqlsynthgen import main, settings -from sqlsynthgen.main import create_tables +from click.testing import Result +from typer.testing import CliRunner +from sqlsynthgen.main import app +from tests.utils import get_test_settings -@lru_cache(1) -def get_test_settings() -> settings.Settings: - """Get a Settings object that ignores .env files and environment variables.""" - return settings.Settings( - db_host_name="db_host_name", - db_user_name="db_user_name", - db_password="db_password", - db_name="db_name", - _env_file=None, - ) +runner = CliRunner() -class MyTestCase(TestCase): - """Module test case.""" +class TestCLI(TestCase): + """Tests for the command-line interface.""" - def test_main(self) -> None: - """Test the main function.""" - with patch("sqlsynthgen.main.populate"), patch( + def assertSuccess(self, result: Result) -> None: + """Give details and raise if the result isn't good.""" + # pylint: disable=invalid-name + if result.exit_code != 0: + print(result.stdout) + self.assertEqual(0, result.exit_code) + + def test_make_tables(self) -> None: + """Test the make-tables sub-command.""" + + with patch("sqlsynthgen.main.run") as mock_run, patch( "sqlsynthgen.main.get_settings" ) as mock_get_settings: mock_get_settings.return_value = get_test_settings() - with self.assertRaises(NotImplementedError): - main.main() + mock_run.return_value.returncode = 0 - def test_generate(self) -> None: - """Test the generate function.""" - with patch("sqlsynthgen.main.populate") as mock_populate, patch( - "sqlsynthgen.main.get_settings" - ) as mock_get_settings, patch( - "sqlsynthgen.main.create_engine" - ) as mock_create_engine: - mock_get_settings.return_value = get_test_settings() + result = runner.invoke( + app, + [ + "make-tables", + ], + catch_exceptions=False, + ) - main.generate([], []) + self.assertSuccess(result) - mock_populate.assert_called_once() - mock_create_engine.assert_called_once() + mock_run.assert_has_calls( + [ + call( + [ + "sqlacodegen", + get_test_settings().src_postgres_dsn, + ], + capture_output=True, + encoding="utf-8", + check=True, + ), + ] + ) + self.assertNotEqual("", result.stdout) - def test_create_tables(self) -> None: - """Test the create_tables function.""" - mock_meta = MagicMock() + def test_make_tables_with_schema(self) -> None: + """Test the make-tables sub-command handles the schema setting.""" - with patch("sqlsynthgen.main.create_engine") as mock_create_engine, patch( + with patch("sqlsynthgen.main.run") as mock_run, patch( "sqlsynthgen.main.get_settings" ) as mock_get_settings: + mock_get_settings.return_value = get_test_settings() + mock_get_settings.return_value.src_schema = "sschema" - create_tables(mock_meta) - mock_get_settings.assert_called_once() - mock_create_engine.assert_called_once_with( - mock_get_settings.return_value.postgres_dsn + result = runner.invoke( + app, + [ + "make-tables", + ], + catch_exceptions=False, ) + + self.assertSuccess(result) + + mock_run.assert_has_calls( + [ + call( + [ + "sqlacodegen", + "--schema=sschema", + get_test_settings().src_postgres_dsn, + ], + capture_output=True, + encoding="utf-8", + check=True, + ), + ] + ) + self.assertNotEqual("", result.stdout) + + def test_make_generators(self) -> None: + """Test the make-generators sub-command.""" + result = runner.invoke( + app, + [ + "make-generators", + ], + catch_exceptions=False, + ) + + self.assertSuccess(result) + + def test_create_tables(self) -> None: + """Test the create-tables sub-command.""" + result = runner.invoke( + app, + [ + "create-tables", + ], + catch_exceptions=False, + ) + + self.assertSuccess(result) + + def test_create_data(self) -> None: + """Test the create-data sub-command.""" + result = runner.invoke( + app, + [ + "create-data", + ], + catch_exceptions=False, + ) + + self.assertSuccess(result) diff --git a/tests/test_settings.py b/tests/test_settings.py new file mode 100644 index 00000000..8fc70796 --- /dev/null +++ b/tests/test_settings.py @@ -0,0 +1,59 @@ +"""Tests for the settings module.""" +from unittest import TestCase + +from sqlsynthgen.settings import Settings + + +class TestSettings(TestCase): + """Tests for the Settings class.""" + + def test_default_settings(self) -> None: + """Test the minimal settings.""" + settings = Settings( + src_host_name="shost", + src_user_name="suser", + src_password="spassword", + dst_host_name="dhost", + dst_user_name="duser", + dst_password="dpassword", + # To stop any local .env files influencing the test + _env_file=None, + ) + + self.assertEqual( + "postgresql://suser:spassword@shost:5432/", str(settings.src_postgres_dsn) + ) + self.assertIsNone(settings.src_schema) + + self.assertEqual( + "postgresql://duser:dpassword@dhost:5432/", str(settings.dst_postgres_dsn) + ) + + def test_maximal_settings(self) -> None: + """Test the full settings.""" + settings = Settings( + src_host_name="shost", + src_port=1234, + src_user_name="suser", + src_password="spassword", + src_db_name="sdbname", + src_ssl_required=True, + dst_host_name="dhost", + dst_port=4321, + dst_user_name="duser", + dst_password="dpassword", + dst_db_name="ddbname", + dst_ssl_required=True, + # To stop any local .env files influencing the test + _env_file=None, + ) + + self.assertEqual( + "postgresql://suser:spassword@shost:1234/sdbname?sslmode=require", + str(settings.src_postgres_dsn), + ) + + self.assertEqual( + "postgresql://duser:dpassword@dhost:4321/ddbname?sslmode=require", + str(settings.dst_postgres_dsn), + ) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..c40fd523 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,20 @@ +"""Utilities for testing.""" +from functools import lru_cache + +from sqlsynthgen import settings + + +@lru_cache(1) +def get_test_settings() -> settings.Settings: + """Get a Settings object that ignores .env files and environment variables.""" + + return settings.Settings( + src_host_name="shost", + src_user_name="suser", + src_password="spassword", + dst_host_name="dhost", + dst_user_name="duser", + dst_password="dpassword", + # To stop any local .env files influencing the test + _env_file=None, + )