Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions sqlsynthgen/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any

from sqlalchemy import create_engine, insert
from sqlalchemy.schema import CreateSchema

from sqlsynthgen.settings import get_settings

Expand All @@ -10,21 +11,33 @@ def create_db_tables(metadata: Any) -> Any:
"""Create tables described by the sqlalchemy metadata object."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)
# Create schemas, if necessary.
for table in metadata.sorted_tables:
try:
schema = table.schema
if not engine.dialect.has_schema(engine, schema=schema):
engine.execute(CreateSchema(schema, if_not_exists=True))
except AttributeError:
# This table didn't have a schema field
pass
metadata.create_all(engine)


def create_db_data(sorted_tables: list, sorted_generators: list) -> None:
def create_db_data(sorted_tables: list, sorted_generators: list, num_rows: int) -> None:
"""Connect to a database and populate it with data."""
settings = get_settings()
engine = create_engine(settings.dst_postgres_dsn)

with engine.connect() as conn:
populate(conn, sorted_tables, sorted_generators)
populate(conn, sorted_tables, sorted_generators, num_rows)


def populate(conn: Any, tables: list, generators: list) -> None:
def populate(conn: Any, tables: list, generators: list, num_rows: int) -> None:
"""Populate a database schema with dummy data."""

for table, generator in zip(tables, generators):
stmt = insert(table).values(generator(conn).__dict__)
conn.execute(stmt)
# Run all the inserts for one table in a transaction
with conn.begin():
for _ in range(num_rows):
stmt = insert(table).values(generator(conn).__dict__)
conn.execute(stmt)
5 changes: 4 additions & 1 deletion sqlsynthgen/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ def import_file(file_path: str) -> ModuleType:
def create_data(
orm_file: str = typer.Argument(...),
ssg_file: str = typer.Argument(...),
num_rows: int = typer.Argument(...),
) -> None:
"""Fill tables with synthetic data."""
orm_module = import_file(orm_file)
ssg_module = import_file(ssg_file)
create_db_data(orm_module.metadata.sorted_tables, ssg_module.sorted_generators)
create_db_data(
orm_module.metadata.sorted_tables, ssg_module.sorted_generators, num_rows
)


@app.command()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_create_db_data(self) -> None:
) as mock_create_engine:
mock_get_settings.return_value = get_test_settings()

create_db_data([], [])
create_db_data([], [], 0)

mock_populate.assert_called_once()
mock_create_engine.assert_called_once()
Expand Down
3 changes: 2 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,12 @@ def test_create_data(self) -> None:
"create-data",
"tests/examples/example_orm.py",
"tests/examples/expected_ssg.py",
"10",
],
catch_exceptions=False,
)

self.assertSuccess(result)
mock_create_db_data.assert_called_once_with(
example_orm.metadata.sorted_tables, expected_ssg.sorted_generators
example_orm.metadata.sorted_tables, expected_ssg.sorted_generators, 10
)