From 3f915edc013085433adbd4e9836d041be45b1009 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 18 Jan 2023 14:52:05 +0000 Subject: [PATCH 1/2] Create a new schema is necessary --- sqlsynthgen/create.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index 50b3aa81..fe16c405 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -2,6 +2,7 @@ from typing import Any from sqlalchemy import create_engine, insert +from sqlalchemy.schema import CreateSchema from sqlsynthgen.settings import get_settings @@ -10,6 +11,15 @@ def create_db_tables(metadata: Any) -> Any: """Create tables described by the sqlalchemy metadata object.""" settings = get_settings() engine = create_engine(settings.dst_postgres_dsn) + # Create schemas, if necessary. + for table in metadata.sorted_tables: + try: + schema = table.schema + if not engine.dialect.has_schema(engine, schema=schema): + engine.execute(CreateSchema(schema, if_not_exists=True)) + except AttributeError: + # This table didn't have a schema field + pass metadata.create_all(engine) From ac8c0034f54bfacde27ed6e56596c500c4f8a83c Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 18 Jan 2023 14:53:00 +0000 Subject: [PATCH 2/2] Add a new argument for how many rows to add --- sqlsynthgen/create.py | 13 ++++++++----- sqlsynthgen/main.py | 5 ++++- tests/test_create.py | 2 +- tests/test_main.py | 3 ++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/sqlsynthgen/create.py b/sqlsynthgen/create.py index fe16c405..6e7d9c2c 100644 --- a/sqlsynthgen/create.py +++ b/sqlsynthgen/create.py @@ -23,18 +23,21 @@ def create_db_tables(metadata: Any) -> Any: metadata.create_all(engine) -def create_db_data(sorted_tables: list, sorted_generators: list) -> None: +def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None: """Connect to a database and populate it with data.""" settings = get_settings() engine = create_engine(settings.dst_postgres_dsn) with engine.connect() as conn: - populate(conn, sorted_tables, sorted_generators) + populate(conn, sorted_tables, sorted_generators, num_rows) -def populate(conn: Any, tables: list, generators: list) -> None: +def populate(conn: Any, tables: list, generators: list, num_rows: int) -> None: """Populate a database schema with dummy data.""" for table, generator in zip(tables, generators): - stmt = insert(table).values(generator(conn).__dict__) - conn.execute(stmt) + # Run all the inserts for one table in a transaction + with conn.begin(): + for _ in range(num_rows): + stmt = insert(table).values(generator(conn).__dict__) + conn.execute(stmt) diff --git a/sqlsynthgen/main.py b/sqlsynthgen/main.py index 42c84bb9..1904617f 100644 --- a/sqlsynthgen/main.py +++ b/sqlsynthgen/main.py @@ -26,11 +26,14 @@ def import_file(file_path: str) -> ModuleType: def create_data( orm_file: str = typer.Argument(...), ssg_file: str = typer.Argument(...), + num_rows: int = typer.Argument(...), ) -> None: """Fill tables with synthetic data.""" orm_module = import_file(orm_file) ssg_module = import_file(ssg_file) - create_db_data(orm_module.metadata.sorted_tables, ssg_module.sorted_generators) + create_db_data( + orm_module.metadata.sorted_tables, ssg_module.sorted_generators, num_rows + ) @app.command() diff --git a/tests/test_create.py b/tests/test_create.py index 788d7a96..de1b163a 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -18,7 +18,7 @@ def test_create_db_data(self) -> None: ) as mock_create_engine: mock_get_settings.return_value = get_test_settings() - create_db_data([], []) + create_db_data([], [], 0) mock_populate.assert_called_once() mock_create_engine.assert_called_once() diff --git a/tests/test_main.py b/tests/test_main.py index 855ff0ec..04d6f3ae 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -151,11 +151,12 @@ def test_create_data(self) -> None: "create-data", "tests/examples/example_orm.py", "tests/examples/expected_ssg.py", + "10", ], catch_exceptions=False, ) self.assertSuccess(result) mock_create_db_data.assert_called_once_with( - example_orm.metadata.sorted_tables, expected_ssg.sorted_generators + example_orm.metadata.sorted_tables, expected_ssg.sorted_generators, 10 )